diff --git a/.github/codecov.yml b/.github/codecov.yml deleted file mode 100644 index 5f721427d7a..00000000000 --- a/.github/codecov.yml +++ /dev/null @@ -1,10 +0,0 @@ -# we measure coverage but don't enforce it -# https://docs.codecov.com/docs/codecov-yaml -coverage: - status: - patch: - default: - target: 0% - project: - default: - target: 0% diff --git a/.github/generate-codecov-yml.sh b/.github/generate-codecov-yml.sh new file mode 100755 index 00000000000..ddb60d0ce80 --- /dev/null +++ b/.github/generate-codecov-yml.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +# Run this from the repository root: +# +# .github/generate-codecov-yml.sh >> .github/codecov.yml + +cat <> $GITHUB_ENV - name: "Create Parsers badge" - uses: schneegans/dynamic-badges-action@v1.6.0 + uses: schneegans/dynamic-badges-action@v1.7.0 if: ${{ github.ref == 'refs/heads/master' && github.repository_owner == 'crowdsecurity' }} with: auth: ${{ secrets.GIST_BADGES_SECRET }} @@ -66,7 +64,7 @@ jobs: color: ${{ env.SCENARIO_BADGE_COLOR }} - name: "Create Scenarios badge" - uses: schneegans/dynamic-badges-action@v1.6.0 + uses: schneegans/dynamic-badges-action@v1.7.0 if: ${{ github.ref == 'refs/heads/master' && github.repository_owner == 'crowdsecurity' }} with: auth: ${{ secrets.GIST_BADGES_SECRET }} diff --git a/.github/workflows/bats-mysql.yml b/.github/workflows/bats-mysql.yml index 897122f632b..211d856bc34 100644 --- a/.github/workflows/bats-mysql.yml +++ b/.github/workflows/bats-mysql.yml @@ -1,4 +1,4 @@ -name: Functional tests (MySQL) +name: (sub) Bats / MySQL on: workflow_call: @@ -7,16 +7,9 @@ on: required: true type: string -env: - PREFIX_TEST_NAMES_WITH_FILE: true - jobs: build: - strategy: - matrix: - go-version: ["1.20.6"] - - name: "Build + tests" + name: "Functional tests" runs-on: ubuntu-latest timeout-minutes: 30 services: @@ -35,22 +28,21 @@ jobs: echo githubciXXXXXXXXXXXXXXXXXXXXXXXX | sudo tee /etc/machine-id - name: "Check out CrowdSec repository" - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 submodules: true - - name: "Set up Go ${{ matrix.go-version }}" - uses: actions/setup-go@v4 + - name: "Set up Go" + uses: actions/setup-go@v5 with: - go-version: ${{ matrix.go-version }} - cache-dependency-path: "**/go.sum" + go-version: "1.22" - name: "Install bats dependencies" env: GOBIN: /usr/local/bin run: | - sudo apt -qq -y -o=Dpkg::Use-Pty=0 install build-essential daemonize jq netcat-openbsd libre2-dev + sudo apt -qq -y -o=Dpkg::Use-Pty=0 install build-essential daemonize jq libre2-dev - name: "Build crowdsec and fixture" run: | @@ -63,7 +55,7 @@ jobs: MYSQL_USER: root - name: "Run tests" - run: make bats-test + run: ./test/run-tests ./test/bats --formatter $(pwd)/test/lib/color-formatter env: DB_BACKEND: mysql MYSQL_HOST: 127.0.0.1 diff --git a/.github/workflows/bats-postgres.yml b/.github/workflows/bats-postgres.yml index c2aefef0458..aec707f0c03 100644 --- a/.github/workflows/bats-postgres.yml +++ b/.github/workflows/bats-postgres.yml @@ -1,23 +1,16 @@ -name: Functional tests (Postgres) +name: (sub) Bats / Postgres on: workflow_call: -env: - PREFIX_TEST_NAMES_WITH_FILE: true - jobs: build: - strategy: - matrix: - go-version: ["1.20.6"] - - name: "Build + tests" + name: "Functional tests" runs-on: ubuntu-latest timeout-minutes: 30 services: database: - image: postgres:15 + image: postgres:16 env: POSTGRES_PASSWORD: "secret" ports: @@ -30,13 +23,13 @@ jobs: steps: - - name: "Install pg_dump v15" + - name: "Install pg_dump v16" # we can remove this when it's released on ubuntu-latest run: | sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list' wget -qO- https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo tee /etc/apt/trusted.gpg.d/pgdg.asc &>/dev/null sudo apt update - sudo apt -qq -y -o=Dpkg::Use-Pty=0 install postgresql-client-15 + sudo apt -qq -y -o=Dpkg::Use-Pty=0 install postgresql-client-16 - name: "Force machineid" run: | @@ -44,22 +37,21 @@ jobs: echo githubciXXXXXXXXXXXXXXXXXXXXXXXX | sudo tee /etc/machine-id - name: "Check out CrowdSec repository" - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 submodules: true - - name: "Set up Go ${{ matrix.go-version }}" - uses: actions/setup-go@v4 + - name: "Set up Go" + uses: actions/setup-go@v5 with: - go-version: ${{ matrix.go-version }} - cache-dependency-path: "**/go.sum" + go-version: "1.22" - name: "Install bats dependencies" env: GOBIN: /usr/local/bin run: | - sudo apt -qq -y -o=Dpkg::Use-Pty=0 install build-essential daemonize jq netcat-openbsd libre2-dev + sudo apt -qq -y -o=Dpkg::Use-Pty=0 install build-essential daemonize jq libre2-dev - name: "Build crowdsec and fixture (DB_BACKEND: pgx)" run: | @@ -72,7 +64,7 @@ jobs: PGUSER: postgres - name: "Run tests (DB_BACKEND: pgx)" - run: make bats-test + run: ./test/run-tests ./test/bats --formatter $(pwd)/test/lib/color-formatter env: DB_BACKEND: pgx PGHOST: 127.0.0.1 diff --git a/.github/workflows/bats-sqlite-coverage.yml b/.github/workflows/bats-sqlite-coverage.yml index 93f85b72e6e..a089aa53532 100644 --- a/.github/workflows/bats-sqlite-coverage.yml +++ b/.github/workflows/bats-sqlite-coverage.yml @@ -1,19 +1,17 @@ -name: Functional tests (sqlite) +name: (sub) Bats / sqlite + coverage on: workflow_call: + secrets: + CODECOV_TOKEN: + required: true env: - PREFIX_TEST_NAMES_WITH_FILE: true TEST_COVERAGE: true jobs: build: - strategy: - matrix: - go-version: ["1.20.6"] - - name: "Build + tests" + name: "Functional tests" runs-on: ubuntu-latest timeout-minutes: 20 @@ -25,29 +23,32 @@ jobs: echo githubciXXXXXXXXXXXXXXXXXXXXXXXX | sudo tee /etc/machine-id - name: "Check out CrowdSec repository" - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 submodules: true - - name: "Set up Go ${{ matrix.go-version }}" - uses: actions/setup-go@v4 + - name: "Set up Go" + uses: actions/setup-go@v5 with: - go-version: ${{ matrix.go-version }} - cache-dependency-path: "**/go.sum" + go-version: "1.22" - name: "Install bats dependencies" env: GOBIN: /usr/local/bin run: | - sudo apt -qq -y -o=Dpkg::Use-Pty=0 install build-essential daemonize jq netcat-openbsd libre2-dev + sudo apt -qq -y -o=Dpkg::Use-Pty=0 install build-essential daemonize jq libre2-dev - name: "Build crowdsec and fixture" run: | make clean bats-build bats-fixture BUILD_STATIC=1 + - name: Generate codecov configuration + run: | + .github/generate-codecov-yml.sh >> .github/codecov.yml + - name: "Run tests" - run: make bats-test + run: ./test/run-tests ./test/bats --formatter $(pwd)/test/lib/color-formatter - name: "Collect coverage data" run: | @@ -82,8 +83,9 @@ jobs: run: for file in $(find ./test/local/var/log -type f); do echo ">>>>> $file"; cat $file; echo; done if: ${{ always() }} - - name: Upload crowdsec coverage to codecov - uses: codecov/codecov-action@v3 + - name: Upload bats coverage to codecov + uses: codecov/codecov-action@v4 with: files: ./coverage-bats.out flags: bats + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/bats.yml b/.github/workflows/bats.yml index 46b1414ef19..59976bad87d 100644 --- a/.github/workflows/bats.yml +++ b/.github/workflows/bats.yml @@ -28,10 +28,12 @@ on: jobs: sqlite: uses: ./.github/workflows/bats-sqlite-coverage.yml + secrets: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} # Jobs for Postgres (and sometimes MySQL) can have failing tests on GitHub # CI, but they pass when run on devs' machines or in the release checks. We - # disable them here by default. Remove the if..false to enable them. + # disable them here by default. Remove if...false to enable them. mariadb: uses: ./.github/workflows/bats-mysql.yml diff --git a/.github/workflows/cache-cleanup.yaml b/.github/workflows/cache-cleanup.yaml index d193650246b..4f320cf2442 100644 --- a/.github/workflows/cache-cleanup.yaml +++ b/.github/workflows/cache-cleanup.yaml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Cleanup run: | diff --git a/.github/workflows/ci-windows-build-msi.yml b/.github/workflows/ci-windows-build-msi.yml index 47f5b905d6a..a37aa43e2d0 100644 --- a/.github/workflows/ci-windows-build-msi.yml +++ b/.github/workflows/ci-windows-build-msi.yml @@ -21,31 +21,26 @@ on: jobs: build: - strategy: - matrix: - go-version: ["1.20.6"] - name: Build runs-on: windows-2019 steps: - name: Check out code into the Go module directory - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - - name: "Set up Go ${{ matrix.go-version }}" - uses: actions/setup-go@v4 + - name: "Set up Go" + uses: actions/setup-go@v5 with: - go-version: ${{ matrix.go-version }} - cache-dependency-path: "**/go.sum" + go-version: "1.22" - name: Build run: make windows_installer BUILD_RE2_WASM=1 - name: Upload MSI - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: path: crowdsec*msi name: crowdsec.msi diff --git a/.github/workflows/ci_release-drafter.yml b/.github/workflows/ci_release-drafter.yml index 2ccb6977cfd..0b8c9b386e6 100644 --- a/.github/workflows/ci_release-drafter.yml +++ b/.github/workflows/ci_release-drafter.yml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest steps: # Drafts your next Release notes as Pull Requests are merged into "master" - - uses: release-drafter/release-drafter@v5 + - uses: release-drafter/release-drafter@v6 with: config-name: release-drafter.yml # (Optional) specify config name to use, relative to .github/. Default: release-drafter.yml diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index c1995cd8d89..2715c6590c3 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -44,11 +44,20 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 + with: + # required to pick up tags for BUILD_VERSION + fetch-depth: 0 + + - name: "Set up Go" + uses: actions/setup-go@v5 + with: + go-version: "1.22" + cache-dependency-path: "**/go.sum" # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@v2 + uses: github/codeql-action/init@v3 with: languages: ${{ matrix.language }} # If you wish to specify custom queries, you can do so here or in a config file. @@ -58,8 +67,8 @@ jobs: # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). # If this step fails, then you should remove it and run the build manually (see below) - - name: Autobuild - uses: github/codeql-action/autobuild@v2 + # - name: Autobuild + # uses: github/codeql-action/autobuild@v3 # ℹī¸ Command-line programs to run using the OS shell. # 📚 https://git.io/JvXDl @@ -68,9 +77,8 @@ jobs: # and modify them (or add more) to build your code if your project # uses a compiled language - #- run: | - # make bootstrap - # make release + - run: | + make clean build BUILD_RE2_WASM=1 - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v2 + uses: github/codeql-action/analyze@v3 diff --git a/.github/workflows/docker-tests.yml b/.github/workflows/docker-tests.yml index 913c4766238..918f3bcaf1d 100644 --- a/.github/workflows/docker-tests.yml +++ b/.github/workflows/docker-tests.yml @@ -15,78 +15,50 @@ on: - 'README.md' jobs: - test_docker_image: + test_flavor: + strategy: + # we could test all the flavors in a single pytest job, + # but let's split them (and the image build) in multiple runners for performance + matrix: + # can be slim, full or debian (no debian slim). + flavor: ["slim", "debian"] + runs-on: ubuntu-latest timeout-minutes: 30 steps: - name: Check out the repo - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v2 - with: - config: .github/buildkit.toml - - - name: "Build flavor: slim" - uses: docker/build-push-action@v4 + uses: docker/setup-buildx-action@v3 with: - context: . - file: ./Dockerfile - tags: crowdsecurity/crowdsec:test-slim - target: slim - platforms: linux/amd64 - load: true - cache-from: type=gha - cache-to: type=gha,mode=min - - - name: "Build flavor: full" - uses: docker/build-push-action@v4 - with: - context: . - file: ./Dockerfile - tags: crowdsecurity/crowdsec:test - target: full - platforms: linux/amd64 - load: true - cache-from: type=gha - cache-to: type=gha,mode=min + buildkitd-config: .github/buildkit.toml - - name: "Build flavor: full (debian)" - uses: docker/build-push-action@v4 + - name: "Build image" + uses: docker/build-push-action@v6 with: context: . - file: ./Dockerfile.debian - tags: crowdsecurity/crowdsec:test-debian - target: full + file: ./Dockerfile${{ matrix.flavor == 'debian' && '.debian' || '' }} + tags: crowdsecurity/crowdsec:test${{ matrix.flavor == 'full' && '' || '-' }}${{ matrix.flavor == 'full' && '' || matrix.flavor }} + target: ${{ matrix.flavor == 'debian' && 'full' || matrix.flavor }} platforms: linux/amd64 load: true cache-from: type=gha cache-to: type=gha,mode=min - name: "Setup Python" - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.x" - - - name: "Install pipenv" - run: | - cd docker/test - python -m pip install --upgrade pipenv wheel - - - name: "Cache virtualenvs" - id: cache-pipenv - uses: actions/cache@v3 - with: - path: ~/.local/share/virtualenvs - key: ${{ runner.os }}-pipenv-${{ hashFiles('**/Pipfile.lock') }} + cache: 'pipenv' - name: "Install dependencies" - if: steps.cache-pipenv.outputs.cache-hit != 'true' run: | cd docker/test + python -m pip install --upgrade pipenv wheel pipenv install --deploy - name: "Create Docker network" @@ -95,9 +67,10 @@ jobs: - name: "Run tests" env: CROWDSEC_TEST_VERSION: test - CROWDSEC_TEST_FLAVORS: slim,debian + CROWDSEC_TEST_FLAVORS: ${{ matrix.flavor }} CROWDSEC_TEST_NETWORK: net-test CROWDSEC_TEST_TIMEOUT: 90 + # running serially to reduce test flakiness run: | cd docker/test - pipenv run pytest -n 2 --durations=0 --color=yes + pipenv run pytest -n 1 --durations=0 --color=yes diff --git a/.github/workflows/go-tests-windows.yml b/.github/workflows/go-tests-windows.yml index 772c574abef..ba283f3890a 100644 --- a/.github/workflows/go-tests-windows.yml +++ b/.github/workflows/go-tests-windows.yml @@ -20,51 +20,47 @@ env: jobs: build: - strategy: - matrix: - go-version: ["1.20.6"] - name: "Build + tests" runs-on: windows-2022 steps: - name: Check out CrowdSec repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - - name: "Set up Go ${{ matrix.go-version }}" - uses: actions/setup-go@v4 + - name: "Set up Go" + uses: actions/setup-go@v5 with: - go-version: ${{ matrix.go-version }} - cache-dependency-path: "**/go.sum" + go-version: "1.22" - name: Build run: | make build BUILD_RE2_WASM=1 + - name: Generate codecov configuration + run: | + .github/generate-codecov-yml.sh >> .github/codecov.yml + - name: Run tests run: | go install github.com/kyoh86/richgo@v0.3.10 - go test -coverprofile coverage.out -covermode=atomic ./... > out.txt + go test -tags expr_debug -coverprofile coverage.out -covermode=atomic ./... > out.txt if(!$?) { cat out.txt | sed 's/ *coverage:.*of statements in.*//' | richgo testfilter; Exit 1 } cat out.txt | sed 's/ *coverage:.*of statements in.*//' | richgo testfilter - name: Upload unit coverage to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: files: coverage.out flags: unit-windows + token: ${{ secrets.CODECOV_TOKEN }} - name: golangci-lint - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v6 with: - version: v1.51 + version: v1.61 args: --issues-exit-code=1 --timeout 10m only-new-issues: false - # the cache is already managed above, enabling it here - # gives errors when extracting - skip-pkg-cache: true - skip-build-cache: true diff --git a/.github/workflows/go-tests.yml b/.github/workflows/go-tests.yml index 2dff8af2169..3fdfb8a3e82 100644 --- a/.github/workflows/go-tests.yml +++ b/.github/workflows/go-tests.yml @@ -24,23 +24,18 @@ env: RICHGO_FORCE_COLOR: 1 AWS_HOST: localstack # these are to mimic aws config - AWS_ACCESS_KEY_ID: AKIAIOSFODNN7EXAMPLE - AWS_SECRET_ACCESS_KEY: wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY + AWS_ACCESS_KEY_ID: test + AWS_SECRET_ACCESS_KEY: test AWS_REGION: us-east-1 - KINESIS_INITIALIZE_STREAMS: "stream-1-shard:1,stream-2-shards:2" CROWDSEC_FEATURE_DISABLE_HTTP_RETRY_BACKOFF: true jobs: build: - strategy: - matrix: - go-version: ["1.20.6"] - name: "Build + tests" runs-on: ubuntu-latest services: localstack: - image: localstack/localstack:1.3.0 + image: localstack/localstack:3.0 ports: - 4566:4566 # Localstack exposes all services on the same port env: @@ -49,7 +44,7 @@ jobs: KINESIS_ERROR_PROBABILITY: "" DOCKER_HOST: unix:///var/run/docker.sock KINESIS_INITIALIZE_STREAMS: ${{ env.KINESIS_INITIALIZE_STREAMS }} - HOSTNAME_EXTERNAL: ${{ env.AWS_HOST }} # Required so that resource urls are provided properly + LOCALSTACK_HOST: ${{ env.AWS_HOST }} # Required so that resource urls are provided properly # e.g sqs url will get localhost if we don't set this env to map our service options: >- --name=localstack @@ -58,7 +53,7 @@ jobs: --health-timeout=5s --health-retries=3 zoo1: - image: confluentinc/cp-zookeeper:7.3.0 + image: confluentinc/cp-zookeeper:7.4.3 ports: - "2181:2181" env: @@ -108,19 +103,62 @@ jobs: --health-timeout 10s --health-retries 5 + loki: + image: grafana/loki:2.9.1 + ports: + - "3100:3100" + options: >- + --name=loki1 + --health-cmd "wget -q -O - http://localhost:3100/ready | grep 'ready'" + --health-interval 30s + --health-timeout 10s + --health-retries 5 + --health-start-period 30s + steps: - name: Check out CrowdSec repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - - name: "Set up Go ${{ matrix.go-version }}" - uses: actions/setup-go@v4 + - name: "Set up Go" + uses: actions/setup-go@v5 with: - go-version: ${{ matrix.go-version }} - cache-dependency-path: "**/go.sum" + go-version: "1.22" + + - name: Run "go generate" and check for changes + run: | + set -e + # ensure the version of 'protoc' matches the one that generated the files + PROTOBUF_VERSION="21.12" + # don't pollute the repo + pushd $HOME + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOBUF_VERSION}/protoc-${PROTOBUF_VERSION}-linux-x86_64.zip + unzip protoc-${PROTOBUF_VERSION}-linux-x86_64.zip -d $HOME/.protoc + popd + export PATH="$HOME/.protoc/bin:$PATH" + go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2 + go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.5.1 + go generate ./... + protoc --version + if [[ $(git status --porcelain) ]]; then + echo "Error: Uncommitted changes found after running 'make generate'. Please commit all generated code." + git diff + exit 1 + else + echo "No changes detected after running 'make generate'." + fi + + - name: Create localstack streams + run: | + aws --endpoint-url=http://127.0.0.1:4566 --region us-east-1 kinesis create-stream --stream-name stream-1-shard --shard-count 1 + aws --endpoint-url=http://127.0.0.1:4566 --region us-east-1 kinesis create-stream --stream-name stream-2-shards --shard-count 2 + + - name: Generate codecov configuration + run: | + .github/generate-codecov-yml.sh >> .github/codecov.yml - name: Build and run tests, static run: | @@ -129,26 +167,29 @@ jobs: go install github.com/kyoh86/richgo@v0.3.10 set -o pipefail make build BUILD_STATIC=1 - make go-acc | richgo testfilter + make go-acc | sed 's/ *coverage:.*of statements in.*//' | richgo testfilter + + # check if some component stubs are missing + - name: "Build profile: minimal" + run: | + make build BUILD_PROFILE=minimal - name: Run tests again, dynamic run: | make clean build - make go-acc | richgo testfilter + set -o pipefail + make go-acc | sed 's/ *coverage:.*of statements in.*//' | richgo testfilter - name: Upload unit coverage to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: files: coverage.out flags: unit-linux + token: ${{ secrets.CODECOV_TOKEN }} - name: golangci-lint - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v6 with: - version: v1.51 + version: v1.61 args: --issues-exit-code=1 --timeout 10m only-new-issues: false - # the cache is already managed above, enabling it here - # gives errors when extracting - skip-pkg-cache: true - skip-build-cache: true diff --git a/.github/workflows/governance-bot.yaml b/.github/workflows/governance-bot.yaml index 5c08cabf5d1..c9e73e7811a 100644 --- a/.github/workflows/governance-bot.yaml +++ b/.github/workflows/governance-bot.yaml @@ -23,7 +23,7 @@ jobs: runs-on: ubuntu-latest steps: # Semantic versioning, lock to different version: v2, v2.0 or a commit hash. - - uses: BirthdayResearch/oss-governance-bot@v3 + - uses: BirthdayResearch/oss-governance-bot@v4 with: # You can use a PAT to post a comment/label/status so that it shows up as a user instead of github-actions github-token: ${{secrets.GITHUB_TOKEN}} # optional, default to '${{ github.token }}' diff --git a/.github/workflows/publish-docker-master.yml b/.github/workflows/publish-docker-master.yml new file mode 100644 index 00000000000..e8bfb10ddb1 --- /dev/null +++ b/.github/workflows/publish-docker-master.yml @@ -0,0 +1,47 @@ +name: (push-master) Publish latest Docker images + +on: + push: + branches: [ master ] + paths: + - 'pkg/**' + - 'cmd/**' + - 'mk/**' + - 'docker/docker_start.sh' + - 'docker/config.yaml' + - '.github/workflows/publish-docker-master.yml' + - '.github/workflows/publish-docker.yml' + - 'Dockerfile' + - 'Dockerfile.debian' + - 'go.mod' + - 'go.sum' + - 'Makefile' + +jobs: + dev-alpine: + uses: ./.github/workflows/publish-docker.yml + with: + platform: linux/amd64 + crowdsec_version: "" + image_version: dev + latest: false + push: true + slim: false + debian: false + secrets: + DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} + DOCKER_PASSWORD: ${{ secrets.DOCKER_PASSWORD }} + + dev-debian: + uses: ./.github/workflows/publish-docker.yml + with: + platform: linux/amd64 + crowdsec_version: "" + image_version: dev + latest: false + push: true + slim: false + debian: true + secrets: + DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} + DOCKER_PASSWORD: ${{ secrets.DOCKER_PASSWORD }} diff --git a/.github/workflows/publish-docker-release.yml b/.github/workflows/publish-docker-release.yml new file mode 100644 index 00000000000..5ec2d0e143e --- /dev/null +++ b/.github/workflows/publish-docker-release.yml @@ -0,0 +1,48 @@ +name: (manual) Publish Docker images + +on: + workflow_dispatch: + inputs: + image_version: + description: Docker Image version (base tag, i.e. v1.6.0-2) + required: true + crowdsec_version: + description: Crowdsec version (BUILD_VERSION) + required: true + latest: + description: Overwrite latest (and slim) tags? + default: false + required: true + push: + description: Really push? + default: false + required: true + +jobs: + alpine: + uses: ./.github/workflows/publish-docker.yml + secrets: + DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} + DOCKER_PASSWORD: ${{ secrets.DOCKER_PASSWORD }} + with: + image_version: ${{ github.event.inputs.image_version }} + crowdsec_version: ${{ github.event.inputs.crowdsec_version }} + latest: ${{ github.event.inputs.latest == 'true' }} + push: ${{ github.event.inputs.push == 'true' }} + slim: true + debian: false + platform: "linux/amd64,linux/386,linux/arm64,linux/arm/v7,linux/arm/v6" + + debian: + uses: ./.github/workflows/publish-docker.yml + secrets: + DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} + DOCKER_PASSWORD: ${{ secrets.DOCKER_PASSWORD }} + with: + image_version: ${{ github.event.inputs.image_version }} + crowdsec_version: ${{ github.event.inputs.crowdsec_version }} + latest: ${{ github.event.inputs.latest == 'true' }} + push: ${{ github.event.inputs.push == 'true' }} + slim: false + debian: true + platform: "linux/amd64,linux/386,linux/arm64" diff --git a/.github/workflows/publish-docker.yml b/.github/workflows/publish-docker.yml new file mode 100644 index 00000000000..11b4401c6da --- /dev/null +++ b/.github/workflows/publish-docker.yml @@ -0,0 +1,125 @@ +name: (sub) Publish Docker images + +on: + workflow_call: + secrets: + DOCKER_USERNAME: + required: true + DOCKER_PASSWORD: + required: true + inputs: + platform: + required: true + type: string + image_version: + required: true + type: string + crowdsec_version: + required: true + type: string + latest: + required: true + type: boolean + push: + required: true + type: boolean + slim: + required: true + type: boolean + debian: + required: true + type: boolean + +jobs: + push_to_registry: + name: Push Docker image to registries + runs-on: ubuntu-latest + steps: + + - name: Check out the repo + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + with: + buildkitd-config: .github/buildkit.toml + + - name: Login to DockerHub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_PASSWORD }} + + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.repository_owner }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Prepare (slim) + if: ${{ inputs.slim }} + id: slim + run: | + DOCKERHUB_IMAGE=${{ secrets.DOCKER_USERNAME }}/crowdsec + GHCR_IMAGE=ghcr.io/${{ github.repository_owner }}/crowdsec + VERSION=${{ inputs.image_version }} + DEBIAN=${{ inputs.debian && '-debian' || '' }} + TAGS="${DOCKERHUB_IMAGE}:${VERSION}-slim${DEBIAN},${GHCR_IMAGE}:${VERSION}-slim${DEBIAN}" + if [[ ${{ inputs.latest }} == true ]]; then + TAGS=$TAGS,${DOCKERHUB_IMAGE}:slim${DEBIAN},${GHCR_IMAGE}:slim${DEBIAN} + fi + echo "tags=${TAGS}" >> $GITHUB_OUTPUT + echo "created=$(date -u +'%Y-%m-%dT%H:%M:%SZ')" >> $GITHUB_OUTPUT + + - name: Prepare (full) + id: full + run: | + DOCKERHUB_IMAGE=${{ secrets.DOCKER_USERNAME }}/crowdsec + GHCR_IMAGE=ghcr.io/${{ github.repository_owner }}/crowdsec + VERSION=${{ inputs.image_version }} + DEBIAN=${{ inputs.debian && '-debian' || '' }} + TAGS="${DOCKERHUB_IMAGE}:${VERSION}${DEBIAN},${GHCR_IMAGE}:${VERSION}${DEBIAN}" + if [[ ${{ inputs.latest }} == true ]]; then + TAGS=$TAGS,${DOCKERHUB_IMAGE}:latest${DEBIAN},${GHCR_IMAGE}:latest${DEBIAN} + fi + echo "tags=${TAGS}" >> $GITHUB_OUTPUT + echo "created=$(date -u +'%Y-%m-%dT%H:%M:%SZ')" >> $GITHUB_OUTPUT + + - name: Build and push image (slim) + if: ${{ inputs.slim }} + uses: docker/build-push-action@v6 + with: + context: . + file: ./Dockerfile${{ inputs.debian && '.debian' || '' }} + push: ${{ inputs.push }} + tags: ${{ steps.slim.outputs.tags }} + target: slim + platforms: ${{ inputs.platform }} + labels: | + org.opencontainers.image.source=${{ github.event.repository.html_url }} + org.opencontainers.image.created=${{ steps.slim.outputs.created }} + org.opencontainers.image.revision=${{ github.sha }} + build-args: | + BUILD_VERSION=${{ inputs.crowdsec_version }} + + - name: Build and push image (full) + uses: docker/build-push-action@v6 + with: + context: . + file: ./Dockerfile${{ inputs.debian && '.debian' || '' }} + push: ${{ inputs.push }} + tags: ${{ steps.full.outputs.tags }} + target: full + platforms: ${{ inputs.platform }} + labels: | + org.opencontainers.image.source=${{ github.event.repository.html_url }} + org.opencontainers.image.created=${{ steps.full.outputs.created }} + org.opencontainers.image.revision=${{ github.sha }} + build-args: | + BUILD_VERSION=${{ inputs.crowdsec_version }} diff --git a/.github/workflows/release_publish-package.yml b/.github/workflows/publish-tarball-release.yml similarity index 66% rename from .github/workflows/release_publish-package.yml rename to .github/workflows/publish-tarball-release.yml index c38e0812ca4..eeefb801719 100644 --- a/.github/workflows/release_publish-package.yml +++ b/.github/workflows/publish-tarball-release.yml @@ -1,5 +1,5 @@ # .github/workflows/build-docker-image.yml -name: build +name: Release on: release: @@ -12,25 +12,20 @@ permissions: jobs: build: - strategy: - matrix: - go-version: ["1.20.6"] - name: Build and upload binary package runs-on: ubuntu-latest steps: - name: Check out code into the Go module directory - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - - name: "Set up Go ${{ matrix.go-version }}" - uses: actions/setup-go@v4 + - name: "Set up Go" + uses: actions/setup-go@v5 with: - go-version: ${{ matrix.go-version }} - cache-dependency-path: "**/go.sum" + go-version: "1.22" - name: Build the binaries run: | @@ -42,4 +37,4 @@ jobs: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | tag_name="${GITHUB_REF##*/}" - hub release edit -a crowdsec-release.tgz -a vendor.tgz -m "" "$tag_name" + gh release upload "$tag_name" crowdsec-release.tgz vendor.tgz *-vendor.tar.xz diff --git a/.github/workflows/publish_docker-image_on_master-debian.yml b/.github/workflows/publish_docker-image_on_master-debian.yml deleted file mode 100644 index 88076157c33..00000000000 --- a/.github/workflows/publish_docker-image_on_master-debian.yml +++ /dev/null @@ -1,70 +0,0 @@ -name: Publish Debian Docker image on Push to Master - -on: - push: - branches: [ master ] - paths: - - 'pkg/**' - - 'cmd/**' - - 'plugins/**' - - 'docker/docker_start.sh' - - 'docker/config.yaml' - - '.github/workflows/publish_docker-image_on_master-debian.yml' - - 'Dockerfile.debian' - - 'go.mod' - - 'go.sum' - - 'Makefile' - -jobs: - push_to_registry: - name: Push Debian Docker image to Docker Hub - runs-on: ubuntu-latest - steps: - - - name: Check out the repo - uses: actions/checkout@v3 - with: - fetch-depth: 0 - - - name: Prepare - id: prep - run: | - DOCKER_IMAGE=crowdsecurity/crowdsec - GHCR_IMAGE=ghcr.io/${{ github.repository_owner }}/crowdsec - VERSION=dev-debian - TAGS="${DOCKER_IMAGE}:${VERSION},${GHCR_IMAGE}:${VERSION}" - echo "tags=${TAGS}" >> $GITHUB_OUTPUT - echo "created=$(date -u +'%Y-%m-%dT%H:%M:%SZ')" >> $GITHUB_OUTPUT - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v2 - with: - config: .github/buildkit.toml - - - name: Login to DockerHub - uses: docker/login-action@v2 - with: - username: ${{ secrets.DOCKER_USERNAME }} - password: ${{ secrets.DOCKER_PASSWORD }} - - - name: Login to GitHub Container Registry - uses: docker/login-action@v2 - with: - registry: ghcr.io - username: ${{ github.repository_owner }} - password: ${{ secrets.GITHUB_TOKEN }} - - - name: Build and push full image - uses: docker/build-push-action@v4 - with: - context: . - file: ./Dockerfile.debian - push: ${{ github.event_name != 'pull_request' }} - tags: ${{ steps.prep.outputs.tags }} - platforms: linux/amd64 - labels: | - org.opencontainers.image.source=${{ github.event.repository.html_url }} - org.opencontainers.image.created=${{ steps.prep.outputs.created }} - org.opencontainers.image.revision=${{ github.sha }} - cache-from: type=gha - cache-to: type=gha,mode=min diff --git a/.github/workflows/publish_docker-image_on_master.yml b/.github/workflows/publish_docker-image_on_master.yml deleted file mode 100644 index 6cab486b051..00000000000 --- a/.github/workflows/publish_docker-image_on_master.yml +++ /dev/null @@ -1,70 +0,0 @@ -name: Publish Docker image on Push to Master - -on: - push: - branches: [ master ] - paths: - - 'pkg/**' - - 'cmd/**' - - 'plugins/**' - - 'docker/docker_start.sh' - - 'docker/config.yaml' - - '.github/workflows/publish_docker-image_on_master.yml' - - 'Dockerfile' - - 'go.mod' - - 'go.sum' - - 'Makefile' - -jobs: - push_to_registry: - name: Push Docker image to Docker Hub - runs-on: ubuntu-latest - steps: - - - name: Check out the repo - uses: actions/checkout@v3 - with: - fetch-depth: 0 - - - name: Prepare - id: prep - run: | - DOCKER_IMAGE=crowdsecurity/crowdsec - GHCR_IMAGE=ghcr.io/${{ github.repository_owner }}/crowdsec - VERSION=dev - TAGS="${DOCKER_IMAGE}:${VERSION},${GHCR_IMAGE}:${VERSION}" - echo "tags=${TAGS}" >> $GITHUB_OUTPUT - echo "created=$(date -u +'%Y-%m-%dT%H:%M:%SZ')" >> $GITHUB_OUTPUT - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v2 - with: - config: .github/buildkit.toml - - - name: Login to DockerHub - uses: docker/login-action@v2 - with: - username: ${{ secrets.DOCKER_USERNAME }} - password: ${{ secrets.DOCKER_PASSWORD }} - - - name: Login to GitHub Container Registry - uses: docker/login-action@v2 - with: - registry: ghcr.io - username: ${{ github.repository_owner }} - password: ${{ secrets.GITHUB_TOKEN }} - - - name: Build and push full image - uses: docker/build-push-action@v4 - with: - context: . - file: ./Dockerfile - push: ${{ github.event_name != 'pull_request' }} - tags: ${{ steps.prep.outputs.tags }} - platforms: linux/amd64 - labels: | - org.opencontainers.image.source=${{ github.event.repository.html_url }} - org.opencontainers.image.created=${{ steps.prep.outputs.created }} - org.opencontainers.image.revision=${{ github.sha }} - cache-from: type=gha - cache-to: type=gha,mode=min diff --git a/.github/workflows/release_publish_docker-image-debian.yml b/.github/workflows/release_publish_docker-image-debian.yml deleted file mode 100644 index e766dae0966..00000000000 --- a/.github/workflows/release_publish_docker-image-debian.yml +++ /dev/null @@ -1,61 +0,0 @@ -name: Publish Docker Debian image - -on: - release: - types: - - released - - prereleased - workflow_dispatch: - -jobs: - push_to_registry: - name: Push Docker debian image to Docker Hub - runs-on: ubuntu-latest - steps: - - name: Check out the repo - uses: actions/checkout@v3 - with: - fetch-depth: 0 - - name: Prepare - id: prep - run: | - DOCKER_IMAGE=crowdsecurity/crowdsec - VERSION=bullseye - if [[ $GITHUB_REF == refs/tags/* ]]; then - VERSION=${GITHUB_REF#refs/tags/} - elif [[ $GITHUB_REF == refs/heads/* ]]; then - VERSION=$(echo ${GITHUB_REF#refs/heads/} | sed -E 's#/+#-#g') - elif [[ $GITHUB_REF == refs/pull/* ]]; then - VERSION=pr-${{ github.event.number }} - fi - TAGS="${DOCKER_IMAGE}:${VERSION}-debian" - if [[ "${{ github.event.action }}" == "released" ]]; then - TAGS=$TAGS,${DOCKER_IMAGE}:latest-debian - fi - echo "version=${VERSION}" >> $GITHUB_OUTPUT - echo "tags=${TAGS}" >> $GITHUB_OUTPUT - echo "created=$(date -u +'%Y-%m-%dT%H:%M:%SZ')" >> $GITHUB_OUTPUT - - name: Set up QEMU - uses: docker/setup-qemu-action@v2 - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v2 - with: - config: .github/buildkit.toml - - - name: Login to DockerHub - uses: docker/login-action@v2 - with: - username: ${{ secrets.DOCKER_USERNAME }} - password: ${{ secrets.DOCKER_PASSWORD }} - - name: Build and push - uses: docker/build-push-action@v4 - with: - context: . - file: ./Dockerfile.debian - push: ${{ github.event_name != 'pull_request' }} - tags: ${{ steps.prep.outputs.tags }} - platforms: linux/amd64,linux/arm64,linux/386 - labels: | - org.opencontainers.image.source=${{ github.event.repository.html_url }} - org.opencontainers.image.created=${{ steps.prep.outputs.created }} - org.opencontainers.image.revision=${{ github.sha }} diff --git a/.github/workflows/release_publish_docker-image.yml b/.github/workflows/release_publish_docker-image.yml deleted file mode 100644 index db344f54930..00000000000 --- a/.github/workflows/release_publish_docker-image.yml +++ /dev/null @@ -1,86 +0,0 @@ -name: Publish Docker image - -on: - release: - types: - - released - - prereleased - -jobs: - push_to_registry: - name: Push Docker image to Docker Hub - runs-on: ubuntu-latest - steps: - - name: Check out the repo - uses: actions/checkout@v3 - with: - fetch-depth: 0 - - name: Prepare - id: prep - run: | - DOCKER_IMAGE=crowdsecurity/crowdsec - GHCR_IMAGE=ghcr.io/${{ github.repository_owner }}/crowdsec - VERSION=edge - if [[ $GITHUB_REF == refs/tags/* ]]; then - VERSION=${GITHUB_REF#refs/tags/} - elif [[ $GITHUB_REF == refs/heads/* ]]; then - VERSION=$(echo ${GITHUB_REF#refs/heads/} | sed -E 's#/+#-#g') - elif [[ $GITHUB_REF == refs/pull/* ]]; then - VERSION=pr-${{ github.event.number }} - fi - TAGS="${DOCKER_IMAGE}:${VERSION},${GHCR_IMAGE}:${VERSION}" - TAGS_SLIM="${DOCKER_IMAGE}:${VERSION}-slim" - if [[ ${{ github.event.action }} == released ]]; then - TAGS=$TAGS,${DOCKER_IMAGE}:latest,${GHCR_IMAGE}:latest - TAGS_SLIM=$TAGS_SLIM,${DOCKER_IMAGE}:slim - fi - echo "version=${VERSION}" >> $GITHUB_OUTPUT - echo "tags=${TAGS}" >> $GITHUB_OUTPUT - echo "tags_slim=${TAGS_SLIM}" >> $GITHUB_OUTPUT - echo "created=$(date -u +'%Y-%m-%dT%H:%M:%SZ')" >> $GITHUB_OUTPUT - - name: Set up QEMU - uses: docker/setup-qemu-action@v2 - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v2 - with: - config: .github/buildkit.toml - - - name: Login to DockerHub - uses: docker/login-action@v2 - with: - username: ${{ secrets.DOCKER_USERNAME }} - password: ${{ secrets.DOCKER_PASSWORD }} - - - name: Login to GitHub Container Registry - uses: docker/login-action@v2 - with: - registry: ghcr.io - username: ${{ github.repository_owner }} - password: ${{ secrets.GITHUB_TOKEN }} - - - name: Build and push slim image - uses: docker/build-push-action@v4 - with: - context: . - file: ./Dockerfile - push: ${{ github.event_name != 'pull_request' }} - tags: ${{ steps.prep.outputs.tags_slim }} - target: slim - platforms: linux/amd64,linux/arm64,linux/arm/v7,linux/arm/v6,linux/386 - labels: | - org.opencontainers.image.source=${{ github.event.repository.html_url }} - org.opencontainers.image.created=${{ steps.prep.outputs.created }} - org.opencontainers.image.revision=${{ github.sha }} - - - name: Build and push full image - uses: docker/build-push-action@v4 - with: - context: . - file: ./Dockerfile - push: ${{ github.event_name != 'pull_request' }} - tags: ${{ steps.prep.outputs.tags }} - platforms: linux/amd64,linux/arm64,linux/arm/v7,linux/arm/v6,linux/386 - labels: | - org.opencontainers.image.source=${{ github.event.repository.html_url }} - org.opencontainers.image.created=${{ steps.prep.outputs.created }} - org.opencontainers.image.revision=${{ github.sha }} diff --git a/.github/workflows/update_docker_hub_doc.yml b/.github/workflows/update_docker_hub_doc.yml index 0a5047ddcf1..5c5f76acca4 100644 --- a/.github/workflows/update_docker_hub_doc.yml +++ b/.github/workflows/update_docker_hub_doc.yml @@ -1,4 +1,4 @@ -name: Update Docker Hub README +name: (push-master) Update Docker Hub README on: push: @@ -13,7 +13,7 @@ jobs: steps: - name: Check out the repo - uses: actions/checkout@v3 + uses: actions/checkout@v4 if: ${{ github.repository_owner == 'crowdsecurity' }} - name: Update docker hub README diff --git a/.gitignore b/.gitignore index 8fe1778baec..d76efcbfc48 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,10 @@ *.dylib *~ .pc + +# IDEs .vscode +.idea # If vendor is included, allow prebuilt (wasm?) libraries. !vendor/**/*.so @@ -34,18 +37,14 @@ test/coverage/* *.swo # Dependencies are not vendored by default, but a tarball is created by "make vendor" -# and provided in the release. Used by freebsd, gentoo, etc. +# and provided in the release. Used by gentoo, etc. vendor/ vendor.tgz # crowdsec binaries cmd/crowdsec-cli/cscli cmd/crowdsec/crowdsec -plugins/notifications/http/notification-http -plugins/notifications/slack/notification-slack -plugins/notifications/splunk/notification-splunk -plugins/notifications/email/notification-email -plugins/notifications/dummy/notification-dummy +cmd/notification-*/notification-* # Test cache (downloaded files) .cache @@ -61,3 +60,6 @@ msi __pycache__ *.py[cod] *.egg-info + +# automatically generated before running codecov +.github/codecov.yml diff --git a/.golangci.yml b/.golangci.yml index faa67c4bb80..4909d3e60c0 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,38 +1,38 @@ # https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml -run: - skip-dirs: - - pkg/time/rate - skip-files: - - pkg/database/ent/generate.go - - pkg/yamlpatch/merge.go - - pkg/yamlpatch/merge_test.go - linters-settings: - gocyclo: - min-complexity: 30 - - funlen: - # Checks the number of lines in a function. - # If lower than 0, disable the check. - # Default: 60 - lines: -1 - # Checks the number of statements in a function. - # If lower than 0, disable the check. - # Default: 40 - statements: -1 + gci: + sections: + - standard + - default + - prefix(github.com/crowdsecurity) + - prefix(github.com/crowdsecurity/crowdsec) + + gomoddirectives: + replace-allow-list: + - golang.org/x/time/rate govet: - check-shadowing: true + enable-all: true + disable: + - reflectvaluecompare + - fieldalignment - lll: - line-length: 140 + maintidx: + # raise this after refactoring + under: 15 misspell: locale: US + nestif: + # lower this after refactoring + min-complexity: 16 + + nlreturn: + block-size: 5 + nolintlint: - allow-leading-space: true # don't require machine-readable nolint directives (i.e. with no leading space) allow-unused: false # report any unused nolint directives require-explanation: false # don't require an explanation for nolint directives require-specific: false # don't require nolint directives to be specific about which linter is being skipped @@ -40,104 +40,226 @@ linters-settings: interfacebloat: max: 12 + depguard: + rules: + wrap: + deny: + - pkg: "github.com/pkg/errors" + desc: "errors.Wrap() is deprecated in favor of fmt.Errorf()" + files: + - "!**/pkg/database/*.go" + yaml: + files: + - "!**/pkg/acquisition/acquisition.go" + - "!**/pkg/acquisition/acquisition_test.go" + - "!**/pkg/acquisition/modules/appsec/appsec.go" + - "!**/pkg/acquisition/modules/cloudwatch/cloudwatch.go" + - "!**/pkg/acquisition/modules/docker/docker.go" + - "!**/pkg/acquisition/modules/file/file.go" + - "!**/pkg/acquisition/modules/journalctl/journalctl.go" + - "!**/pkg/acquisition/modules/kafka/kafka.go" + - "!**/pkg/acquisition/modules/kinesis/kinesis.go" + - "!**/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go" + - "!**/pkg/acquisition/modules/loki/loki.go" + - "!**/pkg/acquisition/modules/loki/timestamp_test.go" + - "!**/pkg/acquisition/modules/s3/s3.go" + - "!**/pkg/acquisition/modules/syslog/syslog.go" + - "!**/pkg/acquisition/modules/wineventlog/wineventlog_windows.go" + - "!**/pkg/appsec/appsec.go" + - "!**/pkg/appsec/loader.go" + - "!**/pkg/csplugin/broker.go" + - "!**/pkg/leakybucket/buckets_test.go" + - "!**/pkg/leakybucket/manager_load.go" + - "!**/pkg/parser/node.go" + - "!**/pkg/parser/node_test.go" + - "!**/pkg/parser/parsing_test.go" + - "!**/pkg/parser/stage.go" + deny: + - pkg: "gopkg.in/yaml.v2" + desc: "yaml.v2 is deprecated for new code in favor of yaml.v3" + + stylecheck: + checks: + - all + - -ST1003 # should not use underscores in Go names; ... + - -ST1005 # error strings should not be capitalized + - -ST1012 # error var ... should have name of the form ErrFoo + - -ST1016 # methods on the same type should have the same receiver name + - -ST1022 # comment on exported var ... should be of the form ... + + revive: + ignore-generated-header: true + severity: error + enable-all-rules: true + rules: + - name: add-constant + disabled: true + - name: cognitive-complexity + # lower this after refactoring + arguments: [119] + - name: comment-spacings + disabled: true + - name: confusing-results + disabled: true + - name: cyclomatic + # lower this after refactoring + arguments: [39] + - name: defer + disabled: true + - name: empty-block + disabled: true + - name: empty-lines + disabled: true + - name: error-naming + disabled: true + - name: flag-parameter + disabled: true + - name: function-result-limit + arguments: [6] + - name: function-length + # lower this after refactoring + arguments: [110, 237] + - name: get-return + disabled: true + - name: increment-decrement + disabled: true + - name: import-alias-naming + disabled: true + - name: import-shadowing + disabled: true + - name: line-length-limit + # lower this after refactoring + arguments: [221] + - name: max-control-nesting + # lower this after refactoring + arguments: [7] + - name: max-public-structs + disabled: true + - name: nested-structs + disabled: true + - name: package-comments + disabled: true + - name: redundant-import-alias + disabled: true + - name: time-equal + disabled: true + - name: var-naming + disabled: true + - name: unchecked-type-assertion + disabled: true + - name: exported + disabled: true + - name: unexported-naming + disabled: true + - name: unexported-return + disabled: true + - name: unhandled-error + disabled: true + arguments: + - "fmt.Print" + - "fmt.Printf" + - "fmt.Println" + - name: unnecessary-stmt + disabled: true + - name: unused-parameter + disabled: true + - name: unused-receiver + disabled: true + - name: use-any + disabled: true + - name: useless-break + disabled: true + + wsl: + # Allow blocks to end with comments + allow-trailing-comment: true + + gocritic: + enable-all: true + disabled-checks: + - typeDefFirst + - paramTypeCombine + - httpNoBody + - ifElseChain + - importShadow + - hugeParam + - rangeValCopy + - commentedOutCode + - commentedOutImport + - unnamedResult + - sloppyReassign + - appendCombine + - captLocal + - typeUnparen + - commentFormatting + - deferInLoop # + - sprintfQuotedString # + - whyNoLint + - equalFold # + - unnecessaryBlock # + - ptrToRefParam # + - stringXbytes # + - appendAssign # + - tooManyResultsChecker + - unnecessaryDefer + - docStub + - preferFprint + linters: enable-all: true disable: # # DEPRECATED by golangi-lint # - - deadcode # The owner seems to have abandoned the linter. Replaced by unused. - - exhaustivestruct # The owner seems to have abandoned the linter. Replaced by exhaustruct. - - golint # Golint differs from gofmt. Gofmt reformats Go source code, whereas golint prints out style mistakes - - ifshort # Checks that your code uses short syntax for if-statements whenever possible - - interfacer # Linter that suggests narrower interface types - - maligned # Tool to detect Go structs that would take less memory if their fields were sorted - - nosnakecase # nosnakecase is a linter that detects snake case of variable naming and function name. - - scopelint # Scopelint checks for unpinned variables in go programs - - structcheck # The owner seems to have abandoned the linter. Replaced by unused. - - varcheck # The owner seems to have abandoned the linter. Replaced by unused. - - # - # Enabled - # - - # - asasalint # check for pass []any as any in variadic func(...any) - # - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers - # - bidichk # Checks for dangerous unicode character sequences - # - decorder # check declaration order and count of types, constants, variables and functions - # - dupword # checks for duplicate words in the source code - # - durationcheck # check for two durations multiplied together - # - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases - # - exportloopref # checks for pointers to enclosing loop variables - # - funlen # Tool for detection of long functions - # - ginkgolinter # enforces standards of using ginkgo and gomega - # - gochecknoinits # Checks that no init functions are present in Go code - # - gocritic # Provides diagnostics that check for bugs, performance and style issues. - # - goheader # Checks is file header matches to pattern - # - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - # - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - # - goprintffuncname # Checks that printf-like functions are named with `f` at the end - # - gosimple # (megacheck): Linter for Go source code that specializes in simplifying a code - # - govet # (vet, vetshadow): Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - # - grouper # An analyzer to analyze expression groups. - # - importas # Enforces consistent import aliases - # - ineffassign # Detects when assignments to existing variables are not used - # - interfacebloat # A linter that checks the number of methods inside an interface. - # - logrlint # Check logr arguments. - # - makezero # Finds slice declarations with non-zero initial length - # - misspell # Finds commonly misspelled English words in comments - # - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - # - nolintlint # Reports ill-formed or insufficient nolint directives - # - predeclared # find code that shadows one of Go's predeclared identifiers - # - reassign # Checks that package variables are not reassigned - # - rowserrcheck # checks whether Err of rows is checked successfully - # - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - # - staticcheck # (megacheck): Staticcheck is a go vet on steroids, applying a ton of static analysis checks - # - testableexamples # linter checks if examples are testable (have an expected output) - # - tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17 - # - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - # - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code - # - unconvert # Remove unnecessary type conversions - # - unused # (megacheck): Checks Go code for unused constants, variables, functions and types - # - usestdlibvars # A linter that detect the possibility to use variables/constants from the Go standard library. + - execinquery + - exportloopref + - gomnd + + # + # Redundant + # + + - gocyclo # revive + - cyclop # revive + - lll # revive + - funlen # revive + - gocognit # revive + + # Disabled atm + + - intrange # intrange is a linter to find places where for loops could make use of an integer range. # # Recommended? (easy) # - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) - - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. - - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. + - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and reports occasions, where the check for the returned error can be omitted. - exhaustive # check exhaustiveness of enum switch statements - gci # Gci control golang package import order and make it always deterministic. - godot # Check if comments end in a period - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification - - goimports # In addition to fixing imports, goimports also formats your code in the same style as gofmt. + - goimports # Check import statements are formatted according to the 'goimport' command. Reformat imports in autofix mode. - gosec # (gas): Inspects source code for security problems - - lll # Reports long lines + - inamedparam # reports interfaces with unnamed method parameters - musttag # enforce field tags in (un)marshaled structs - - nakedret # Finds naked returns in functions greater than a specified function length - - nonamedreturns # Reports all named returns - - nosprintfhostport # Checks for misuse of Sprintf to construct a host with port in a URL. - promlinter # Check Prometheus metrics naming via promlint - - revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint. - - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - - wastedassign # wastedassign finds wasted assignment statements. + - protogetter # Reports direct reads from proto message fields when getters should be used + - tagalign # check that struct tags are well aligned + - thelper # thelper detects tests helpers which is not start with t.Helper() method. - wrapcheck # Checks that errors returned from external packages are wrapped - - depguard # Go linter that checks if package imports are in a list of acceptable packages # # Recommended? (requires some work) # - - bodyclose # checks whether HTTP response body is closed successfully - containedctx # containedctx is a linter that detects struct contained context.Context field - - contextcheck # check the function whether use a non-inherited context + - contextcheck # check whether the function uses a non-inherited context - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. - - gomnd # An analyzer to detect magic numbers. - ireturn # Accept Interfaces, Return Concrete Types + - mnd # An analyzer to detect magic numbers. - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. - - noctx # noctx finds sending http request without context.Context + - noctx # Finds sending http request without context.Context - unparam # Reports unused function parameters # @@ -146,33 +268,26 @@ linters: - gofumpt # Gofumpt checks whether code was gofumpt-ed. - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity - - whitespace # Tool for detection of leading and trailing whitespace - - wsl # Whitespace Linter - Forces you to use empty lines! + - whitespace # Whitespace is a linter that checks for unnecessary newlines at the start and end of functions, if, for, etc. + - wsl # add or remove empty lines # # Well intended, but not ready for this # - - cyclop # checks function and package cyclomatic complexity - dupl # Tool for code clone detection - forcetypeassert # finds forced type assertions - - gocognit # Computes and checks the cognitive complexity of functions - - gocyclo # Computes and checks the cyclomatic complexity of functions - godox # Tool for detection of FIXME, TODO and other comment keywords - - goerr113 # Golang linter to check the errors handling expressions - - maintidx # maintidx measures the maintainability index of each function. - - nestif # Reports deeply nested if statements - - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test + - err113 # Go linter to check the errors handling expressions + - paralleltest # Detects missing usage of t.Parallel() method in your Go test - testpackage # linter that makes you use a separate _test package # # Too strict / too many false positives (for now?) # - - execinquery # execinquery is a linter about query string checker in Query function which reads your Go src files and warning it finds - exhaustruct # Checks if all structure fields are initialized - forbidigo # Forbids identifiers - - gochecknoglobals # check that no global variables exist + - gochecknoglobals # Check that no global variables exist. - goconst # Finds repeated strings that could be replaced by a constant - - stylecheck # Stylecheck is a replacement for golint - tagliatelle # Checks the struct tags. - varnamelen # checks that the length of a variable's name matches its scope @@ -187,45 +302,162 @@ issues: # “Look, that’s why there’s rules, understand? So that you think before you # break ‘em.” ― Terry Pratchett + exclude-dirs: + - pkg/time/rate + - pkg/metabase + + exclude-files: + - pkg/yamlpatch/merge.go + - pkg/yamlpatch/merge_test.go + + exclude-generated: strict + max-issues-per-linter: 0 - max-same-issues: 10 + max-same-issues: 0 exclude-rules: - - path: go.mod - text: "replacement are not allowed: golang.org/x/time/rate" + + # Won't fix: # `err` is often shadowed, we may continue to do it - linters: - govet - text: "shadow: declaration of \"err\" shadows declaration" - - # - # errcheck - # + text: "shadow: declaration of \"(err|ctx)\" shadows declaration" - linters: - errcheck text: "Error return value of `.*` is not checked" + # Will fix, trivial - just beware of merge conflicts + + - linters: + - perfsprint + text: "fmt.Sprintf can be replaced .*" + # - # gocritic + # Will fix, easy but some neurons required # - linters: - - gocritic - text: "ifElseChain: rewrite if-else to switch statement" + - errorlint + text: "non-wrapping format verb for fmt.Errorf. Use `%w` to format errors" + + - linters: + - errorlint + text: "type assertion on error will fail on wrapped errors. Use errors.As to check for specific errors" + + - linters: + - errorlint + text: "type switch on error will fail on wrapped errors. Use errors.As to check for specific errors" + + - linters: + - errorlint + text: "comparing with .* will fail on wrapped errors. Use errors.Is to check for a specific error" + + - linters: + - nosprintfhostport + text: "host:port in url should be constructed with net.JoinHostPort and not directly with fmt.Sprintf" + + # https://github.com/timakin/bodyclose + - linters: + - bodyclose + text: "response body must be closed" + + # named/naked returns are evil, with a single exception + # https://go.dev/wiki/CodeReviewComments#named-result-parameters + - linters: + - nonamedreturns + text: "named return .* with type .* found" + + - linters: + - revive + path: pkg/leakybucket/manager_load.go + text: "confusing-naming: Field '.*' differs only by capitalization to other field in the struct type BucketFactory" + + - linters: + - revive + path: pkg/exprhelpers/helpers.go + text: "confusing-naming: Method 'flatten' differs only by capitalization to function 'Flatten' in the same source file" + + - linters: + - revive + path: pkg/appsec/query_utils.go + text: "confusing-naming: Method 'parseQuery' differs only by capitalization to function 'ParseQuery' in the same source file" + + - linters: + - revive + path: pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go + text: "confusing-naming: Method 'QueryRange' differs only by capitalization to method 'queryRange' in the same source file" + + - linters: + - revive + path: cmd/crowdsec-cli/copyfile.go + + - linters: + - revive + path: pkg/hubtest/hubtest_item.go + text: "cyclomatic: .*RunWithLogFile" + + # tolerate complex functions in tests for now + - linters: + - maintidx + path: "(.+)_test.go" + + # tolerate long functions in tests + - linters: + - revive + path: "pkg/(.+)_test.go" + text: "function-length: .*" + + # tolerate long lines in tests + - linters: + - revive + path: "pkg/(.+)_test.go" + text: "line-length-limit: .*" + + # tolerate deep exit in tests, for now + - linters: + - revive + path: "pkg/(.+)_test.go" + text: "deep-exit: .*" + + # we use t,ctx instead of ctx,t in tests + - linters: + - revive + path: "pkg/(.+)_test.go" + text: "context-as-argument: context.Context should be the first parameter of a function" + + # tolerate deep exit in cobra's OnInitialize, for now + - linters: + - revive + path: "cmd/crowdsec-cli/main.go" + text: "deep-exit: .*" + + - linters: + - revive + path: "cmd/crowdsec-cli/clihub/item_metrics.go" + text: "deep-exit: .*" + + - linters: + - revive + path: "cmd/crowdsec-cli/idgen/password.go" + text: "deep-exit: .*" - linters: - - gocritic - text: "captLocal: `.*' should not be capitalized" + - revive + path: "pkg/leakybucket/overflows.go" + text: "deep-exit: .*" - linters: - - gocritic - text: "appendAssign: append result not assigned to the same slice" + - revive + path: "cmd/crowdsec/crowdsec.go" + text: "deep-exit: .*" - linters: - - gocritic - text: "commentFormatting: put a space between `//` and comment text" + - revive + path: "cmd/crowdsec/api.go" + text: "deep-exit: .*" - linters: - - staticcheck - text: "x509.ParseCRL has been deprecated since Go 1.19: Use ParseRevocationList instead" + - revive + path: "cmd/crowdsec/win_service.go" + text: "deep-exit: .*" diff --git a/Dockerfile b/Dockerfile index f43eda4a09b..450ea69017f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,67 +1,63 @@ # vim: set ft=dockerfile: -ARG GOVERSION=1.20.6 +FROM golang:1.22-alpine3.20 AS build -FROM golang:${GOVERSION}-alpine AS build +ARG BUILD_VERSION WORKDIR /go/src/crowdsec # We like to choose the release of re2 to use, and Alpine does not ship a static version anyway. ENV RE2_VERSION=2023-03-01 +ENV BUILD_VERSION=${BUILD_VERSION} # wizard.sh requires GNU coreutils RUN apk add --no-cache git g++ gcc libc-dev make bash gettext binutils-gold coreutils pkgconfig && \ - wget https://github.com/google/re2/archive/refs/tags/${RE2_VERSION}.tar.gz && \ + wget -q https://github.com/google/re2/archive/refs/tags/${RE2_VERSION}.tar.gz && \ tar -xzf ${RE2_VERSION}.tar.gz && \ cd re2-${RE2_VERSION} && \ make install && \ echo "githubciXXXXXXXXXXXXXXXXXXXXXXXX" > /etc/machine-id && \ - go install github.com/mikefarah/yq/v4@v4.34.1 + go install github.com/mikefarah/yq/v4@v4.44.3 COPY . . -RUN make clean release DOCKER_BUILD=1 BUILD_STATIC=1 && \ +RUN make clean release DOCKER_BUILD=1 BUILD_STATIC=1 CGO_CFLAGS="-D_LARGEFILE64_SOURCE" && \ cd crowdsec-v* && \ ./wizard.sh --docker-mode && \ cd - >/dev/null && \ - cscli hub update && \ + cscli hub update --with-content && \ cscli collections install crowdsecurity/linux && \ cscli parsers install crowdsecurity/whitelists # In case we need to remove agents here.. # cscli machines list -o json | yq '.[].machineId' | xargs -r cscli machines delete -FROM alpine:latest as slim +FROM alpine:latest AS slim -RUN apk add --no-cache --repository=http://dl-cdn.alpinelinux.org/alpine/edge/community tzdata bash && \ +RUN apk add --no-cache --repository=http://dl-cdn.alpinelinux.org/alpine/edge/community tzdata bash rsync && \ mkdir -p /staging/etc/crowdsec && \ mkdir -p /staging/etc/crowdsec/acquis.d && \ mkdir -p /staging/var/lib/crowdsec && \ mkdir -p /var/lib/crowdsec/data -COPY --from=build /go/bin/yq /usr/local/bin/yq +COPY --from=build /go/bin/yq /usr/local/bin/crowdsec /usr/local/bin/cscli /usr/local/bin/ COPY --from=build /etc/crowdsec /staging/etc/crowdsec -COPY --from=build /usr/local/bin/crowdsec /usr/local/bin/crowdsec -COPY --from=build /usr/local/bin/cscli /usr/local/bin/cscli COPY --from=build /go/src/crowdsec/docker/docker_start.sh / COPY --from=build /go/src/crowdsec/docker/config.yaml /staging/etc/crowdsec/config.yaml +COPY --from=build /var/lib/crowdsec /staging/var/lib/crowdsec RUN yq -n '.url="http://0.0.0.0:8080"' | install -m 0600 /dev/stdin /staging/etc/crowdsec/local_api_credentials.yaml -ENTRYPOINT /bin/bash docker_start.sh +ENTRYPOINT ["/bin/bash", "/docker_start.sh"] -FROM slim as plugins +FROM slim AS full # Due to the wizard using cp -n, we have to copy the config files directly from the source as -n does not exist in busybox cp # The files are here for reference, as users will need to mount a new version to be actually able to use notifications -COPY --from=build /go/src/crowdsec/plugins/notifications/email/email.yaml /staging/etc/crowdsec/notifications/email.yaml -COPY --from=build /go/src/crowdsec/plugins/notifications/http/http.yaml /staging/etc/crowdsec/notifications/http.yaml -COPY --from=build /go/src/crowdsec/plugins/notifications/slack/slack.yaml /staging/etc/crowdsec/notifications/slack.yaml -COPY --from=build /go/src/crowdsec/plugins/notifications/splunk/splunk.yaml /staging/etc/crowdsec/notifications/splunk.yaml -COPY --from=build /usr/local/lib/crowdsec/plugins /usr/local/lib/crowdsec/plugins - -FROM slim as geoip - -COPY --from=build /var/lib/crowdsec /staging/var/lib/crowdsec - -FROM plugins as full +COPY --from=build \ + /go/src/crowdsec/cmd/notification-email/email.yaml \ + /go/src/crowdsec/cmd/notification-http/http.yaml \ + /go/src/crowdsec/cmd/notification-slack/slack.yaml \ + /go/src/crowdsec/cmd/notification-splunk/splunk.yaml \ + /go/src/crowdsec/cmd/notification-sentinel/sentinel.yaml \ + /staging/etc/crowdsec/notifications/ -COPY --from=build /var/lib/crowdsec /staging/var/lib/crowdsec +COPY --from=build /usr/local/lib/crowdsec/plugins /usr/local/lib/crowdsec/plugins diff --git a/Dockerfile.debian b/Dockerfile.debian index 75405431654..8bf2698c786 100644 --- a/Dockerfile.debian +++ b/Dockerfile.debian @@ -1,7 +1,7 @@ # vim: set ft=dockerfile: -ARG GOVERSION=1.20.6 +FROM golang:1.22-bookworm AS build -FROM golang:${GOVERSION}-bookworm AS build +ARG BUILD_VERSION WORKDIR /go/src/crowdsec @@ -10,6 +10,7 @@ ENV DEBCONF_NOWARNINGS="yes" # We like to choose the release of re2 to use, the debian version is usually older. ENV RE2_VERSION=2023-03-01 +ENV BUILD_VERSION=${BUILD_VERSION} # wizard.sh requires GNU coreutils RUN apt-get update && \ @@ -20,7 +21,7 @@ RUN apt-get update && \ make && \ make install && \ echo "githubciXXXXXXXXXXXXXXXXXXXXXXXX" > /etc/machine-id && \ - go install github.com/mikefarah/yq/v4@v4.34.1 + go install github.com/mikefarah/yq/v4@v4.44.3 COPY . . @@ -28,14 +29,14 @@ RUN make clean release DOCKER_BUILD=1 BUILD_STATIC=1 && \ cd crowdsec-v* && \ ./wizard.sh --docker-mode && \ cd - >/dev/null && \ - cscli hub update && \ + cscli hub update --with-content && \ cscli collections install crowdsecurity/linux && \ cscli parsers install crowdsecurity/whitelists # In case we need to remove agents here.. # cscli machines list -o json | yq '.[].machineId' | xargs -r cscli machines delete -FROM debian:bookworm-slim as slim +FROM debian:bookworm-slim AS slim ENV DEBIAN_FRONTEND=noninteractive ENV DEBCONF_NOWARNINGS="yes" @@ -47,37 +48,40 @@ RUN apt-get update && \ iproute2 \ ca-certificates \ bash \ - tzdata && \ + tzdata \ + rsync && \ mkdir -p /staging/etc/crowdsec && \ mkdir -p /staging/etc/crowdsec/acquis.d && \ mkdir -p /staging/var/lib/crowdsec && \ mkdir -p /var/lib/crowdsec/data -COPY --from=build /go/bin/yq /usr/local/bin/yq +COPY --from=build /go/bin/yq /usr/local/bin/crowdsec /usr/local/bin/cscli /usr/local/bin/ COPY --from=build /etc/crowdsec /staging/etc/crowdsec -COPY --from=build /usr/local/bin/crowdsec /usr/local/bin/crowdsec -COPY --from=build /usr/local/bin/cscli /usr/local/bin/cscli COPY --from=build /go/src/crowdsec/docker/docker_start.sh / COPY --from=build /go/src/crowdsec/docker/config.yaml /staging/etc/crowdsec/config.yaml RUN yq -n '.url="http://0.0.0.0:8080"' | install -m 0600 /dev/stdin /staging/etc/crowdsec/local_api_credentials.yaml && \ yq eval -i ".plugin_config.group = \"nogroup\"" /staging/etc/crowdsec/config.yaml -ENTRYPOINT /bin/bash docker_start.sh +ENTRYPOINT ["/bin/bash", "docker_start.sh"] -FROM slim as plugins +FROM slim AS plugins # Due to the wizard using cp -n, we have to copy the config files directly from the source as -n does not exist in busybox cp # The files are here for reference, as users will need to mount a new version to be actually able to use notifications -COPY --from=build /go/src/crowdsec/plugins/notifications/email/email.yaml /staging/etc/crowdsec/notifications/email.yaml -COPY --from=build /go/src/crowdsec/plugins/notifications/http/http.yaml /staging/etc/crowdsec/notifications/http.yaml -COPY --from=build /go/src/crowdsec/plugins/notifications/slack/slack.yaml /staging/etc/crowdsec/notifications/slack.yaml -COPY --from=build /go/src/crowdsec/plugins/notifications/splunk/splunk.yaml /staging/etc/crowdsec/notifications/splunk.yaml +COPY --from=build \ + /go/src/crowdsec/cmd/notification-email/email.yaml \ + /go/src/crowdsec/cmd/notification-http/http.yaml \ + /go/src/crowdsec/cmd/notification-slack/slack.yaml \ + /go/src/crowdsec/cmd/notification-splunk/splunk.yaml \ + /go/src/crowdsec/cmd/notification-sentinel/sentinel.yaml \ + /staging/etc/crowdsec/notifications/ + COPY --from=build /usr/local/lib/crowdsec/plugins /usr/local/lib/crowdsec/plugins -FROM slim as geoip +FROM slim AS geoip COPY --from=build /var/lib/crowdsec /staging/var/lib/crowdsec -FROM plugins as full +FROM plugins AS full COPY --from=build /var/lib/crowdsec /staging/var/lib/crowdsec diff --git a/Makefile b/Makefile index 543f6017715..bbfa4bbee94 100644 --- a/Makefile +++ b/Makefile @@ -23,22 +23,18 @@ BUILD_RE2_WASM ?= 0 BUILD_STATIC ?= 0 # List of plugins to build -PLUGINS ?= $(patsubst ./plugins/notifications/%,%,$(wildcard ./plugins/notifications/*)) - -# Can be overriden, if you can deal with the consequences -BUILD_REQUIRE_GO_MAJOR ?= 1 -BUILD_REQUIRE_GO_MINOR ?= 20 +PLUGINS ?= $(patsubst ./cmd/notification-%,%,$(wildcard ./cmd/notification-*)) #-------------------------------------- -GOCMD = go -GOTEST = $(GOCMD) test +GO = go +GOTEST = $(GO) test BUILD_CODENAME ?= alphaga CROWDSEC_FOLDER = ./cmd/crowdsec CSCLI_FOLDER = ./cmd/crowdsec-cli/ -PLUGINS_DIR = ./plugins/notifications +PLUGINS_DIR_PREFIX = ./cmd/notification- CROWDSEC_BIN = crowdsec$(EXT) CSCLI_BIN = cscli$(EXT) @@ -64,24 +60,25 @@ bool = $(if $(filter $(call lc, $1),1 yes true),1,0) #-------------------------------------- # -# Define MAKE_FLAGS and LD_OPTS for the sub-makefiles in cmd/ and plugins/ +# Define MAKE_FLAGS and LD_OPTS for the sub-makefiles in cmd/ # MAKE_FLAGS = --no-print-directory GOARCH=$(GOARCH) GOOS=$(GOOS) RM="$(RM)" WIN_IGNORE_ERR="$(WIN_IGNORE_ERR)" CP="$(CP)" CPR="$(CPR)" MKDIR="$(MKDIR)" LD_OPTS_VARS= \ --X 'github.com/crowdsecurity/go-cs-lib/pkg/version.Version=$(BUILD_VERSION)' \ --X 'github.com/crowdsecurity/go-cs-lib/pkg/version.BuildDate=$(BUILD_TIMESTAMP)' \ --X 'github.com/crowdsecurity/go-cs-lib/pkg/version.Tag=$(BUILD_TAG)' \ +-X 'github.com/crowdsecurity/go-cs-lib/version.Version=$(BUILD_VERSION)' \ +-X 'github.com/crowdsecurity/go-cs-lib/version.BuildDate=$(BUILD_TIMESTAMP)' \ +-X 'github.com/crowdsecurity/go-cs-lib/version.Tag=$(BUILD_TAG)' \ -X '$(GO_MODULE_NAME)/pkg/cwversion.Codename=$(BUILD_CODENAME)' \ -X '$(GO_MODULE_NAME)/pkg/csconfig.defaultConfigDir=$(DEFAULT_CONFIGDIR)' \ -X '$(GO_MODULE_NAME)/pkg/csconfig.defaultDataDir=$(DEFAULT_DATADIR)' ifneq (,$(DOCKER_BUILD)) -LD_OPTS_VARS += -X '$(GO_MODULE_NAME)/pkg/cwversion.System=docker' +LD_OPTS_VARS += -X 'github.com/crowdsecurity/go-cs-lib/version.System=docker' endif -GO_TAGS := netgo,osusergo,sqlite_omit_load_extension +#expr_debug tag is required to enable the debug mode in expr +GO_TAGS := netgo,osusergo,sqlite_omit_load_extension,expr_debug # this will be used by Go in the make target, some distributions require it export PKG_CONFIG_PATH:=/usr/local/lib/pkgconfig:$(PKG_CONFIG_PATH) @@ -92,7 +89,6 @@ ifeq ($(PKG_CONFIG),) endif ifeq ($(RE2_CHECK),) -# we could detect the platform and suggest the command to install RE2_FAIL := "libre2-dev is not installed, please install it or set BUILD_RE2_WASM=1 to use the WebAssembly version" else # += adds a space that we don't want @@ -101,6 +97,7 @@ LD_OPTS_VARS += -X '$(GO_MODULE_NAME)/pkg/cwversion.Libre2=C++' endif endif +# Build static to avoid the runtime dependency on libre2.so ifeq ($(call bool,$(BUILD_STATIC)),1) BUILD_TYPE = static EXTLDFLAGS := -extldflags '-static' @@ -109,22 +106,94 @@ BUILD_TYPE = dynamic EXTLDFLAGS := endif -export LD_OPTS=-ldflags "-s -w $(EXTLDFLAGS) $(LD_OPTS_VARS)" \ - -trimpath -tags $(GO_TAGS) +# Build with debug symbols, and disable optimizations + inlining, to use Delve +ifeq ($(call bool,$(DEBUG)),1) +STRIP_SYMBOLS := +DISABLE_OPTIMIZATION := -gcflags "-N -l" +else +STRIP_SYMBOLS := -s -w +DISABLE_OPTIMIZATION := +endif + +#-------------------------------------- + +# Handle optional components and build profiles, to save space on the final binaries. + +# Keep it safe for now until we decide how to expand on the idea. Either choose a profile or exclude components manually. +# For example if we want to disable some component by default, or have opt-in components (INCLUDE?). + +ifeq ($(and $(BUILD_PROFILE),$(EXCLUDE)),1) +$(error "Cannot specify both BUILD_PROFILE and EXCLUDE") +endif + +COMPONENTS := \ + datasource_appsec \ + datasource_cloudwatch \ + datasource_docker \ + datasource_file \ + datasource_k8saudit \ + datasource_kafka \ + datasource_journalctl \ + datasource_kinesis \ + datasource_loki \ + datasource_s3 \ + datasource_syslog \ + datasource_wineventlog \ + cscli_setup + +comma := , +space := $(empty) $(empty) + +# Predefined profiles + +# keep only datasource-file +EXCLUDE_MINIMAL := $(subst $(space),$(comma),$(filter-out datasource_file,,$(COMPONENTS))) + +# example +# EXCLUDE_MEDIUM := datasource_kafka,datasource_kinesis,datasource_s3 + +BUILD_PROFILE ?= default + +# Set the EXCLUDE_LIST based on the chosen profile, unless EXCLUDE is already set +ifeq ($(BUILD_PROFILE),minimal) +EXCLUDE ?= $(EXCLUDE_MINIMAL) +else ifneq ($(BUILD_PROFILE),default) +$(error Invalid build profile specified: $(BUILD_PROFILE). Valid profiles are: minimal, default) +endif + +# Create list of excluded components from the EXCLUDE variable +EXCLUDE_LIST := $(subst $(comma),$(space),$(EXCLUDE)) + +INVALID_COMPONENTS := $(filter-out $(COMPONENTS),$(EXCLUDE_LIST)) +ifneq ($(INVALID_COMPONENTS),) +$(error Invalid optional components specified in EXCLUDE: $(INVALID_COMPONENTS). Valid components are: $(COMPONENTS)) +endif + +# Convert the excluded components to "no_" form +COMPONENT_TAGS := $(foreach component,$(EXCLUDE_LIST),no_$(component)) -ifneq (,$(TEST_COVERAGE)) +ifneq ($(COMPONENT_TAGS),) +GO_TAGS := $(GO_TAGS),$(subst $(space),$(comma),$(COMPONENT_TAGS)) +endif + +#-------------------------------------- + +export LD_OPTS=-ldflags "$(STRIP_SYMBOLS) $(EXTLDFLAGS) $(LD_OPTS_VARS)" \ + -trimpath -tags $(GO_TAGS) $(DISABLE_OPTIMIZATION) + +ifeq ($(call bool,$(TEST_COVERAGE)),1) LD_OPTS += -cover endif #-------------------------------------- .PHONY: build -build: pre-build goversion crowdsec cscli plugins +build: build-info crowdsec cscli plugins ## Build crowdsec, cscli and plugins -# Sanity checks and build information -.PHONY: pre-build -pre-build: +.PHONY: build-info +build-info: ## Print build information $(info Building $(BUILD_VERSION) ($(BUILD_TAG)) $(BUILD_TYPE) for $(GOOS)/$(GOARCH)) + $(info Excluded components: $(EXCLUDE_LIST)) ifneq (,$(RE2_FAIL)) $(error $(RE2_FAIL)) @@ -135,19 +204,47 @@ ifneq (,$(RE2_CHECK)) else $(info Fallback to WebAssembly regexp library. To use the C++ version, make sure you have installed libre2-dev and pkg-config.) endif + +ifeq ($(call bool,$(DEBUG)),1) + $(info Building with debug symbols and disabled optimizations) +endif + +ifeq ($(call bool,$(TEST_COVERAGE)),1) + $(info Test coverage collection enabled) +endif + +# intentional, empty line $(info ) .PHONY: all -all: clean test build +all: clean test build ## Clean, test and build (requires localstack) .PHONY: plugins -plugins: +plugins: ## Build notification plugins @$(foreach plugin,$(PLUGINS), \ - $(MAKE) -C $(PLUGINS_DIR)/$(plugin) build $(MAKE_FLAGS); \ + $(MAKE) -C $(PLUGINS_DIR_PREFIX)$(plugin) build $(MAKE_FLAGS); \ ) +# same as "$(MAKE) -f debian/rules clean" but without the dependency on debhelper +.PHONY: clean-debian +clean-debian: + @$(RM) -r debian/crowdsec + @$(RM) -r debian/crowdsec + @$(RM) -r debian/files + @$(RM) -r debian/.debhelper + @$(RM) -r debian/*.substvars + @$(RM) -r debian/*-stamp + +.PHONY: clean-rpm +clean-rpm: + @$(RM) -r rpm/BUILD + @$(RM) -r rpm/BUILDROOT + @$(RM) -r rpm/RPMS + @$(RM) -r rpm/SOURCES/*.tar.gz + @$(RM) -r rpm/SRPMS + .PHONY: clean -clean: testclean +clean: clean-debian clean-rpm testclean ## Remove build artifacts @$(MAKE) -C $(CROWDSEC_FOLDER) clean $(MAKE_FLAGS) @$(MAKE) -C $(CSCLI_FOLDER) clean $(MAKE_FLAGS) @$(RM) $(CROWDSEC_BIN) $(WIN_IGNORE_ERR) @@ -155,19 +252,19 @@ clean: testclean @$(RM) *.log $(WIN_IGNORE_ERR) @$(RM) crowdsec-release.tgz $(WIN_IGNORE_ERR) @$(foreach plugin,$(PLUGINS), \ - $(MAKE) -C $(PLUGINS_DIR)/$(plugin) clean $(MAKE_FLAGS); \ + $(MAKE) -C $(PLUGINS_DIR_PREFIX)$(plugin) clean $(MAKE_FLAGS); \ ) .PHONY: cscli -cscli: goversion +cscli: ## Build cscli @$(MAKE) -C $(CSCLI_FOLDER) build $(MAKE_FLAGS) .PHONY: crowdsec -crowdsec: goversion +crowdsec: ## Build crowdsec @$(MAKE) -C $(CROWDSEC_FOLDER) build $(MAKE_FLAGS) .PHONY: testclean -testclean: bats-clean +testclean: bats-clean ## Remove test artifacts @$(RM) pkg/apiserver/ent $(WIN_IGNORE_ERR) @$(RM) pkg/cwhub/hubdir $(WIN_IGNORE_ERR) @$(RM) pkg/cwhub/install $(WIN_IGNORE_ERR) @@ -175,53 +272,45 @@ testclean: bats-clean # for the tests with localstack export AWS_ENDPOINT_FORCE=http://localhost:4566 -export AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE -export AWS_SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY +export AWS_ACCESS_KEY_ID=test +export AWS_SECRET_ACCESS_KEY=test testenv: - @echo 'NOTE: You need Docker, docker-compose and run "make localstack" in a separate shell ("make localstack-stop" to terminate it)' + @echo 'NOTE: You need to run "make localstack" in a separate shell, "make localstack-stop" to terminate it' -# run the tests with localstack .PHONY: test -test: testenv goversion - $(GOTEST) $(LD_OPTS) ./... +test: testenv ## Run unit tests with localstack + $(GOTEST) --tags=$(GO_TAGS) $(LD_OPTS) ./... -# run the tests with localstack and coverage .PHONY: go-acc -go-acc: testenv goversion - go-acc ./... -o coverage.out --ignore database,notifications,protobufs,cwversion,cstest,models -- $(LD_OPTS) | \ - sed 's/ *coverage:.*of statements in.*//' +go-acc: testenv ## Run unit tests with localstack + coverage + go-acc ./... -o coverage.out --ignore database,notifications,protobufs,cwversion,cstest,models --tags $(GO_TAGS) -- $(LD_OPTS) + +check_docker: + @if ! docker info > /dev/null 2>&1; then \ + echo "Could not run 'docker info': check that docker is running, and if you need to run this command with sudo."; \ + fi # mock AWS services .PHONY: localstack -localstack: - docker-compose -f test/localstack/docker-compose.yml up +localstack: check_docker ## Run localstack containers (required for unit testing) + docker compose -f test/localstack/docker-compose.yml up .PHONY: localstack-stop -localstack-stop: - docker-compose -f test/localstack/docker-compose.yml down - -# list of plugins that contain go.mod -PLUGIN_VENDOR = $(foreach plugin,$(PLUGINS),$(shell if [ -f $(PLUGINS_DIR)/$(plugin)/go.mod ]; then echo $(PLUGINS_DIR)/$(plugin); fi)) +localstack-stop: check_docker ## Stop localstack containers + docker compose -f test/localstack/docker-compose.yml down # build vendor.tgz to be distributed with the release .PHONY: vendor -vendor: - $(foreach plugin_dir,$(PLUGIN_VENDOR), \ - cd $(plugin_dir) >/dev/null && \ - $(GOCMD) mod vendor && \ - cd - >/dev/null; \ - ) - $(GOCMD) mod vendor - tar -czf vendor.tgz vendor $(foreach plugin_dir,$(PLUGIN_VENDOR),$(plugin_dir)/vendor) +vendor: vendor-remove ## CI only - vendor dependencies and archive them for packaging + $(GO) mod vendor + tar czf vendor.tgz vendor + tar --create --auto-compress --file=$(RELDIR)-vendor.tar.xz vendor # remove vendor directories and vendor.tgz .PHONY: vendor-remove -vendor-remove: - $(foreach plugin_dir,$(PLUGIN_VENDOR), \ - $(RM) $(plugin_dir)/vendor; \ - ) - $(RM) vendor vendor.tgz +vendor-remove: ## Remove vendor dependencies and archives + $(RM) vendor vendor.tgz *-vendor.tar.xz .PHONY: package package: @@ -232,9 +321,9 @@ package: @$(CP) $(CSCLI_FOLDER)/$(CSCLI_BIN) $(RELDIR)/cmd/crowdsec-cli @$(foreach plugin,$(PLUGINS), \ - $(MKDIR) $(RELDIR)/$(PLUGINS_DIR)/$(plugin); \ - $(CP) $(PLUGINS_DIR)/$(plugin)/notification-$(plugin)$(EXT) $(RELDIR)/$(PLUGINS_DIR)/$(plugin); \ - $(CP) $(PLUGINS_DIR)/$(plugin)/$(plugin).yaml $(RELDIR)/$(PLUGINS_DIR)/$(plugin)/; \ + $(MKDIR) $(RELDIR)/$(PLUGINS_DIR_PREFIX)$(plugin); \ + $(CP) $(PLUGINS_DIR_PREFIX)$(plugin)/notification-$(plugin)$(EXT) $(RELDIR)/$(PLUGINS_DIR_PREFIX)$(plugin); \ + $(CP) $(PLUGINS_DIR_PREFIX)$(plugin)/$(plugin).yaml $(RELDIR)/$(PLUGINS_DIR_PREFIX)$(plugin)/; \ ) @$(CPR) ./config $(RELDIR) @@ -252,18 +341,15 @@ else @if (Test-Path -Path $(RELDIR)) { echo "$(RELDIR) already exists, abort" ; exit 1 ; } endif -# build a release tarball .PHONY: release -release: check_release build package +release: check_release build package ## Build a release tarball -# build the windows installer .PHONY: windows_installer -windows_installer: build +windows_installer: build ## Windows - build the installer @.\make_installer.ps1 -version $(BUILD_VERSION) -# build the chocolatey package .PHONY: chocolatey -chocolatey: windows_installer +chocolatey: windows_installer ## Windows - build the chocolatey package @.\make_chocolatey.ps1 -version $(BUILD_VERSION) # Include test/bats.mk only if it exists @@ -275,4 +361,4 @@ else include test/bats.mk endif -include mk/goversion.mk +include mk/help.mk diff --git a/README.md b/README.md index 6428c3a8053..a900f0ee514 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ +Go Reference diff --git a/azure-pipelines.yml b/azure-pipelines.yml index d4a2f5b2114..6051ca67393 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -15,19 +15,13 @@ pool: stages: - stage: Build jobs: - - job: + - job: Build displayName: "Build" steps: - - task: DotNetCoreCLI@2 - displayName: "Install SignClient" - inputs: - command: 'custom' - custom: 'tool' - arguments: 'install --global SignClient --version 1.3.155' - task: GoTool@0 - displayName: "Install Go 1.20" + displayName: "Install Go" inputs: - version: '1.20.6' + version: '1.22' - pwsh: | choco install -y make @@ -39,24 +33,14 @@ stages: #we are not calling make windows_installer because we want to sign the binaries before they are added to the MSI script: | make build BUILD_RE2_WASM=1 - - task: AzureKeyVault@2 - inputs: - azureSubscription: 'Azure subscription 1(8a93ab40-7e99-445e-ad47-0f6a3e2ef546)' - KeyVaultName: 'CodeSigningSecrets' - SecretsFilter: 'CodeSigningUser,CodeSigningPassword' - RunAsPreJob: false - - - task: DownloadSecureFile@1 - inputs: - secureFile: appsettings.json - - - pwsh: | - SignClient.exe Sign --name "crowdsec-binaries" ` - --input "**/*.exe" --config (Join-Path -Path $(Agent.TempDirectory) -ChildPath "appsettings.json") ` - --user $(CodeSigningUser) --secret '$(CodeSigningPassword)' - displayName: "Sign Crowdsec binaries + plugins" + - pwsh: | $build_version=$env:BUILD_SOURCEBRANCHNAME + #Override the version if it's set in the pipeline + if ( ${env:USERBUILDVERSION} -ne "") + { + $build_version = ${env:USERBUILDVERSION} + } if ($build_version.StartsWith("v")) { $build_version = $build_version.Substring(1) @@ -69,35 +53,112 @@ stages: displayName: GetCrowdsecVersion name: GetCrowdsecVersion - pwsh: | - .\make_installer.ps1 -version '$(GetCrowdsecVersion.BuildVersion)' + Get-ChildItem -Path .\cmd -Directory | ForEach-Object { + $dirName = $_.Name + Get-ChildItem -Path .\cmd\$dirName -File -Filter '*.exe' | ForEach-Object { + $fileName = $_.Name + $destDir = Join-Path $(Build.ArtifactStagingDirectory) cmd\$dirName + New-Item -ItemType Directory -Path $destDir -Force + Copy-Item -Path .\cmd\$dirName\$fileName -Destination $destDir + } + } + displayName: "Copy binaries to staging directory" + - task: PublishPipelineArtifact@1 + inputs: + targetPath: '$(Build.ArtifactStagingDirectory)' + artifact: 'unsigned_binaries' + displayName: "Upload binaries artifact" + + - stage: Sign + dependsOn: Build + variables: + - group: 'FOSS Build Variables' + - name: BuildVersion + value: $[ stageDependencies.Build.Build.outputs['GetCrowdsecVersion.BuildVersion'] ] + condition: succeeded() + jobs: + - job: Sign + displayName: "Sign" + steps: + - download: current + artifact: unsigned_binaries + displayName: "Download binaries artifact" + - task: CopyFiles@2 + inputs: + SourceFolder: '$(Pipeline.Workspace)/unsigned_binaries' + TargetFolder: '$(Build.SourcesDirectory)' + displayName: "Copy binaries to workspace" + - task: DotNetCoreCLI@2 + displayName: "Install SignTool tool" + inputs: + command: 'custom' + custom: 'tool' + arguments: install --global sign --version 0.9.0-beta.23127.3 + - task: AzureKeyVault@2 + displayName: "Get signing parameters" + inputs: + azureSubscription: "Azure subscription" + KeyVaultName: "$(KeyVaultName)" + SecretsFilter: "TenantId,ClientId,ClientSecret,Certificate,KeyVaultUrl" + - pwsh: | + sign code azure-key-vault ` + "**/*.exe" ` + --base-directory "$(Build.SourcesDirectory)/cmd/" ` + --publisher-name "CrowdSec" ` + --description "CrowdSec" ` + --description-url "https://github.com/crowdsecurity/crowdsec" ` + --azure-key-vault-tenant-id "$(TenantId)" ` + --azure-key-vault-client-id "$(ClientId)" ` + --azure-key-vault-client-secret "$(ClientSecret)" ` + --azure-key-vault-certificate "$(Certificate)" ` + --azure-key-vault-url "$(KeyVaultUrl)" + displayName: "Sign crowdsec binaries" + - pwsh: | + .\make_installer.ps1 -version '$(BuildVersion)' displayName: "Build Crowdsec MSI" name: BuildMSI - - pwsh: | - .\make_chocolatey.ps1 -version '$(GetCrowdsecVersion.BuildVersion)' + .\make_chocolatey.ps1 -version '$(BuildVersion)' displayName: "Build Chocolatey nupkg" - - pwsh: | - SignClient.exe Sign --name "crowdsec-msi" ` - --input "*.msi" --config (Join-Path -Path $(Agent.TempDirectory) -ChildPath "appsettings.json") ` - --user $(CodeSigningUser) --secret '$(CodeSigningPassword)' - displayName: "Sign Crowdsec MSI" - - - task: PublishBuildArtifacts@1 + sign code azure-key-vault ` + "*.msi" ` + --base-directory "$(Build.SourcesDirectory)" ` + --publisher-name "CrowdSec" ` + --description "CrowdSec" ` + --description-url "https://github.com/crowdsecurity/crowdsec" ` + --azure-key-vault-tenant-id "$(TenantId)" ` + --azure-key-vault-client-id "$(ClientId)" ` + --azure-key-vault-client-secret "$(ClientSecret)" ` + --azure-key-vault-certificate "$(Certificate)" ` + --azure-key-vault-url "$(KeyVaultUrl)" + displayName: "Sign MSI package" + - pwsh: | + sign code azure-key-vault ` + "*.nupkg" ` + --base-directory "$(Build.SourcesDirectory)" ` + --publisher-name "CrowdSec" ` + --description "CrowdSec" ` + --description-url "https://github.com/crowdsecurity/crowdsec" ` + --azure-key-vault-tenant-id "$(TenantId)" ` + --azure-key-vault-client-id "$(ClientId)" ` + --azure-key-vault-client-secret "$(ClientSecret)" ` + --azure-key-vault-certificate "$(Certificate)" ` + --azure-key-vault-url "$(KeyVaultUrl)" + displayName: "Sign nuget package" + - task: PublishPipelineArtifact@1 inputs: - PathtoPublish: '$(Build.Repository.LocalPath)\\crowdsec_$(GetCrowdsecVersion.BuildVersion).msi' - ArtifactName: 'crowdsec.msi' - publishLocation: 'Container' - displayName: "Upload MSI artifact" - - - task: PublishBuildArtifacts@1 + targetPath: '$(Build.SourcesDirectory)/crowdsec_$(BuildVersion).msi' + artifact: 'signed_msi_package' + displayName: "Upload signed MSI artifact" + - task: PublishPipelineArtifact@1 inputs: - PathtoPublish: '$(Build.Repository.LocalPath)\\windows\\Chocolatey\\crowdsec\\crowdsec.$(GetCrowdsecVersion.BuildVersion).nupkg' - ArtifactName: 'crowdsec.nupkg' - publishLocation: 'Container' - displayName: "Upload nupkg artifact" + targetPath: '$(Build.SourcesDirectory)/crowdsec.$(BuildVersion).nupkg' + artifact: 'signed_nuget_package' + displayName: "Upload signed nuget artifact" + - stage: Publish - dependsOn: Build + dependsOn: Sign jobs: - deployment: "Publish" displayName: "Publish to GitHub" @@ -119,8 +180,7 @@ stages: assetUploadMode: 'replace' addChangeLog: false isPreRelease: true #we force prerelease because the pipeline is invoked on tag creation, which happens when we do a prerelease - #the .. is an ugly hack, but I can't find the var that gives D:\a\1 ... assets: | - $(Build.ArtifactStagingDirectory)\..\crowdsec.msi/*.msi - $(Build.ArtifactStagingDirectory)\..\crowdsec.nupkg/*.nupkg + $(Pipeline.Workspace)/signed_msi_package/*.msi + $(Pipeline.Workspace)/signed_nuget_package/*.nupkg condition: ne(variables['GetLatestPrelease.LatestPreRelease'], '') diff --git a/cmd/crowdsec-cli/Makefile b/cmd/crowdsec-cli/Makefile index f4d66157fd9..6d6e4da8dbd 100644 --- a/cmd/crowdsec-cli/Makefile +++ b/cmd/crowdsec-cli/Makefile @@ -4,32 +4,16 @@ ifeq ($(OS), Windows_NT) EXT = .exe endif -# Go parameters -GOCMD = go -GOBUILD = $(GOCMD) build -GOTEST = $(GOCMD) test +GO = go +GOBUILD = $(GO) build BINARY_NAME = cscli$(EXT) -PREFIX ?= "/" -BIN_PREFIX = $(PREFIX)"/usr/local/bin/" .PHONY: all all: clean build build: clean - $(GOBUILD) $(LD_OPTS) $(BUILD_VENDOR_FLAGS) -o $(BINARY_NAME) - -.PHONY: install -install: install-conf install-bin - -install-conf: - -install-bin: - @install -v -m 755 -D "$(BINARY_NAME)" "$(BIN_PREFIX)/$(BINARY_NAME)" || exit - -uninstall: - @$(RM) $(CSCLI_CONFIG) $(WIN_IGNORE_ERR) - @$(RM) $(BIN_PREFIX)$(BINARY_NAME) $(WIN_IGNORE_ERR) + $(GOBUILD) $(LD_OPTS) -o $(BINARY_NAME) clean: @$(RM) $(BINARY_NAME) $(WIN_IGNORE_ERR) diff --git a/cmd/crowdsec-cli/alerts.go b/cmd/crowdsec-cli/alerts.go deleted file mode 100644 index 6abe3db5afc..00000000000 --- a/cmd/crowdsec-cli/alerts.go +++ /dev/null @@ -1,551 +0,0 @@ -package main - -import ( - "context" - "encoding/csv" - "encoding/json" - "fmt" - "net/url" - "os" - "sort" - "strconv" - "strings" - "text/template" - "time" - - "github.com/fatih/color" - "github.com/go-openapi/strfmt" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - "gopkg.in/yaml.v2" - - "github.com/crowdsecurity/go-cs-lib/pkg/version" - - "github.com/crowdsecurity/crowdsec/pkg/apiclient" - "github.com/crowdsecurity/crowdsec/pkg/database" - "github.com/crowdsecurity/crowdsec/pkg/models" - "github.com/crowdsecurity/crowdsec/pkg/types" -) - -func DecisionsFromAlert(alert *models.Alert) string { - ret := "" - var decMap = make(map[string]int) - for _, decision := range alert.Decisions { - k := *decision.Type - if *decision.Simulated { - k = fmt.Sprintf("(simul)%s", k) - } - v := decMap[k] - decMap[k] = v + 1 - } - for k, v := range decMap { - if len(ret) > 0 { - ret += " " - } - ret += fmt.Sprintf("%s:%d", k, v) - } - return ret -} - -func DateFromAlert(alert *models.Alert) string { - ts, err := time.Parse(time.RFC3339, alert.CreatedAt) - if err != nil { - log.Infof("while parsing %s with %s : %s", alert.CreatedAt, time.RFC3339, err) - return alert.CreatedAt - } - return ts.Format(time.RFC822) -} - -func SourceFromAlert(alert *models.Alert) string { - - //more than one item, just number and scope - if len(alert.Decisions) > 1 { - return fmt.Sprintf("%d %ss (%s)", len(alert.Decisions), *alert.Decisions[0].Scope, *alert.Decisions[0].Origin) - } - - //fallback on single decision information - if len(alert.Decisions) == 1 { - return fmt.Sprintf("%s:%s", *alert.Decisions[0].Scope, *alert.Decisions[0].Value) - } - - //try to compose a human friendly version - if *alert.Source.Value != "" && *alert.Source.Scope != "" { - scope := "" - scope = fmt.Sprintf("%s:%s", *alert.Source.Scope, *alert.Source.Value) - extra := "" - if alert.Source.Cn != "" { - extra = alert.Source.Cn - } - if alert.Source.AsNumber != "" { - extra += fmt.Sprintf("/%s", alert.Source.AsNumber) - } - if alert.Source.AsName != "" { - extra += fmt.Sprintf("/%s", alert.Source.AsName) - } - - if extra != "" { - scope += " (" + extra + ")" - } - return scope - } - return "" -} - -func AlertsToTable(alerts *models.GetAlertsResponse, printMachine bool) error { - - if csConfig.Cscli.Output == "raw" { - csvwriter := csv.NewWriter(os.Stdout) - header := []string{"id", "scope", "value", "reason", "country", "as", "decisions", "created_at"} - if printMachine { - header = append(header, "machine") - } - err := csvwriter.Write(header) - if err != nil { - return err - } - for _, alertItem := range *alerts { - row := []string{ - fmt.Sprintf("%d", alertItem.ID), - *alertItem.Source.Scope, - *alertItem.Source.Value, - *alertItem.Scenario, - alertItem.Source.Cn, - alertItem.Source.GetAsNumberName(), - DecisionsFromAlert(alertItem), - *alertItem.StartAt, - } - if printMachine { - row = append(row, alertItem.MachineID) - } - err := csvwriter.Write(row) - if err != nil { - return err - } - } - csvwriter.Flush() - } else if csConfig.Cscli.Output == "json" { - x, _ := json.MarshalIndent(alerts, "", " ") - fmt.Printf("%s", string(x)) - } else if csConfig.Cscli.Output == "human" { - if len(*alerts) == 0 { - fmt.Println("No active alerts") - return nil - } - alertsTable(color.Output, alerts, printMachine) - } - return nil -} - -var alertTemplate = ` -################################################################################################ - - - ID : {{.ID}} - - Date : {{.CreatedAt}} - - Machine : {{.MachineID}} - - Simulation : {{.Simulated}} - - Reason : {{.Scenario}} - - Events Count : {{.EventsCount}} - - Scope:Value : {{.Source.Scope}}{{if .Source.Value}}:{{.Source.Value}}{{end}} - - Country : {{.Source.Cn}} - - AS : {{.Source.AsName}} - - Begin : {{.StartAt}} - - End : {{.StopAt}} - - UUID : {{.UUID}} - -` - -func DisplayOneAlert(alert *models.Alert, withDetail bool) error { - if csConfig.Cscli.Output == "human" { - tmpl, err := template.New("alert").Parse(alertTemplate) - if err != nil { - return err - } - err = tmpl.Execute(os.Stdout, alert) - if err != nil { - return err - } - - alertDecisionsTable(color.Output, alert) - - if len(alert.Meta) > 0 { - fmt.Printf("\n - Context :\n") - sort.Slice(alert.Meta, func(i, j int) bool { - return alert.Meta[i].Key < alert.Meta[j].Key - }) - table := newTable(color.Output) - table.SetRowLines(false) - table.SetHeaders("Key", "Value") - for _, meta := range alert.Meta { - var valSlice []string - if err := json.Unmarshal([]byte(meta.Value), &valSlice); err != nil { - return fmt.Errorf("unknown context value type '%s' : %s", meta.Value, err) - } - for _, value := range valSlice { - table.AddRow( - meta.Key, - value, - ) - } - } - table.Render() - } - - if withDetail { - fmt.Printf("\n - Events :\n") - for _, event := range alert.Events { - alertEventTable(color.Output, event) - } - } - } - return nil -} - -func NewAlertsCmd() *cobra.Command { - var cmdAlerts = &cobra.Command{ - Use: "alerts [action]", - Short: "Manage alerts", - Args: cobra.MinimumNArgs(1), - DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - var err error - if err := csConfig.LoadAPIClient(); err != nil { - return fmt.Errorf("loading api client: %w", err) - } - apiURL, err := url.Parse(csConfig.API.Client.Credentials.URL) - if err != nil { - return fmt.Errorf("parsing api url %s: %w", apiURL, err) - } - Client, err = apiclient.NewClient(&apiclient.Config{ - MachineID: csConfig.API.Client.Credentials.Login, - Password: strfmt.Password(csConfig.API.Client.Credentials.Password), - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), - URL: apiURL, - VersionPrefix: "v1", - }) - - if err != nil { - return fmt.Errorf("new api client: %w", err) - } - return nil - }, - } - - cmdAlerts.AddCommand(NewAlertsListCmd()) - cmdAlerts.AddCommand(NewAlertsInspectCmd()) - cmdAlerts.AddCommand(NewAlertsFlushCmd()) - cmdAlerts.AddCommand(NewAlertsDeleteCmd()) - - return cmdAlerts -} - -func NewAlertsListCmd() *cobra.Command { - var alertListFilter = apiclient.AlertsListOpts{ - ScopeEquals: new(string), - ValueEquals: new(string), - ScenarioEquals: new(string), - IPEquals: new(string), - RangeEquals: new(string), - Since: new(string), - Until: new(string), - TypeEquals: new(string), - IncludeCAPI: new(bool), - OriginEquals: new(string), - } - var limit = new(int) - contained := new(bool) - var printMachine bool - var cmdAlertsList = &cobra.Command{ - Use: "list [filters]", - Short: "List alerts", - Example: `cscli alerts list -cscli alerts list --ip 1.2.3.4 -cscli alerts list --range 1.2.3.0/24 -cscli alerts list -s crowdsecurity/ssh-bf -cscli alerts list --type ban`, - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - var err error - - if err := manageCliDecisionAlerts(alertListFilter.IPEquals, alertListFilter.RangeEquals, - alertListFilter.ScopeEquals, alertListFilter.ValueEquals); err != nil { - printHelp(cmd) - return err - } - if limit != nil { - alertListFilter.Limit = limit - } - - if *alertListFilter.Until == "" { - alertListFilter.Until = nil - } else if strings.HasSuffix(*alertListFilter.Until, "d") { - /*time.ParseDuration support hours 'h' as bigger unit, let's make the user's life easier*/ - realDuration := strings.TrimSuffix(*alertListFilter.Until, "d") - days, err := strconv.Atoi(realDuration) - if err != nil { - printHelp(cmd) - return fmt.Errorf("can't parse duration %s, valid durations format: 1d, 4h, 4h15m", *alertListFilter.Until) - } - *alertListFilter.Until = fmt.Sprintf("%d%s", days*24, "h") - } - if *alertListFilter.Since == "" { - alertListFilter.Since = nil - } else if strings.HasSuffix(*alertListFilter.Since, "d") { - /*time.ParseDuration support hours 'h' as bigger unit, let's make the user's life easier*/ - realDuration := strings.TrimSuffix(*alertListFilter.Since, "d") - days, err := strconv.Atoi(realDuration) - if err != nil { - printHelp(cmd) - return fmt.Errorf("can't parse duration %s, valid durations format: 1d, 4h, 4h15m", *alertListFilter.Since) - } - *alertListFilter.Since = fmt.Sprintf("%d%s", days*24, "h") - } - - if *alertListFilter.IncludeCAPI { - *alertListFilter.Limit = 0 - } - - if *alertListFilter.TypeEquals == "" { - alertListFilter.TypeEquals = nil - } - if *alertListFilter.ScopeEquals == "" { - alertListFilter.ScopeEquals = nil - } - if *alertListFilter.ValueEquals == "" { - alertListFilter.ValueEquals = nil - } - if *alertListFilter.ScenarioEquals == "" { - alertListFilter.ScenarioEquals = nil - } - if *alertListFilter.IPEquals == "" { - alertListFilter.IPEquals = nil - } - if *alertListFilter.RangeEquals == "" { - alertListFilter.RangeEquals = nil - } - - if *alertListFilter.OriginEquals == "" { - alertListFilter.OriginEquals = nil - } - - if contained != nil && *contained { - alertListFilter.Contains = new(bool) - } - - alerts, _, err := Client.Alerts.List(context.Background(), alertListFilter) - if err != nil { - return fmt.Errorf("unable to list alerts: %v", err) - } - - err = AlertsToTable(alerts, printMachine) - if err != nil { - return fmt.Errorf("unable to list alerts: %v", err) - } - - return nil - }, - } - cmdAlertsList.Flags().SortFlags = false - cmdAlertsList.Flags().BoolVarP(alertListFilter.IncludeCAPI, "all", "a", false, "Include decisions from Central API") - cmdAlertsList.Flags().StringVar(alertListFilter.Until, "until", "", "restrict to alerts older than until (ie. 4h, 30d)") - cmdAlertsList.Flags().StringVar(alertListFilter.Since, "since", "", "restrict to alerts newer than since (ie. 4h, 30d)") - cmdAlertsList.Flags().StringVarP(alertListFilter.IPEquals, "ip", "i", "", "restrict to alerts from this source ip (shorthand for --scope ip --value )") - cmdAlertsList.Flags().StringVarP(alertListFilter.ScenarioEquals, "scenario", "s", "", "the scenario (ie. crowdsecurity/ssh-bf)") - cmdAlertsList.Flags().StringVarP(alertListFilter.RangeEquals, "range", "r", "", "restrict to alerts from this range (shorthand for --scope range --value )") - cmdAlertsList.Flags().StringVar(alertListFilter.TypeEquals, "type", "", "restrict to alerts with given decision type (ie. ban, captcha)") - cmdAlertsList.Flags().StringVar(alertListFilter.ScopeEquals, "scope", "", "restrict to alerts of this scope (ie. ip,range)") - cmdAlertsList.Flags().StringVarP(alertListFilter.ValueEquals, "value", "v", "", "the value to match for in the specified scope") - cmdAlertsList.Flags().StringVar(alertListFilter.OriginEquals, "origin", "", fmt.Sprintf("the value to match for the specified origin (%s ...)", strings.Join(types.GetOrigins(), ","))) - cmdAlertsList.Flags().BoolVar(contained, "contained", false, "query decisions contained by range") - cmdAlertsList.Flags().BoolVarP(&printMachine, "machine", "m", false, "print machines that sent alerts") - cmdAlertsList.Flags().IntVarP(limit, "limit", "l", 50, "limit size of alerts list table (0 to view all alerts)") - - return cmdAlertsList -} - -func NewAlertsDeleteCmd() *cobra.Command { - var ActiveDecision *bool - var AlertDeleteAll bool - var delAlertByID string - contained := new(bool) - var alertDeleteFilter = apiclient.AlertsDeleteOpts{ - ScopeEquals: new(string), - ValueEquals: new(string), - ScenarioEquals: new(string), - IPEquals: new(string), - RangeEquals: new(string), - } - var cmdAlertsDelete = &cobra.Command{ - Use: "delete [filters] [--all]", - Short: `Delete alerts -/!\ This command can be use only on the same machine than the local API.`, - Example: `cscli alerts delete --ip 1.2.3.4 -cscli alerts delete --range 1.2.3.0/24 -cscli alerts delete -s crowdsecurity/ssh-bf"`, - DisableAutoGenTag: true, - Aliases: []string{"remove"}, - Args: cobra.ExactArgs(0), - PreRunE: func(cmd *cobra.Command, args []string) error { - if AlertDeleteAll { - return nil - } - if *alertDeleteFilter.ScopeEquals == "" && *alertDeleteFilter.ValueEquals == "" && - *alertDeleteFilter.ScenarioEquals == "" && *alertDeleteFilter.IPEquals == "" && - *alertDeleteFilter.RangeEquals == "" && delAlertByID == "" { - _ = cmd.Usage() - return fmt.Errorf("at least one filter or --all must be specified") - } - - return nil - }, - RunE: func(cmd *cobra.Command, args []string) error { - var err error - - if !AlertDeleteAll { - if err := manageCliDecisionAlerts(alertDeleteFilter.IPEquals, alertDeleteFilter.RangeEquals, - alertDeleteFilter.ScopeEquals, alertDeleteFilter.ValueEquals); err != nil { - printHelp(cmd) - return err - } - if ActiveDecision != nil { - alertDeleteFilter.ActiveDecisionEquals = ActiveDecision - } - - if *alertDeleteFilter.ScopeEquals == "" { - alertDeleteFilter.ScopeEquals = nil - } - if *alertDeleteFilter.ValueEquals == "" { - alertDeleteFilter.ValueEquals = nil - } - if *alertDeleteFilter.ScenarioEquals == "" { - alertDeleteFilter.ScenarioEquals = nil - } - if *alertDeleteFilter.IPEquals == "" { - alertDeleteFilter.IPEquals = nil - } - if *alertDeleteFilter.RangeEquals == "" { - alertDeleteFilter.RangeEquals = nil - } - if contained != nil && *contained { - alertDeleteFilter.Contains = new(bool) - } - limit := 0 - alertDeleteFilter.Limit = &limit - } else { - limit := 0 - alertDeleteFilter = apiclient.AlertsDeleteOpts{Limit: &limit} - } - - var alerts *models.DeleteAlertsResponse - if delAlertByID == "" { - alerts, _, err = Client.Alerts.Delete(context.Background(), alertDeleteFilter) - if err != nil { - return fmt.Errorf("unable to delete alerts : %v", err) - } - } else { - alerts, _, err = Client.Alerts.DeleteOne(context.Background(), delAlertByID) - if err != nil { - return fmt.Errorf("unable to delete alert: %v", err) - } - } - log.Infof("%s alert(s) deleted", alerts.NbDeleted) - - return nil - }, - } - cmdAlertsDelete.Flags().SortFlags = false - cmdAlertsDelete.Flags().StringVar(alertDeleteFilter.ScopeEquals, "scope", "", "the scope (ie. ip,range)") - cmdAlertsDelete.Flags().StringVarP(alertDeleteFilter.ValueEquals, "value", "v", "", "the value to match for in the specified scope") - cmdAlertsDelete.Flags().StringVarP(alertDeleteFilter.ScenarioEquals, "scenario", "s", "", "the scenario (ie. crowdsecurity/ssh-bf)") - cmdAlertsDelete.Flags().StringVarP(alertDeleteFilter.IPEquals, "ip", "i", "", "Source ip (shorthand for --scope ip --value )") - cmdAlertsDelete.Flags().StringVarP(alertDeleteFilter.RangeEquals, "range", "r", "", "Range source ip (shorthand for --scope range --value )") - cmdAlertsDelete.Flags().StringVar(&delAlertByID, "id", "", "alert ID") - cmdAlertsDelete.Flags().BoolVarP(&AlertDeleteAll, "all", "a", false, "delete all alerts") - cmdAlertsDelete.Flags().BoolVar(contained, "contained", false, "query decisions contained by range") - return cmdAlertsDelete -} - -func NewAlertsInspectCmd() *cobra.Command { - var details bool - var cmdAlertsInspect = &cobra.Command{ - Use: `inspect "alert_id"`, - Short: `Show info about an alert`, - Example: `cscli alerts inspect 123`, - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - if len(args) == 0 { - printHelp(cmd) - return fmt.Errorf("missing alert_id") - } - for _, alertID := range args { - id, err := strconv.Atoi(alertID) - if err != nil { - return fmt.Errorf("bad alert id %s", alertID) - } - alert, _, err := Client.Alerts.GetByID(context.Background(), id) - if err != nil { - return fmt.Errorf("can't find alert with id %s: %s", alertID, err) - } - switch csConfig.Cscli.Output { - case "human": - if err := DisplayOneAlert(alert, details); err != nil { - continue - } - case "json": - data, err := json.MarshalIndent(alert, "", " ") - if err != nil { - return fmt.Errorf("unable to marshal alert with id %s: %s", alertID, err) - } - fmt.Printf("%s\n", string(data)) - case "raw": - data, err := yaml.Marshal(alert) - if err != nil { - return fmt.Errorf("unable to marshal alert with id %s: %s", alertID, err) - } - fmt.Printf("%s\n", string(data)) - } - } - - return nil - }, - } - cmdAlertsInspect.Flags().SortFlags = false - cmdAlertsInspect.Flags().BoolVarP(&details, "details", "d", false, "show alerts with events") - - return cmdAlertsInspect -} - -func NewAlertsFlushCmd() *cobra.Command { - var maxItems int - var maxAge string - var cmdAlertsFlush = &cobra.Command{ - Use: `flush`, - Short: `Flush alerts -/!\ This command can be used only on the same machine than the local API`, - Example: `cscli alerts flush --max-items 1000 --max-age 7d`, - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - var err error - if err := csConfig.LoadAPIServer(); err != nil || csConfig.DisableAPI { - return fmt.Errorf("local API is disabled, please run this command on the local API machine") - } - dbClient, err = database.NewClient(csConfig.DbConfig) - if err != nil { - return fmt.Errorf("unable to create new database client: %s", err) - } - log.Info("Flushing alerts. !! This may take a long time !!") - err = dbClient.FlushAlerts(maxAge, maxItems) - if err != nil { - return fmt.Errorf("unable to flush alerts: %s", err) - } - log.Info("Alerts flushed") - - return nil - }, - } - - cmdAlertsFlush.Flags().SortFlags = false - cmdAlertsFlush.Flags().IntVar(&maxItems, "max-items", 5000, "Maximum number of alert items to keep in the database") - cmdAlertsFlush.Flags().StringVar(&maxAge, "max-age", "7d", "Maximum age of alert items to keep in the database") - - return cmdAlertsFlush -} diff --git a/cmd/crowdsec-cli/ask/ask.go b/cmd/crowdsec-cli/ask/ask.go new file mode 100644 index 00000000000..484ccb30c8a --- /dev/null +++ b/cmd/crowdsec-cli/ask/ask.go @@ -0,0 +1,20 @@ +package ask + +import ( + "github.com/AlecAivazis/survey/v2" +) + +func YesNo(message string, defaultAnswer bool) (bool, error) { + var answer bool + + prompt := &survey.Confirm{ + Message: message, + Default: defaultAnswer, + } + + if err := survey.AskOne(prompt, &answer); err != nil { + return defaultAnswer, err + } + + return answer, nil +} diff --git a/cmd/crowdsec-cli/bouncers.go b/cmd/crowdsec-cli/bouncers.go deleted file mode 100644 index 96fc8c5a206..00000000000 --- a/cmd/crowdsec-cli/bouncers.go +++ /dev/null @@ -1,219 +0,0 @@ -package main - -import ( - "encoding/csv" - "encoding/json" - "fmt" - "io" - "strings" - "time" - - "github.com/fatih/color" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - "golang.org/x/exp/slices" - - middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" - "github.com/crowdsecurity/crowdsec/pkg/database" - "github.com/crowdsecurity/crowdsec/pkg/types" -) - -func getBouncers(out io.Writer, dbClient *database.Client) error { - bouncers, err := dbClient.ListBouncers() - if err != nil { - return fmt.Errorf("unable to list bouncers: %s", err) - } - if csConfig.Cscli.Output == "human" { - getBouncersTable(out, bouncers) - } else if csConfig.Cscli.Output == "json" { - enc := json.NewEncoder(out) - enc.SetIndent("", " ") - if err := enc.Encode(bouncers); err != nil { - return fmt.Errorf("failed to unmarshal: %w", err) - } - return nil - } else if csConfig.Cscli.Output == "raw" { - csvwriter := csv.NewWriter(out) - err := csvwriter.Write([]string{"name", "ip", "revoked", "last_pull", "type", "version", "auth_type"}) - if err != nil { - return fmt.Errorf("failed to write raw header: %w", err) - } - for _, b := range bouncers { - var revoked string - if !b.Revoked { - revoked = "validated" - } else { - revoked = "pending" - } - err := csvwriter.Write([]string{b.Name, b.IPAddress, revoked, b.LastPull.Format(time.RFC3339), b.Type, b.Version, b.AuthType}) - if err != nil { - return fmt.Errorf("failed to write raw: %w", err) - } - } - csvwriter.Flush() - } - return nil -} - -func NewBouncersListCmd() *cobra.Command { - cmdBouncersList := &cobra.Command{ - Use: "list", - Short: "List bouncers", - Long: `List bouncers`, - Example: `cscli bouncers list`, - Args: cobra.ExactArgs(0), - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, arg []string) error { - err := getBouncers(color.Output, dbClient) - if err != nil { - return fmt.Errorf("unable to list bouncers: %s", err) - } - return nil - }, - } - - return cmdBouncersList -} - -func runBouncersAdd(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - keyLength, err := flags.GetInt("length") - if err != nil { - return err - } - - key, err := flags.GetString("key") - if err != nil { - return err - } - - keyName := args[0] - var apiKey string - - if keyName == "" { - return fmt.Errorf("please provide a name for the api key") - } - apiKey = key - if key == "" { - apiKey, err = middlewares.GenerateAPIKey(keyLength) - } - if err != nil { - return fmt.Errorf("unable to generate api key: %s", err) - } - _, err = dbClient.CreateBouncer(keyName, "", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType) - if err != nil { - return fmt.Errorf("unable to create bouncer: %s", err) - } - - if csConfig.Cscli.Output == "human" { - fmt.Printf("API key for '%s':\n\n", keyName) - fmt.Printf(" %s\n\n", apiKey) - fmt.Print("Please keep this key since you will not be able to retrieve it!\n") - } else if csConfig.Cscli.Output == "raw" { - fmt.Printf("%s", apiKey) - } else if csConfig.Cscli.Output == "json" { - j, err := json.Marshal(apiKey) - if err != nil { - return fmt.Errorf("unable to marshal api key") - } - fmt.Printf("%s", string(j)) - } - - return nil -} - -func NewBouncersAddCmd() *cobra.Command { - cmdBouncersAdd := &cobra.Command{ - Use: "add MyBouncerName [--length 16]", - Short: "add bouncer", - Long: `add bouncer`, - Example: `cscli bouncers add MyBouncerName -cscli bouncers add MyBouncerName -l 24 -cscli bouncers add MyBouncerName -k `, - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - RunE: runBouncersAdd, - } - - flags := cmdBouncersAdd.Flags() - - flags.IntP("length", "l", 16, "length of the api key") - flags.StringP("key", "k", "", "api key for the bouncer") - - return cmdBouncersAdd -} - -func runBouncersDelete(cmd *cobra.Command, args []string) error { - for _, bouncerID := range args { - err := dbClient.DeleteBouncer(bouncerID) - if err != nil { - return fmt.Errorf("unable to delete bouncer '%s': %s", bouncerID, err) - } - log.Infof("bouncer '%s' deleted successfully", bouncerID) - } - - return nil -} - -func NewBouncersDeleteCmd() *cobra.Command { - cmdBouncersDelete := &cobra.Command{ - Use: "delete MyBouncerName", - Short: "delete bouncer", - Args: cobra.MinimumNArgs(1), - Aliases: []string{"remove"}, - DisableAutoGenTag: true, - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - var err error - dbClient, err = getDBClient() - if err != nil { - cobra.CompError("unable to create new database client: " + err.Error()) - return nil, cobra.ShellCompDirectiveNoFileComp - } - bouncers, err := dbClient.ListBouncers() - if err != nil { - cobra.CompError("unable to list bouncers " + err.Error()) - } - ret := make([]string, 0) - for _, bouncer := range bouncers { - if strings.Contains(bouncer.Name, toComplete) && !slices.Contains(args, bouncer.Name) { - ret = append(ret, bouncer.Name) - } - } - return ret, cobra.ShellCompDirectiveNoFileComp - }, - RunE: runBouncersDelete, - } - - return cmdBouncersDelete -} - -func NewBouncersCmd() *cobra.Command { - var cmdBouncers = &cobra.Command{ - Use: "bouncers [action]", - Short: "Manage bouncers [requires local API]", - Long: `To list/add/delete bouncers. -Note: This command requires database direct access, so is intended to be run on Local API/master. -`, - Args: cobra.MinimumNArgs(1), - Aliases: []string{"bouncer"}, - DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - var err error - if err := csConfig.LoadAPIServer(); err != nil || csConfig.DisableAPI { - return fmt.Errorf("local API is disabled, please run this command on the local API machine") - } - dbClient, err = database.NewClient(csConfig.DbConfig) - if err != nil { - return fmt.Errorf("unable to create new database client: %s", err) - } - return nil - }, - } - - cmdBouncers.AddCommand(NewBouncersListCmd()) - cmdBouncers.AddCommand(NewBouncersAddCmd()) - cmdBouncers.AddCommand(NewBouncersDeleteCmd()) - - return cmdBouncers -} diff --git a/cmd/crowdsec-cli/bouncers_table.go b/cmd/crowdsec-cli/bouncers_table.go deleted file mode 100644 index 0ea725f5598..00000000000 --- a/cmd/crowdsec-cli/bouncers_table.go +++ /dev/null @@ -1,31 +0,0 @@ -package main - -import ( - "io" - "time" - - "github.com/aquasecurity/table" - "github.com/enescakir/emoji" - - "github.com/crowdsecurity/crowdsec/pkg/database/ent" -) - -func getBouncersTable(out io.Writer, bouncers []*ent.Bouncer) { - t := newLightTable(out) - t.SetHeaders("Name", "IP Address", "Valid", "Last API pull", "Type", "Version", "Auth Type") - t.SetHeaderAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) - - for _, b := range bouncers { - var revoked string - if !b.Revoked { - revoked = emoji.CheckMark.String() - } else { - revoked = emoji.Prohibited.String() - } - - t.AddRow(b.Name, b.IPAddress, revoked, b.LastPull.Format(time.RFC3339), b.Type, b.Version, b.AuthType) - } - - t.Render() -} diff --git a/cmd/crowdsec-cli/capi.go b/cmd/crowdsec-cli/capi.go deleted file mode 100644 index af6e9c2e86f..00000000000 --- a/cmd/crowdsec-cli/capi.go +++ /dev/null @@ -1,191 +0,0 @@ -package main - -import ( - "context" - "fmt" - "net/url" - "os" - - "github.com/go-openapi/strfmt" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - "gopkg.in/yaml.v2" - - "github.com/crowdsecurity/go-cs-lib/pkg/version" - - "github.com/crowdsecurity/crowdsec/pkg/apiclient" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" - "github.com/crowdsecurity/crowdsec/pkg/cwhub" - "github.com/crowdsecurity/crowdsec/pkg/fflag" - "github.com/crowdsecurity/crowdsec/pkg/models" - "github.com/crowdsecurity/crowdsec/pkg/types" -) - -const CAPIBaseURL string = "https://api.crowdsec.net/" -const CAPIURLPrefix = "v3" - -func NewCapiCmd() *cobra.Command { - var cmdCapi = &cobra.Command{ - Use: "capi [action]", - Short: "Manage interaction with Central API (CAPI)", - Args: cobra.MinimumNArgs(1), - DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if err := csConfig.LoadAPIServer(); err != nil { - return fmt.Errorf("local API is disabled, please run this command on the local API machine: %w", err) - } - if csConfig.DisableAPI { - return nil - } - if csConfig.API.Server.OnlineClient == nil { - log.Fatalf("no configuration for Central API in '%s'", *csConfig.FilePath) - } - - return nil - }, - } - - cmdCapi.AddCommand(NewCapiRegisterCmd()) - cmdCapi.AddCommand(NewCapiStatusCmd()) - - return cmdCapi -} - -func NewCapiRegisterCmd() *cobra.Command { - var capiUserPrefix string - var outputFile string - - var cmdCapiRegister = &cobra.Command{ - Use: "register", - Short: "Register to Central API (CAPI)", - Args: cobra.MinimumNArgs(0), - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - var err error - capiUser, err := generateID(capiUserPrefix) - if err != nil { - log.Fatalf("unable to generate machine id: %s", err) - } - password := strfmt.Password(generatePassword(passwordLength)) - apiurl, err := url.Parse(types.CAPIBaseURL) - if err != nil { - log.Fatalf("unable to parse api url %s : %s", types.CAPIBaseURL, err) - } - _, err = apiclient.RegisterClient(&apiclient.Config{ - MachineID: capiUser, - Password: password, - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), - URL: apiurl, - VersionPrefix: CAPIURLPrefix, - }, nil) - - if err != nil { - log.Fatalf("api client register ('%s'): %s", types.CAPIBaseURL, err) - } - log.Printf("Successfully registered to Central API (CAPI)") - - var dumpFile string - - if outputFile != "" { - dumpFile = outputFile - } else if csConfig.API.Server.OnlineClient.CredentialsFilePath != "" { - dumpFile = csConfig.API.Server.OnlineClient.CredentialsFilePath - } else { - dumpFile = "" - } - apiCfg := csconfig.ApiCredentialsCfg{ - Login: capiUser, - Password: password.String(), - URL: types.CAPIBaseURL, - } - if fflag.PapiClient.IsEnabled() { - apiCfg.PapiURL = types.PAPIBaseURL - } - apiConfigDump, err := yaml.Marshal(apiCfg) - if err != nil { - log.Fatalf("unable to marshal api credentials: %s", err) - } - if dumpFile != "" { - err = os.WriteFile(dumpFile, apiConfigDump, 0600) - if err != nil { - log.Fatalf("write api credentials in '%s' failed: %s", dumpFile, err) - } - log.Printf("Central API credentials dumped to '%s'", dumpFile) - } else { - fmt.Printf("%s\n", string(apiConfigDump)) - } - - log.Warning(ReloadMessage()) - }, - } - cmdCapiRegister.Flags().StringVarP(&outputFile, "file", "f", "", "output file destination") - cmdCapiRegister.Flags().StringVar(&capiUserPrefix, "schmilblick", "", "set a schmilblick (use in tests only)") - if err := cmdCapiRegister.Flags().MarkHidden("schmilblick"); err != nil { - log.Fatalf("failed to hide flag: %s", err) - } - - return cmdCapiRegister -} - -func NewCapiStatusCmd() *cobra.Command { - var cmdCapiStatus = &cobra.Command{ - Use: "status", - Short: "Check status with the Central API (CAPI)", - Args: cobra.MinimumNArgs(0), - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - var err error - if csConfig.API.Server == nil { - log.Fatal("There is no configuration on 'api.server:'") - } - if csConfig.API.Server.OnlineClient == nil { - log.Fatalf("Please provide credentials for the Central API (CAPI) in '%s'", csConfig.API.Server.OnlineClient.CredentialsFilePath) - } - - if csConfig.API.Server.OnlineClient.Credentials == nil { - log.Fatalf("no credentials for Central API (CAPI) in '%s'", csConfig.API.Server.OnlineClient.CredentialsFilePath) - } - - password := strfmt.Password(csConfig.API.Server.OnlineClient.Credentials.Password) - apiurl, err := url.Parse(csConfig.API.Server.OnlineClient.Credentials.URL) - if err != nil { - log.Fatalf("parsing api url ('%s'): %s", csConfig.API.Server.OnlineClient.Credentials.URL, err) - } - - if err := csConfig.LoadHub(); err != nil { - log.Fatal(err) - } - - if err := cwhub.GetHubIdx(csConfig.Hub); err != nil { - log.Info("Run 'sudo cscli hub update' to get the hub index") - log.Fatalf("Failed to load hub index : %s", err) - } - scenarios, err := cwhub.GetInstalledScenariosAsString() - if err != nil { - log.Fatalf("failed to get scenarios : %s", err) - } - if len(scenarios) == 0 { - log.Fatalf("no scenarios installed, abort") - } - - Client, err = apiclient.NewDefaultClient(apiurl, CAPIURLPrefix, fmt.Sprintf("crowdsec/%s", version.String()), nil) - if err != nil { - log.Fatalf("init default client: %s", err) - } - t := models.WatcherAuthRequest{ - MachineID: &csConfig.API.Server.OnlineClient.Credentials.Login, - Password: &password, - Scenarios: scenarios, - } - log.Infof("Loaded credentials from %s", csConfig.API.Server.OnlineClient.CredentialsFilePath) - log.Infof("Trying to authenticate with username %s on %s", csConfig.API.Server.OnlineClient.Credentials.Login, apiurl) - _, _, err = Client.Auth.AuthenticateWatcher(context.Background(), t) - if err != nil { - log.Fatalf("Failed to authenticate to Central API (CAPI) : %s", err) - } - log.Infof("You can successfully interact with Central API (CAPI)") - }, - } - - return cmdCapiStatus -} diff --git a/cmd/crowdsec-cli/clialert/alerts.go b/cmd/crowdsec-cli/clialert/alerts.go new file mode 100644 index 00000000000..75454e945f2 --- /dev/null +++ b/cmd/crowdsec-cli/clialert/alerts.go @@ -0,0 +1,603 @@ +package clialert + +import ( + "context" + "encoding/csv" + "encoding/json" + "errors" + "fmt" + "net/url" + "os" + "sort" + "strconv" + "strings" + "text/template" + + "github.com/fatih/color" + "github.com/go-openapi/strfmt" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/go-cs-lib/maptools" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/apiclient" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +func decisionsFromAlert(alert *models.Alert) string { + ret := "" + decMap := make(map[string]int) + + for _, decision := range alert.Decisions { + k := *decision.Type + if *decision.Simulated { + k = fmt.Sprintf("(simul)%s", k) + } + + v := decMap[k] + decMap[k] = v + 1 + } + + for _, key := range maptools.SortedKeys(decMap) { + if ret != "" { + ret += " " + } + + ret += fmt.Sprintf("%s:%d", key, decMap[key]) + } + + return ret +} + +func (cli *cliAlerts) alertsToTable(alerts *models.GetAlertsResponse, printMachine bool) error { + cfg := cli.cfg() + switch cfg.Cscli.Output { + case "raw": + csvwriter := csv.NewWriter(os.Stdout) + header := []string{"id", "scope", "value", "reason", "country", "as", "decisions", "created_at"} + + if printMachine { + header = append(header, "machine") + } + + if err := csvwriter.Write(header); err != nil { + return err + } + + for _, alertItem := range *alerts { + row := []string{ + strconv.FormatInt(alertItem.ID, 10), + *alertItem.Source.Scope, + *alertItem.Source.Value, + *alertItem.Scenario, + alertItem.Source.Cn, + alertItem.Source.GetAsNumberName(), + decisionsFromAlert(alertItem), + *alertItem.StartAt, + } + if printMachine { + row = append(row, alertItem.MachineID) + } + + if err := csvwriter.Write(row); err != nil { + return err + } + } + + csvwriter.Flush() + case "json": + if *alerts == nil { + // avoid returning "null" in json + // could be cleaner if we used slice of alerts directly + fmt.Println("[]") + return nil + } + + x, _ := json.MarshalIndent(alerts, "", " ") + fmt.Print(string(x)) + case "human": + if len(*alerts) == 0 { + fmt.Println("No active alerts") + return nil + } + + alertsTable(color.Output, cfg.Cscli.Color, alerts, printMachine) + } + + return nil +} + +func (cli *cliAlerts) displayOneAlert(alert *models.Alert, withDetail bool) error { + alertTemplate := ` +################################################################################################ + + - ID : {{.ID}} + - Date : {{.CreatedAt}} + - Machine : {{.MachineID}} + - Simulation : {{.Simulated}} + - Remediation : {{.Remediation}} + - Reason : {{.Scenario}} + - Events Count : {{.EventsCount}} + - Scope:Value : {{.Source.Scope}}{{if .Source.Value}}:{{.Source.Value}}{{end}} + - Country : {{.Source.Cn}} + - AS : {{.Source.AsName}} + - Begin : {{.StartAt}} + - End : {{.StopAt}} + - UUID : {{.UUID}} + +` + + tmpl, err := template.New("alert").Parse(alertTemplate) + if err != nil { + return err + } + + if err = tmpl.Execute(os.Stdout, alert); err != nil { + return err + } + + cfg := cli.cfg() + + alertDecisionsTable(color.Output, cfg.Cscli.Color, alert) + + if len(alert.Meta) > 0 { + fmt.Printf("\n - Context :\n") + sort.Slice(alert.Meta, func(i, j int) bool { + return alert.Meta[i].Key < alert.Meta[j].Key + }) + + table := cstable.New(color.Output, cfg.Cscli.Color) + table.SetRowLines(false) + table.SetHeaders("Key", "Value") + + for _, meta := range alert.Meta { + var valSlice []string + if err := json.Unmarshal([]byte(meta.Value), &valSlice); err != nil { + return fmt.Errorf("unknown context value type '%s': %w", meta.Value, err) + } + + for _, value := range valSlice { + table.AddRow( + meta.Key, + value, + ) + } + } + + table.Render() + } + + if withDetail { + fmt.Printf("\n - Events :\n") + + for _, event := range alert.Events { + alertEventTable(color.Output, cfg.Cscli.Color, event) + } + } + + return nil +} + +type configGetter func() *csconfig.Config + +type cliAlerts struct { + client *apiclient.ApiClient + cfg configGetter +} + +func New(getconfig configGetter) *cliAlerts { + return &cliAlerts{ + cfg: getconfig, + } +} + +func (cli *cliAlerts) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "alerts [action]", + Short: "Manage alerts", + Args: cobra.MinimumNArgs(1), + DisableAutoGenTag: true, + Aliases: []string{"alert"}, + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + if err := cfg.LoadAPIClient(); err != nil { + return fmt.Errorf("loading api client: %w", err) + } + apiURL, err := url.Parse(cfg.API.Client.Credentials.URL) + if err != nil { + return fmt.Errorf("parsing api url: %w", err) + } + + cli.client, err = apiclient.NewClient(&apiclient.Config{ + MachineID: cfg.API.Client.Credentials.Login, + Password: strfmt.Password(cfg.API.Client.Credentials.Password), + URL: apiURL, + VersionPrefix: "v1", + }) + if err != nil { + return fmt.Errorf("creating api client: %w", err) + } + + return nil + }, + } + + cmd.AddCommand(cli.newListCmd()) + cmd.AddCommand(cli.newInspectCmd()) + cmd.AddCommand(cli.newFlushCmd()) + cmd.AddCommand(cli.newDeleteCmd()) + + return cmd +} + +func (cli *cliAlerts) list(ctx context.Context, alertListFilter apiclient.AlertsListOpts, limit *int, contained *bool, printMachine bool) error { + var err error + + *alertListFilter.ScopeEquals, err = SanitizeScope(*alertListFilter.ScopeEquals, *alertListFilter.IPEquals, *alertListFilter.RangeEquals) + if err != nil { + return err + } + + if limit != nil { + alertListFilter.Limit = limit + } + + if *alertListFilter.Until == "" { + alertListFilter.Until = nil + } else if strings.HasSuffix(*alertListFilter.Until, "d") { + /*time.ParseDuration support hours 'h' as bigger unit, let's make the user's life easier*/ + realDuration := strings.TrimSuffix(*alertListFilter.Until, "d") + + days, err := strconv.Atoi(realDuration) + if err != nil { + return fmt.Errorf("can't parse duration %s, valid durations format: 1d, 4h, 4h15m", *alertListFilter.Until) + } + + *alertListFilter.Until = fmt.Sprintf("%d%s", days*24, "h") + } + + if *alertListFilter.Since == "" { + alertListFilter.Since = nil + } else if strings.HasSuffix(*alertListFilter.Since, "d") { + // time.ParseDuration support hours 'h' as bigger unit, let's make the user's life easier + realDuration := strings.TrimSuffix(*alertListFilter.Since, "d") + + days, err := strconv.Atoi(realDuration) + if err != nil { + return fmt.Errorf("can't parse duration %s, valid durations format: 1d, 4h, 4h15m", *alertListFilter.Since) + } + + *alertListFilter.Since = fmt.Sprintf("%d%s", days*24, "h") + } + + if *alertListFilter.IncludeCAPI { + *alertListFilter.Limit = 0 + } + + if *alertListFilter.TypeEquals == "" { + alertListFilter.TypeEquals = nil + } + + if *alertListFilter.ScopeEquals == "" { + alertListFilter.ScopeEquals = nil + } + + if *alertListFilter.ValueEquals == "" { + alertListFilter.ValueEquals = nil + } + + if *alertListFilter.ScenarioEquals == "" { + alertListFilter.ScenarioEquals = nil + } + + if *alertListFilter.IPEquals == "" { + alertListFilter.IPEquals = nil + } + + if *alertListFilter.RangeEquals == "" { + alertListFilter.RangeEquals = nil + } + + if *alertListFilter.OriginEquals == "" { + alertListFilter.OriginEquals = nil + } + + if contained != nil && *contained { + alertListFilter.Contains = new(bool) + } + + alerts, _, err := cli.client.Alerts.List(ctx, alertListFilter) + if err != nil { + return fmt.Errorf("unable to list alerts: %w", err) + } + + if err = cli.alertsToTable(alerts, printMachine); err != nil { + return fmt.Errorf("unable to list alerts: %w", err) + } + + return nil +} + +func (cli *cliAlerts) newListCmd() *cobra.Command { + alertListFilter := apiclient.AlertsListOpts{ + ScopeEquals: new(string), + ValueEquals: new(string), + ScenarioEquals: new(string), + IPEquals: new(string), + RangeEquals: new(string), + Since: new(string), + Until: new(string), + TypeEquals: new(string), + IncludeCAPI: new(bool), + OriginEquals: new(string), + } + + limit := new(int) + contained := new(bool) + + var printMachine bool + + cmd := &cobra.Command{ + Use: "list [filters]", + Short: "List alerts", + Example: `cscli alerts list +cscli alerts list --ip 1.2.3.4 +cscli alerts list --range 1.2.3.0/24 +cscli alerts list --origin lists +cscli alerts list -s crowdsecurity/ssh-bf +cscli alerts list --type ban`, + Long: `List alerts with optional filters`, + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.list(cmd.Context(), alertListFilter, limit, contained, printMachine) + }, + } + + flags := cmd.Flags() + flags.SortFlags = false + flags.BoolVarP(alertListFilter.IncludeCAPI, "all", "a", false, "Include decisions from Central API") + flags.StringVar(alertListFilter.Until, "until", "", "restrict to alerts older than until (ie. 4h, 30d)") + flags.StringVar(alertListFilter.Since, "since", "", "restrict to alerts newer than since (ie. 4h, 30d)") + flags.StringVarP(alertListFilter.IPEquals, "ip", "i", "", "restrict to alerts from this source ip (shorthand for --scope ip --value )") + flags.StringVarP(alertListFilter.ScenarioEquals, "scenario", "s", "", "the scenario (ie. crowdsecurity/ssh-bf)") + flags.StringVarP(alertListFilter.RangeEquals, "range", "r", "", "restrict to alerts from this range (shorthand for --scope range --value )") + flags.StringVar(alertListFilter.TypeEquals, "type", "", "restrict to alerts with given decision type (ie. ban, captcha)") + flags.StringVar(alertListFilter.ScopeEquals, "scope", "", "restrict to alerts of this scope (ie. ip,range)") + flags.StringVarP(alertListFilter.ValueEquals, "value", "v", "", "the value to match for in the specified scope") + flags.StringVar(alertListFilter.OriginEquals, "origin", "", fmt.Sprintf("the value to match for the specified origin (%s ...)", strings.Join(types.GetOrigins(), ","))) + flags.BoolVar(contained, "contained", false, "query decisions contained by range") + flags.BoolVarP(&printMachine, "machine", "m", false, "print machines that sent alerts") + flags.IntVarP(limit, "limit", "l", 50, "limit size of alerts list table (0 to view all alerts)") + + return cmd +} + +func (cli *cliAlerts) delete(ctx context.Context, delFilter apiclient.AlertsDeleteOpts, activeDecision *bool, deleteAll bool, delAlertByID string, contained *bool) error { + var err error + + if !deleteAll { + *delFilter.ScopeEquals, err = SanitizeScope(*delFilter.ScopeEquals, *delFilter.IPEquals, *delFilter.RangeEquals) + if err != nil { + return err + } + + if activeDecision != nil { + delFilter.ActiveDecisionEquals = activeDecision + } + + if *delFilter.ScopeEquals == "" { + delFilter.ScopeEquals = nil + } + + if *delFilter.ValueEquals == "" { + delFilter.ValueEquals = nil + } + + if *delFilter.ScenarioEquals == "" { + delFilter.ScenarioEquals = nil + } + + if *delFilter.IPEquals == "" { + delFilter.IPEquals = nil + } + + if *delFilter.RangeEquals == "" { + delFilter.RangeEquals = nil + } + + if contained != nil && *contained { + delFilter.Contains = new(bool) + } + + limit := 0 + delFilter.Limit = &limit + } else { + limit := 0 + delFilter = apiclient.AlertsDeleteOpts{Limit: &limit} + } + + var alerts *models.DeleteAlertsResponse + if delAlertByID == "" { + alerts, _, err = cli.client.Alerts.Delete(ctx, delFilter) + if err != nil { + return fmt.Errorf("unable to delete alerts: %w", err) + } + } else { + alerts, _, err = cli.client.Alerts.DeleteOne(ctx, delAlertByID) + if err != nil { + return fmt.Errorf("unable to delete alert: %w", err) + } + } + + log.Infof("%s alert(s) deleted", alerts.NbDeleted) + + return nil +} + +func (cli *cliAlerts) newDeleteCmd() *cobra.Command { + var ( + activeDecision *bool + deleteAll bool + delAlertByID string + ) + + delFilter := apiclient.AlertsDeleteOpts{ + ScopeEquals: new(string), + ValueEquals: new(string), + ScenarioEquals: new(string), + IPEquals: new(string), + RangeEquals: new(string), + } + + contained := new(bool) + + cmd := &cobra.Command{ + Use: "delete [filters] [--all]", + Short: `Delete alerts +/!\ This command can be use only on the same machine than the local API.`, + Example: `cscli alerts delete --ip 1.2.3.4 +cscli alerts delete --range 1.2.3.0/24 +cscli alerts delete -s crowdsecurity/ssh-bf"`, + DisableAutoGenTag: true, + Aliases: []string{"remove"}, + Args: cobra.ExactArgs(0), + PreRunE: func(cmd *cobra.Command, _ []string) error { + if deleteAll { + return nil + } + if *delFilter.ScopeEquals == "" && *delFilter.ValueEquals == "" && + *delFilter.ScenarioEquals == "" && *delFilter.IPEquals == "" && + *delFilter.RangeEquals == "" && delAlertByID == "" { + _ = cmd.Usage() + return errors.New("at least one filter or --all must be specified") + } + + return nil + }, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.delete(cmd.Context(), delFilter, activeDecision, deleteAll, delAlertByID, contained) + }, + } + + flags := cmd.Flags() + flags.SortFlags = false + flags.StringVar(delFilter.ScopeEquals, "scope", "", "the scope (ie. ip,range)") + flags.StringVarP(delFilter.ValueEquals, "value", "v", "", "the value to match for in the specified scope") + flags.StringVarP(delFilter.ScenarioEquals, "scenario", "s", "", "the scenario (ie. crowdsecurity/ssh-bf)") + flags.StringVarP(delFilter.IPEquals, "ip", "i", "", "Source ip (shorthand for --scope ip --value )") + flags.StringVarP(delFilter.RangeEquals, "range", "r", "", "Range source ip (shorthand for --scope range --value )") + flags.StringVar(&delAlertByID, "id", "", "alert ID") + flags.BoolVarP(&deleteAll, "all", "a", false, "delete all alerts") + flags.BoolVar(contained, "contained", false, "query decisions contained by range") + + return cmd +} + +func (cli *cliAlerts) inspect(ctx context.Context, details bool, alertIDs ...string) error { + cfg := cli.cfg() + + for _, alertID := range alertIDs { + id, err := strconv.Atoi(alertID) + if err != nil { + return fmt.Errorf("bad alert id %s", alertID) + } + + alert, _, err := cli.client.Alerts.GetByID(ctx, id) + if err != nil { + return fmt.Errorf("can't find alert with id %s: %w", alertID, err) + } + + switch cfg.Cscli.Output { + case "human": + if err := cli.displayOneAlert(alert, details); err != nil { + log.Warnf("unable to display alert with id %s: %s", alertID, err) + continue + } + case "json": + data, err := json.MarshalIndent(alert, "", " ") + if err != nil { + return fmt.Errorf("unable to serialize alert with id %s: %w", alertID, err) + } + + fmt.Printf("%s\n", string(data)) + case "raw": + data, err := yaml.Marshal(alert) + if err != nil { + return fmt.Errorf("unable to serialize alert with id %s: %w", alertID, err) + } + + fmt.Println(string(data)) + } + } + + return nil +} + +func (cli *cliAlerts) newInspectCmd() *cobra.Command { + var details bool + + cmd := &cobra.Command{ + Use: `inspect "alert_id"`, + Short: `Show info about an alert`, + Example: `cscli alerts inspect 123`, + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + if len(args) == 0 { + _ = cmd.Help() + return errors.New("missing alert_id") + } + return cli.inspect(cmd.Context(), details, args...) + }, + } + + cmd.Flags().SortFlags = false + cmd.Flags().BoolVarP(&details, "details", "d", false, "show alerts with events") + + return cmd +} + +func (cli *cliAlerts) newFlushCmd() *cobra.Command { + var ( + maxItems int + maxAge string + ) + + cmd := &cobra.Command{ + Use: `flush`, + Short: `Flush alerts +/!\ This command can be used only on the same machine than the local API`, + Example: `cscli alerts flush --max-items 1000 --max-age 7d`, + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + cfg := cli.cfg() + ctx := cmd.Context() + + if err := require.LAPI(cfg); err != nil { + return err + } + db, err := require.DBClient(ctx, cfg.DbConfig) + if err != nil { + return err + } + log.Info("Flushing alerts. !! This may take a long time !!") + err = db.FlushAlerts(ctx, maxAge, maxItems) + if err != nil { + return fmt.Errorf("unable to flush alerts: %w", err) + } + log.Info("Alerts flushed") + + return nil + }, + } + + cmd.Flags().SortFlags = false + cmd.Flags().IntVar(&maxItems, "max-items", 5000, "Maximum number of alert items to keep in the database") + cmd.Flags().StringVar(&maxAge, "max-age", "7d", "Maximum age of alert items to keep in the database") + + return cmd +} diff --git a/cmd/crowdsec-cli/clialert/sanitize.go b/cmd/crowdsec-cli/clialert/sanitize.go new file mode 100644 index 00000000000..87b110649da --- /dev/null +++ b/cmd/crowdsec-cli/clialert/sanitize.go @@ -0,0 +1,26 @@ +package clialert + +import ( + "fmt" + "net" + + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +// SanitizeScope validates ip and range and sets the scope accordingly to our case convention. +func SanitizeScope(scope, ip, ipRange string) (string, error) { + if ipRange != "" { + _, _, err := net.ParseCIDR(ipRange) + if err != nil { + return "", fmt.Errorf("%s is not a valid range", ipRange) + } + } + + if ip != "" { + if net.ParseIP(ip) == nil { + return "", fmt.Errorf("%s is not a valid ip", ip) + } + } + + return types.NormalizeScope(scope), nil +} diff --git a/cmd/crowdsec-cli/alerts_table.go b/cmd/crowdsec-cli/clialert/table.go similarity index 80% rename from cmd/crowdsec-cli/alerts_table.go rename to cmd/crowdsec-cli/clialert/table.go index ec457f3723e..1416e1e435c 100644 --- a/cmd/crowdsec-cli/alerts_table.go +++ b/cmd/crowdsec-cli/clialert/table.go @@ -1,4 +1,4 @@ -package main +package clialert import ( "fmt" @@ -9,16 +9,19 @@ import ( log "github.com/sirupsen/logrus" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" "github.com/crowdsecurity/crowdsec/pkg/models" ) -func alertsTable(out io.Writer, alerts *models.GetAlertsResponse, printMachine bool) { - t := newTable(out) +func alertsTable(out io.Writer, wantColor string, alerts *models.GetAlertsResponse, printMachine bool) { + t := cstable.New(out, wantColor) t.SetRowLines(false) + header := []string{"ID", "value", "reason", "country", "as", "decisions", "created_at"} if printMachine { header = append(header, "machine") } + t.SetHeaders(header...) for _, alertItem := range *alerts { @@ -35,7 +38,7 @@ func alertsTable(out io.Writer, alerts *models.GetAlertsResponse, printMachine b *alertItem.Scenario, alertItem.Source.Cn, alertItem.Source.GetAsNumberName(), - DecisionsFromAlert(alertItem), + decisionsFromAlert(alertItem), *alertItem.StartAt, } @@ -49,25 +52,30 @@ func alertsTable(out io.Writer, alerts *models.GetAlertsResponse, printMachine b t.Render() } -func alertDecisionsTable(out io.Writer, alert *models.Alert) { +func alertDecisionsTable(out io.Writer, wantColor string, alert *models.Alert) { foundActive := false - t := newTable(out) + t := cstable.New(out, wantColor) t.SetRowLines(false) t.SetHeaders("ID", "scope:value", "action", "expiration", "created_at") + for _, decision := range alert.Decisions { parsedDuration, err := time.ParseDuration(*decision.Duration) if err != nil { log.Error(err) } + expire := time.Now().UTC().Add(parsedDuration) if time.Now().UTC().After(expire) { continue } + foundActive = true scopeAndValue := *decision.Scope + if *decision.Value != "" { scopeAndValue += ":" + *decision.Value } + t.AddRow( strconv.Itoa(int(decision.ID)), scopeAndValue, @@ -76,16 +84,17 @@ func alertDecisionsTable(out io.Writer, alert *models.Alert) { alert.CreatedAt, ) } + if foundActive { fmt.Printf(" - Active Decisions :\n") t.Render() // Send output } } -func alertEventTable(out io.Writer, event *models.Event) { +func alertEventTable(out io.Writer, wantColor string, event *models.Event) { fmt.Fprintf(out, "\n- Date: %s\n", *event.Timestamp) - t := newTable(out) + t := cstable.New(out, wantColor) t.SetHeaders("Key", "Value") sort.Slice(event.Meta, func(i, j int) bool { return event.Meta[i].Key < event.Meta[j].Key diff --git a/cmd/crowdsec-cli/clibouncer/bouncers.go b/cmd/crowdsec-cli/clibouncer/bouncers.go new file mode 100644 index 00000000000..226fbb7e922 --- /dev/null +++ b/cmd/crowdsec-cli/clibouncer/bouncers.go @@ -0,0 +1,497 @@ +package clibouncer + +import ( + "context" + "encoding/csv" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "slices" + "strings" + "time" + + "github.com/fatih/color" + "github.com/jedib0t/go-pretty/v6/table" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/ask" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clientinfo" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/database" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer" + "github.com/crowdsecurity/crowdsec/pkg/emoji" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +type configGetter = func() *csconfig.Config + +type cliBouncers struct { + db *database.Client + cfg configGetter +} + +func New(cfg configGetter) *cliBouncers { + return &cliBouncers{ + cfg: cfg, + } +} + +func (cli *cliBouncers) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "bouncers [action]", + Short: "Manage bouncers [requires local API]", + Long: `To list/add/delete/prune bouncers. +Note: This command requires database direct access, so is intended to be run on Local API/master. +`, + Args: cobra.MinimumNArgs(1), + Aliases: []string{"bouncer"}, + DisableAutoGenTag: true, + PersistentPreRunE: func(cmd *cobra.Command, _ []string) error { + var err error + + cfg := cli.cfg() + + if err = require.LAPI(cfg); err != nil { + return err + } + + cli.db, err = require.DBClient(cmd.Context(), cfg.DbConfig) + if err != nil { + return err + } + + return nil + }, + } + + cmd.AddCommand(cli.newListCmd()) + cmd.AddCommand(cli.newAddCmd()) + cmd.AddCommand(cli.newDeleteCmd()) + cmd.AddCommand(cli.newPruneCmd()) + cmd.AddCommand(cli.newInspectCmd()) + + return cmd +} + +func (cli *cliBouncers) listHuman(out io.Writer, bouncers ent.Bouncers) { + t := cstable.NewLight(out, cli.cfg().Cscli.Color).Writer + t.AppendHeader(table.Row{"Name", "IP Address", "Valid", "Last API pull", "Type", "Version", "Auth Type"}) + + for _, b := range bouncers { + revoked := emoji.CheckMark + if b.Revoked { + revoked = emoji.Prohibited + } + + lastPull := "" + if b.LastPull != nil { + lastPull = b.LastPull.Format(time.RFC3339) + } + + t.AppendRow(table.Row{b.Name, b.IPAddress, revoked, lastPull, b.Type, b.Version, b.AuthType}) + } + + io.WriteString(out, t.Render()+"\n") +} + +// bouncerInfo contains only the data we want for inspect/list +type bouncerInfo struct { + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + Name string `json:"name"` + Revoked bool `json:"revoked"` + IPAddress string `json:"ip_address"` + Type string `json:"type"` + Version string `json:"version"` + LastPull *time.Time `json:"last_pull"` + AuthType string `json:"auth_type"` + OS string `json:"os,omitempty"` + Featureflags []string `json:"featureflags,omitempty"` +} + +func newBouncerInfo(b *ent.Bouncer) bouncerInfo { + return bouncerInfo{ + CreatedAt: b.CreatedAt, + UpdatedAt: b.UpdatedAt, + Name: b.Name, + Revoked: b.Revoked, + IPAddress: b.IPAddress, + Type: b.Type, + Version: b.Version, + LastPull: b.LastPull, + AuthType: b.AuthType, + OS: clientinfo.GetOSNameAndVersion(b), + Featureflags: clientinfo.GetFeatureFlagList(b), + } +} + +func (cli *cliBouncers) listCSV(out io.Writer, bouncers ent.Bouncers) error { + csvwriter := csv.NewWriter(out) + + if err := csvwriter.Write([]string{"name", "ip", "revoked", "last_pull", "type", "version", "auth_type"}); err != nil { + return fmt.Errorf("failed to write raw header: %w", err) + } + + for _, b := range bouncers { + valid := "validated" + if b.Revoked { + valid = "pending" + } + + lastPull := "" + if b.LastPull != nil { + lastPull = b.LastPull.Format(time.RFC3339) + } + + if err := csvwriter.Write([]string{b.Name, b.IPAddress, valid, lastPull, b.Type, b.Version, b.AuthType}); err != nil { + return fmt.Errorf("failed to write raw: %w", err) + } + } + + csvwriter.Flush() + + return nil +} + +func (cli *cliBouncers) List(ctx context.Context, out io.Writer, db *database.Client) error { + // XXX: must use the provided db object, the one in the struct might be nil + // (calling List directly skips the PersistentPreRunE) + + bouncers, err := db.ListBouncers(ctx) + if err != nil { + return fmt.Errorf("unable to list bouncers: %w", err) + } + + switch cli.cfg().Cscli.Output { + case "human": + cli.listHuman(out, bouncers) + case "json": + info := make([]bouncerInfo, 0, len(bouncers)) + for _, b := range bouncers { + info = append(info, newBouncerInfo(b)) + } + + enc := json.NewEncoder(out) + enc.SetIndent("", " ") + + if err := enc.Encode(info); err != nil { + return errors.New("failed to serialize") + } + + return nil + case "raw": + return cli.listCSV(out, bouncers) + } + + return nil +} + +func (cli *cliBouncers) newListCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "list", + Short: "list all bouncers within the database", + Example: `cscli bouncers list`, + Args: cobra.ExactArgs(0), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.List(cmd.Context(), color.Output, cli.db) + }, + } + + return cmd +} + +func (cli *cliBouncers) add(ctx context.Context, bouncerName string, key string) error { + var err error + + keyLength := 32 + + if key == "" { + key, err = middlewares.GenerateAPIKey(keyLength) + if err != nil { + return fmt.Errorf("unable to generate api key: %w", err) + } + } + + _, err = cli.db.CreateBouncer(ctx, bouncerName, "", middlewares.HashSHA512(key), types.ApiKeyAuthType) + if err != nil { + return fmt.Errorf("unable to create bouncer: %w", err) + } + + switch cli.cfg().Cscli.Output { + case "human": + fmt.Printf("API key for '%s':\n\n", bouncerName) + fmt.Printf(" %s\n\n", key) + fmt.Print("Please keep this key since you will not be able to retrieve it!\n") + case "raw": + fmt.Print(key) + case "json": + j, err := json.Marshal(key) + if err != nil { + return errors.New("unable to serialize api key") + } + + fmt.Print(string(j)) + } + + return nil +} + +func (cli *cliBouncers) newAddCmd() *cobra.Command { + var key string + + cmd := &cobra.Command{ + Use: "add MyBouncerName", + Short: "add a single bouncer to the database", + Example: `cscli bouncers add MyBouncerName +cscli bouncers add MyBouncerName --key `, + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.add(cmd.Context(), args[0], key) + }, + } + + flags := cmd.Flags() + flags.StringP("length", "l", "", "length of the api key") + _ = flags.MarkDeprecated("length", "use --key instead") + flags.StringVarP(&key, "key", "k", "", "api key for the bouncer") + + return cmd +} + +// validBouncerID returns a list of bouncer IDs for command completion +func (cli *cliBouncers) validBouncerID(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + var err error + + cfg := cli.cfg() + ctx := cmd.Context() + + // need to load config and db because PersistentPreRunE is not called for completions + + if err = require.LAPI(cfg); err != nil { + cobra.CompError("unable to list bouncers " + err.Error()) + return nil, cobra.ShellCompDirectiveNoFileComp + } + + cli.db, err = require.DBClient(ctx, cfg.DbConfig) + if err != nil { + cobra.CompError("unable to list bouncers " + err.Error()) + return nil, cobra.ShellCompDirectiveNoFileComp + } + + bouncers, err := cli.db.ListBouncers(ctx) + if err != nil { + cobra.CompError("unable to list bouncers " + err.Error()) + return nil, cobra.ShellCompDirectiveNoFileComp + } + + ret := []string{} + + for _, bouncer := range bouncers { + if strings.Contains(bouncer.Name, toComplete) && !slices.Contains(args, bouncer.Name) { + ret = append(ret, bouncer.Name) + } + } + + return ret, cobra.ShellCompDirectiveNoFileComp +} + +func (cli *cliBouncers) delete(ctx context.Context, bouncers []string, ignoreMissing bool) error { + for _, bouncerID := range bouncers { + if err := cli.db.DeleteBouncer(ctx, bouncerID); err != nil { + var notFoundErr *database.BouncerNotFoundError + if ignoreMissing && errors.As(err, ¬FoundErr) { + return nil + } + + return fmt.Errorf("unable to delete bouncer: %w", err) + } + + log.Infof("bouncer '%s' deleted successfully", bouncerID) + } + + return nil +} + +func (cli *cliBouncers) newDeleteCmd() *cobra.Command { + var ignoreMissing bool + + cmd := &cobra.Command{ + Use: "delete MyBouncerName", + Short: "delete bouncer(s) from the database", + Example: `cscli bouncers delete "bouncer1" "bouncer2"`, + Args: cobra.MinimumNArgs(1), + Aliases: []string{"remove"}, + DisableAutoGenTag: true, + ValidArgsFunction: cli.validBouncerID, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.delete(cmd.Context(), args, ignoreMissing) + }, + } + + flags := cmd.Flags() + flags.BoolVar(&ignoreMissing, "ignore-missing", false, "don't print errors if one or more bouncers don't exist") + + return cmd +} + +func (cli *cliBouncers) prune(ctx context.Context, duration time.Duration, force bool) error { + if duration < 2*time.Minute { + if yes, err := ask.YesNo( + "The duration you provided is less than 2 minutes. "+ + "This may remove active bouncers. Continue?", false); err != nil { + return err + } else if !yes { + fmt.Println("User aborted prune. No changes were made.") + return nil + } + } + + bouncers, err := cli.db.QueryBouncersInactiveSince(ctx, time.Now().UTC().Add(-duration)) + if err != nil { + return fmt.Errorf("unable to query bouncers: %w", err) + } + + if len(bouncers) == 0 { + fmt.Println("No bouncers to prune.") + return nil + } + + cli.listHuman(color.Output, bouncers) + + if !force { + if yes, err := ask.YesNo( + "You are about to PERMANENTLY remove the above bouncers from the database. "+ + "These will NOT be recoverable. Continue?", false); err != nil { + return err + } else if !yes { + fmt.Println("User aborted prune. No changes were made.") + return nil + } + } + + deleted, err := cli.db.BulkDeleteBouncers(ctx, bouncers) + if err != nil { + return fmt.Errorf("unable to prune bouncers: %w", err) + } + + fmt.Fprintf(os.Stderr, "Successfully deleted %d bouncers\n", deleted) + + return nil +} + +func (cli *cliBouncers) newPruneCmd() *cobra.Command { + var ( + duration time.Duration + force bool + ) + + const defaultDuration = 60 * time.Minute + + cmd := &cobra.Command{ + Use: "prune", + Short: "prune multiple bouncers from the database", + Args: cobra.NoArgs, + DisableAutoGenTag: true, + Example: `cscli bouncers prune -d 45m +cscli bouncers prune -d 45m --force`, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.prune(cmd.Context(), duration, force) + }, + } + + flags := cmd.Flags() + flags.DurationVarP(&duration, "duration", "d", defaultDuration, "duration of time since last pull") + flags.BoolVar(&force, "force", false, "force prune without asking for confirmation") + + return cmd +} + +func (cli *cliBouncers) inspectHuman(out io.Writer, bouncer *ent.Bouncer) { + t := cstable.NewLight(out, cli.cfg().Cscli.Color).Writer + + t.SetTitle("Bouncer: " + bouncer.Name) + + t.SetColumnConfigs([]table.ColumnConfig{ + {Number: 1, AutoMerge: true}, + }) + + lastPull := "" + if bouncer.LastPull != nil { + lastPull = bouncer.LastPull.String() + } + + t.AppendRows([]table.Row{ + {"Created At", bouncer.CreatedAt}, + {"Last Update", bouncer.UpdatedAt}, + {"Revoked?", bouncer.Revoked}, + {"IP Address", bouncer.IPAddress}, + {"Type", bouncer.Type}, + {"Version", bouncer.Version}, + {"Last Pull", lastPull}, + {"Auth type", bouncer.AuthType}, + {"OS", clientinfo.GetOSNameAndVersion(bouncer)}, + }) + + for _, ff := range clientinfo.GetFeatureFlagList(bouncer) { + t.AppendRow(table.Row{"Feature Flags", ff}) + } + + io.WriteString(out, t.Render()+"\n") +} + +func (cli *cliBouncers) inspect(bouncer *ent.Bouncer) error { + out := color.Output + outputFormat := cli.cfg().Cscli.Output + + switch outputFormat { + case "human": + cli.inspectHuman(out, bouncer) + case "json": + enc := json.NewEncoder(out) + enc.SetIndent("", " ") + + if err := enc.Encode(newBouncerInfo(bouncer)); err != nil { + return errors.New("failed to serialize") + } + + return nil + default: + return fmt.Errorf("output format '%s' not supported for this command", outputFormat) + } + + return nil +} + +func (cli *cliBouncers) newInspectCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "inspect [bouncer_name]", + Short: "inspect a bouncer by name", + Example: `cscli bouncers inspect "bouncer1"`, + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + ValidArgsFunction: cli.validBouncerID, + RunE: func(cmd *cobra.Command, args []string) error { + bouncerName := args[0] + + b, err := cli.db.Ent.Bouncer.Query(). + Where(bouncer.Name(bouncerName)). + Only(cmd.Context()) + if err != nil { + return fmt.Errorf("unable to read bouncer data '%s': %w", bouncerName, err) + } + + return cli.inspect(b) + }, + } + + return cmd +} diff --git a/cmd/crowdsec-cli/clicapi/capi.go b/cmd/crowdsec-cli/clicapi/capi.go new file mode 100644 index 00000000000..cba66f11104 --- /dev/null +++ b/cmd/crowdsec-cli/clicapi/capi.go @@ -0,0 +1,248 @@ +package clicapi + +import ( + "context" + "errors" + "fmt" + "io" + "net/url" + "os" + + "github.com/fatih/color" + "github.com/go-openapi/strfmt" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/idgen" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/reload" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/apiclient" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" + "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +type configGetter = func() *csconfig.Config + +type cliCapi struct { + cfg configGetter +} + +func New(cfg configGetter) *cliCapi { + return &cliCapi{ + cfg: cfg, + } +} + +func (cli *cliCapi) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "capi [action]", + Short: "Manage interaction with Central API (CAPI)", + Args: cobra.MinimumNArgs(1), + DisableAutoGenTag: true, + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + if err := require.LAPI(cfg); err != nil { + return err + } + + return require.CAPI(cfg) + }, + } + + cmd.AddCommand(cli.newRegisterCmd()) + cmd.AddCommand(cli.newStatusCmd()) + + return cmd +} + +func (cli *cliCapi) register(ctx context.Context, capiUserPrefix string, outputFile string) error { + cfg := cli.cfg() + + capiUser, err := idgen.GenerateMachineID(capiUserPrefix) + if err != nil { + return fmt.Errorf("unable to generate machine id: %w", err) + } + + password := strfmt.Password(idgen.GeneratePassword(idgen.PasswordLength)) + + apiurl, err := url.Parse(types.CAPIBaseURL) + if err != nil { + return fmt.Errorf("unable to parse api url %s: %w", types.CAPIBaseURL, err) + } + + _, err = apiclient.RegisterClient(ctx, &apiclient.Config{ + MachineID: capiUser, + Password: password, + URL: apiurl, + VersionPrefix: "v3", + }, nil) + if err != nil { + return fmt.Errorf("api client register ('%s'): %w", types.CAPIBaseURL, err) + } + + log.Infof("Successfully registered to Central API (CAPI)") + + var dumpFile string + + switch { + case outputFile != "": + dumpFile = outputFile + case cfg.API.Server.OnlineClient.CredentialsFilePath != "": + dumpFile = cfg.API.Server.OnlineClient.CredentialsFilePath + default: + dumpFile = "" + } + + apiCfg := csconfig.ApiCredentialsCfg{ + Login: capiUser, + Password: password.String(), + URL: types.CAPIBaseURL, + } + + apiConfigDump, err := yaml.Marshal(apiCfg) + if err != nil { + return fmt.Errorf("unable to serialize api credentials: %w", err) + } + + if dumpFile != "" { + err = os.WriteFile(dumpFile, apiConfigDump, 0o600) + if err != nil { + return fmt.Errorf("write api credentials in '%s' failed: %w", dumpFile, err) + } + + log.Infof("Central API credentials written to '%s'", dumpFile) + } else { + fmt.Println(string(apiConfigDump)) + } + + log.Warning(reload.Message) + + return nil +} + +func (cli *cliCapi) newRegisterCmd() *cobra.Command { + var ( + capiUserPrefix string + outputFile string + ) + + cmd := &cobra.Command{ + Use: "register", + Short: "Register to Central API (CAPI)", + Args: cobra.MinimumNArgs(0), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.register(cmd.Context(), capiUserPrefix, outputFile) + }, + } + + cmd.Flags().StringVarP(&outputFile, "file", "f", "", "output file destination") + cmd.Flags().StringVar(&capiUserPrefix, "schmilblick", "", "set a schmilblick (use in tests only)") + + _ = cmd.Flags().MarkHidden("schmilblick") + + return cmd +} + +// queryCAPIStatus checks if the Central API is reachable, and if the credentials are correct. It then checks if the instance is enrolle in the console. +func queryCAPIStatus(ctx context.Context, hub *cwhub.Hub, credURL string, login string, password string) (bool, bool, error) { + apiURL, err := url.Parse(credURL) + if err != nil { + return false, false, err + } + + itemsForAPI := hub.GetInstalledListForAPI() + + if len(itemsForAPI) == 0 { + return false, false, errors.New("no scenarios or appsec-rules installed, abort") + } + + passwd := strfmt.Password(password) + + client, err := apiclient.NewClient(&apiclient.Config{ + MachineID: login, + Password: passwd, + Scenarios: itemsForAPI, + URL: apiURL, + // I don't believe papi is neede to check enrollement + // PapiURL: papiURL, + VersionPrefix: "v3", + UpdateScenario: func(_ context.Context) ([]string, error) { + return itemsForAPI, nil + }, + }) + if err != nil { + return false, false, err + } + + pw := strfmt.Password(password) + + t := models.WatcherAuthRequest{ + MachineID: &login, + Password: &pw, + Scenarios: itemsForAPI, + } + + authResp, _, err := client.Auth.AuthenticateWatcher(ctx, t) + if err != nil { + return false, false, err + } + + client.GetClient().Transport.(*apiclient.JWTTransport).Token = authResp.Token + + if client.IsEnrolled() { + return true, true, nil + } + + return true, false, nil +} + +func (cli *cliCapi) Status(ctx context.Context, out io.Writer, hub *cwhub.Hub) error { + cfg := cli.cfg() + + if err := require.CAPIRegistered(cfg); err != nil { + return err + } + + cred := cfg.API.Server.OnlineClient.Credentials + + fmt.Fprintf(out, "Loaded credentials from %s\n", cfg.API.Server.OnlineClient.CredentialsFilePath) + fmt.Fprintf(out, "Trying to authenticate with username %s on %s\n", cred.Login, cred.URL) + + auth, enrolled, err := queryCAPIStatus(ctx, hub, cred.URL, cred.Login, cred.Password) + if err != nil { + return fmt.Errorf("failed to authenticate to Central API (CAPI): %w", err) + } + + if auth { + fmt.Fprint(out, "You can successfully interact with Central API (CAPI)\n") + } + + if enrolled { + fmt.Fprint(out, "Your instance is enrolled in the console\n") + } + + return nil +} + +func (cli *cliCapi) newStatusCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "status", + Short: "Check status with the Central API (CAPI)", + Args: cobra.MinimumNArgs(0), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + hub, err := require.Hub(cli.cfg(), nil, nil) + if err != nil { + return err + } + + return cli.Status(cmd.Context(), color.Output, hub) + }, + } + + return cmd +} diff --git a/cmd/crowdsec-cli/cliconsole/console.go b/cmd/crowdsec-cli/cliconsole/console.go new file mode 100644 index 00000000000..448ddcee7fa --- /dev/null +++ b/cmd/crowdsec-cli/cliconsole/console.go @@ -0,0 +1,417 @@ +package cliconsole + +import ( + "context" + "encoding/csv" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "os" + "strconv" + "strings" + + "github.com/fatih/color" + "github.com/go-openapi/strfmt" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/go-cs-lib/ptr" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/reload" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/apiclient" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +type configGetter func() *csconfig.Config + +type cliConsole struct { + cfg configGetter +} + +func New(cfg configGetter) *cliConsole { + return &cliConsole{ + cfg: cfg, + } +} + +func (cli *cliConsole) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "console [action]", + Short: "Manage interaction with Crowdsec console (https://app.crowdsec.net)", + Args: cobra.MinimumNArgs(1), + DisableAutoGenTag: true, + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + if err := require.LAPI(cfg); err != nil { + return err + } + if err := require.CAPI(cfg); err != nil { + return err + } + + return require.CAPIRegistered(cfg) + }, + } + + cmd.AddCommand(cli.newEnrollCmd()) + cmd.AddCommand(cli.newEnableCmd()) + cmd.AddCommand(cli.newDisableCmd()) + cmd.AddCommand(cli.newStatusCmd()) + + return cmd +} + +func (cli *cliConsole) enroll(ctx context.Context, key string, name string, overwrite bool, tags []string, opts []string) error { + cfg := cli.cfg() + password := strfmt.Password(cfg.API.Server.OnlineClient.Credentials.Password) + + apiURL, err := url.Parse(cfg.API.Server.OnlineClient.Credentials.URL) + if err != nil { + return fmt.Errorf("could not parse CAPI URL: %w", err) + } + + enableOpts := []string{csconfig.SEND_MANUAL_SCENARIOS, csconfig.SEND_TAINTED_SCENARIOS} + + if len(opts) != 0 { + for _, opt := range opts { + valid := false + + if opt == "all" { + enableOpts = csconfig.CONSOLE_CONFIGS + break + } + + for _, availableOpt := range csconfig.CONSOLE_CONFIGS { + if opt != availableOpt { + continue + } + + valid = true + enable := true + + for _, enabledOpt := range enableOpts { + if opt == enabledOpt { + enable = false + continue + } + } + + if enable { + enableOpts = append(enableOpts, opt) + } + + break + } + + if !valid { + return fmt.Errorf("option %s doesn't exist", opt) + } + } + } + + hub, err := require.Hub(cfg, nil, nil) + if err != nil { + return err + } + + c, _ := apiclient.NewClient(&apiclient.Config{ + MachineID: cli.cfg().API.Server.OnlineClient.Credentials.Login, + Password: password, + Scenarios: hub.GetInstalledListForAPI(), + URL: apiURL, + VersionPrefix: "v3", + }) + + resp, err := c.Auth.EnrollWatcher(ctx, key, name, tags, overwrite) + if err != nil { + return fmt.Errorf("could not enroll instance: %w", err) + } + + if resp.Response.StatusCode == http.StatusOK && !overwrite { + log.Warning("Instance already enrolled. You can use '--overwrite' to force enroll") + return nil + } + + if err := cli.setConsoleOpts(enableOpts, true); err != nil { + return err + } + + for _, opt := range enableOpts { + log.Infof("Enabled %s : %s", opt, csconfig.CONSOLE_CONFIGS_HELP[opt]) + } + + log.Info("Watcher successfully enrolled. Visit https://app.crowdsec.net to accept it.") + log.Info("Please restart crowdsec after accepting the enrollment.") + + return nil +} + +func (cli *cliConsole) newEnrollCmd() *cobra.Command { + name := "" + overwrite := false + tags := []string{} + opts := []string{} + + cmd := &cobra.Command{ + Use: "enroll [enroll-key]", + Short: "Enroll this instance to https://app.crowdsec.net [requires local API]", + Long: ` +Enroll this instance to https://app.crowdsec.net + +You can get your enrollment key by creating an account on https://app.crowdsec.net. +After running this command your will need to validate the enrollment in the webapp.`, + Example: fmt.Sprintf(`cscli console enroll YOUR-ENROLL-KEY + cscli console enroll --name [instance_name] YOUR-ENROLL-KEY + cscli console enroll --name [instance_name] --tags [tag_1] --tags [tag_2] YOUR-ENROLL-KEY + cscli console enroll --enable context,manual YOUR-ENROLL-KEY + + valid options are : %s,all (see 'cscli console status' for details)`, strings.Join(csconfig.CONSOLE_CONFIGS, ",")), + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.enroll(cmd.Context(), args[0], name, overwrite, tags, opts) + }, + } + + flags := cmd.Flags() + flags.StringVarP(&name, "name", "n", "", "Name to display in the console") + flags.BoolVarP(&overwrite, "overwrite", "", false, "Force enroll the instance") + flags.StringSliceVarP(&tags, "tags", "t", tags, "Tags to display in the console") + flags.StringSliceVarP(&opts, "enable", "e", opts, "Enable console options") + + return cmd +} + +func (cli *cliConsole) newEnableCmd() *cobra.Command { + var enableAll bool + + cmd := &cobra.Command{ + Use: "enable [option]", + Short: "Enable a console option", + Example: "sudo cscli console enable tainted", + Long: ` +Enable given information push to the central API. Allows to empower the console`, + ValidArgs: csconfig.CONSOLE_CONFIGS, + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + if enableAll { + if err := cli.setConsoleOpts(csconfig.CONSOLE_CONFIGS, true); err != nil { + return err + } + log.Infof("All features have been enabled successfully") + } else { + if len(args) == 0 { + return errors.New("you must specify at least one feature to enable") + } + if err := cli.setConsoleOpts(args, true); err != nil { + return err + } + log.Infof("%v have been enabled", args) + } + + log.Info(reload.Message) + + return nil + }, + } + cmd.Flags().BoolVarP(&enableAll, "all", "a", false, "Enable all console options") + + return cmd +} + +func (cli *cliConsole) newDisableCmd() *cobra.Command { + var disableAll bool + + cmd := &cobra.Command{ + Use: "disable [option]", + Short: "Disable a console option", + Example: "sudo cscli console disable tainted", + Long: ` +Disable given information push to the central API.`, + ValidArgs: csconfig.CONSOLE_CONFIGS, + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + if disableAll { + if err := cli.setConsoleOpts(csconfig.CONSOLE_CONFIGS, false); err != nil { + return err + } + log.Infof("All features have been disabled") + } else { + if err := cli.setConsoleOpts(args, false); err != nil { + return err + } + log.Infof("%v have been disabled", args) + } + + log.Info(reload.Message) + + return nil + }, + } + cmd.Flags().BoolVarP(&disableAll, "all", "a", false, "Disable all console options") + + return cmd +} + +func (cli *cliConsole) newStatusCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "status", + Short: "Shows status of the console options", + Example: `sudo cscli console status`, + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + consoleCfg := cfg.API.Server.ConsoleConfig + switch cfg.Cscli.Output { + case "human": + cmdConsoleStatusTable(color.Output, cfg.Cscli.Color, *consoleCfg) + case "json": + out := map[string](*bool){ + csconfig.SEND_MANUAL_SCENARIOS: consoleCfg.ShareManualDecisions, + csconfig.SEND_CUSTOM_SCENARIOS: consoleCfg.ShareCustomScenarios, + csconfig.SEND_TAINTED_SCENARIOS: consoleCfg.ShareTaintedScenarios, + csconfig.SEND_CONTEXT: consoleCfg.ShareContext, + csconfig.CONSOLE_MANAGEMENT: consoleCfg.ConsoleManagement, + } + data, err := json.MarshalIndent(out, "", " ") + if err != nil { + return fmt.Errorf("failed to serialize configuration: %w", err) + } + fmt.Println(string(data)) + case "raw": + csvwriter := csv.NewWriter(os.Stdout) + err := csvwriter.Write([]string{"option", "enabled"}) + if err != nil { + return err + } + + rows := [][]string{ + {csconfig.SEND_MANUAL_SCENARIOS, strconv.FormatBool(*consoleCfg.ShareManualDecisions)}, + {csconfig.SEND_CUSTOM_SCENARIOS, strconv.FormatBool(*consoleCfg.ShareCustomScenarios)}, + {csconfig.SEND_TAINTED_SCENARIOS, strconv.FormatBool(*consoleCfg.ShareTaintedScenarios)}, + {csconfig.SEND_CONTEXT, strconv.FormatBool(*consoleCfg.ShareContext)}, + {csconfig.CONSOLE_MANAGEMENT, strconv.FormatBool(*consoleCfg.ConsoleManagement)}, + } + for _, row := range rows { + err = csvwriter.Write(row) + if err != nil { + return err + } + } + csvwriter.Flush() + } + + return nil + }, + } + + return cmd +} + +func (cli *cliConsole) dumpConfig() error { + serverCfg := cli.cfg().API.Server + + out, err := yaml.Marshal(serverCfg.ConsoleConfig) + if err != nil { + return fmt.Errorf("while serializing ConsoleConfig (for %s): %w", serverCfg.ConsoleConfigPath, err) + } + + if serverCfg.ConsoleConfigPath == "" { + serverCfg.ConsoleConfigPath = csconfig.DefaultConsoleConfigFilePath + log.Debugf("Empty console_path, defaulting to %s", serverCfg.ConsoleConfigPath) + } + + if err := os.WriteFile(serverCfg.ConsoleConfigPath, out, 0o600); err != nil { + return fmt.Errorf("while dumping console config to %s: %w", serverCfg.ConsoleConfigPath, err) + } + + return nil +} + +func (cli *cliConsole) setConsoleOpts(args []string, wanted bool) error { + cfg := cli.cfg() + consoleCfg := cfg.API.Server.ConsoleConfig + + for _, arg := range args { + switch arg { + case csconfig.CONSOLE_MANAGEMENT: + /*for each flag check if it's already set before setting it*/ + if consoleCfg.ConsoleManagement != nil && *consoleCfg.ConsoleManagement == wanted { + log.Debugf("%s already set to %t", csconfig.CONSOLE_MANAGEMENT, wanted) + } else { + log.Infof("%s set to %t", csconfig.CONSOLE_MANAGEMENT, wanted) + consoleCfg.ConsoleManagement = ptr.Of(wanted) + } + + if cfg.API.Server.OnlineClient.Credentials != nil { + changed := false + if wanted && cfg.API.Server.OnlineClient.Credentials.PapiURL == "" { + changed = true + cfg.API.Server.OnlineClient.Credentials.PapiURL = types.PAPIBaseURL + } else if !wanted && cfg.API.Server.OnlineClient.Credentials.PapiURL != "" { + changed = true + cfg.API.Server.OnlineClient.Credentials.PapiURL = "" + } + + if changed { + fileContent, err := yaml.Marshal(cfg.API.Server.OnlineClient.Credentials) + if err != nil { + return fmt.Errorf("cannot serialize credentials: %w", err) + } + + log.Infof("Updating credentials file: %s", cfg.API.Server.OnlineClient.CredentialsFilePath) + + err = os.WriteFile(cfg.API.Server.OnlineClient.CredentialsFilePath, fileContent, 0o600) + if err != nil { + return fmt.Errorf("cannot write credentials file: %w", err) + } + } + } + case csconfig.SEND_CUSTOM_SCENARIOS: + /*for each flag check if it's already set before setting it*/ + if consoleCfg.ShareCustomScenarios != nil && *consoleCfg.ShareCustomScenarios == wanted { + log.Debugf("%s already set to %t", csconfig.SEND_CUSTOM_SCENARIOS, wanted) + } else { + log.Infof("%s set to %t", csconfig.SEND_CUSTOM_SCENARIOS, wanted) + consoleCfg.ShareCustomScenarios = ptr.Of(wanted) + } + case csconfig.SEND_TAINTED_SCENARIOS: + /*for each flag check if it's already set before setting it*/ + if consoleCfg.ShareTaintedScenarios != nil && *consoleCfg.ShareTaintedScenarios == wanted { + log.Debugf("%s already set to %t", csconfig.SEND_TAINTED_SCENARIOS, wanted) + } else { + log.Infof("%s set to %t", csconfig.SEND_TAINTED_SCENARIOS, wanted) + consoleCfg.ShareTaintedScenarios = ptr.Of(wanted) + } + case csconfig.SEND_MANUAL_SCENARIOS: + /*for each flag check if it's already set before setting it*/ + if consoleCfg.ShareManualDecisions != nil && *consoleCfg.ShareManualDecisions == wanted { + log.Debugf("%s already set to %t", csconfig.SEND_MANUAL_SCENARIOS, wanted) + } else { + log.Infof("%s set to %t", csconfig.SEND_MANUAL_SCENARIOS, wanted) + consoleCfg.ShareManualDecisions = ptr.Of(wanted) + } + case csconfig.SEND_CONTEXT: + /*for each flag check if it's already set before setting it*/ + if consoleCfg.ShareContext != nil && *consoleCfg.ShareContext == wanted { + log.Debugf("%s already set to %t", csconfig.SEND_CONTEXT, wanted) + } else { + log.Infof("%s set to %t", csconfig.SEND_CONTEXT, wanted) + consoleCfg.ShareContext = ptr.Of(wanted) + } + default: + return fmt.Errorf("unknown flag %s", arg) + } + } + + if err := cli.dumpConfig(); err != nil { + return fmt.Errorf("failed writing console config: %w", err) + } + + return nil +} diff --git a/cmd/crowdsec-cli/cliconsole/console_table.go b/cmd/crowdsec-cli/cliconsole/console_table.go new file mode 100644 index 00000000000..8f17b97860a --- /dev/null +++ b/cmd/crowdsec-cli/cliconsole/console_table.go @@ -0,0 +1,50 @@ +package cliconsole + +import ( + "io" + + "github.com/jedib0t/go-pretty/v6/text" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/emoji" +) + +func cmdConsoleStatusTable(out io.Writer, wantColor string, consoleCfg csconfig.ConsoleConfig) { + t := cstable.New(out, wantColor) + t.SetRowLines(false) + + t.SetHeaders("Option Name", "Activated", "Description") + t.SetHeaderAlignment(text.AlignLeft, text.AlignLeft, text.AlignLeft) + + for _, option := range csconfig.CONSOLE_CONFIGS { + activated := emoji.CrossMark + + switch option { + case csconfig.SEND_CUSTOM_SCENARIOS: + if *consoleCfg.ShareCustomScenarios { + activated = emoji.CheckMarkButton + } + case csconfig.SEND_MANUAL_SCENARIOS: + if *consoleCfg.ShareManualDecisions { + activated = emoji.CheckMarkButton + } + case csconfig.SEND_TAINTED_SCENARIOS: + if *consoleCfg.ShareTaintedScenarios { + activated = emoji.CheckMarkButton + } + case csconfig.SEND_CONTEXT: + if *consoleCfg.ShareContext { + activated = emoji.CheckMarkButton + } + case csconfig.CONSOLE_MANAGEMENT: + if *consoleCfg.ConsoleManagement { + activated = emoji.CheckMarkButton + } + } + + t.AddRow(option, activated, csconfig.CONSOLE_CONFIGS_HELP[option]) + } + + t.Render() +} diff --git a/cmd/crowdsec-cli/clidecision/decisions.go b/cmd/crowdsec-cli/clidecision/decisions.go new file mode 100644 index 00000000000..1f8781a3716 --- /dev/null +++ b/cmd/crowdsec-cli/clidecision/decisions.go @@ -0,0 +1,565 @@ +package clidecision + +import ( + "context" + "encoding/csv" + "encoding/json" + "errors" + "fmt" + "net/url" + "os" + "strconv" + "strings" + "time" + + "github.com/fatih/color" + "github.com/go-openapi/strfmt" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clialert" + "github.com/crowdsecurity/crowdsec/pkg/apiclient" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +func (cli *cliDecisions) decisionsToTable(alerts *models.GetAlertsResponse, printMachine bool) error { + /*here we cheat a bit : to make it more readable for the user, we dedup some entries*/ + spamLimit := make(map[string]bool) + skipped := 0 + + for aIdx := range len(*alerts) { + alertItem := (*alerts)[aIdx] + newDecisions := make([]*models.Decision, 0) + + for _, decisionItem := range alertItem.Decisions { + spamKey := fmt.Sprintf("%t:%s:%s:%s", *decisionItem.Simulated, *decisionItem.Type, *decisionItem.Scope, *decisionItem.Value) + if _, ok := spamLimit[spamKey]; ok { + skipped++ + continue + } + + spamLimit[spamKey] = true + + newDecisions = append(newDecisions, decisionItem) + } + + alertItem.Decisions = newDecisions + } + + switch cli.cfg().Cscli.Output { + case "raw": + csvwriter := csv.NewWriter(os.Stdout) + header := []string{"id", "source", "ip", "reason", "action", "country", "as", "events_count", "expiration", "simulated", "alert_id"} + + if printMachine { + header = append(header, "machine") + } + + err := csvwriter.Write(header) + if err != nil { + return err + } + + for _, alertItem := range *alerts { + for _, decisionItem := range alertItem.Decisions { + raw := []string{ + fmt.Sprintf("%d", decisionItem.ID), + *decisionItem.Origin, + *decisionItem.Scope + ":" + *decisionItem.Value, + *decisionItem.Scenario, + *decisionItem.Type, + alertItem.Source.Cn, + alertItem.Source.GetAsNumberName(), + fmt.Sprintf("%d", *alertItem.EventsCount), + *decisionItem.Duration, + fmt.Sprintf("%t", *decisionItem.Simulated), + fmt.Sprintf("%d", alertItem.ID), + } + if printMachine { + raw = append(raw, alertItem.MachineID) + } + + err := csvwriter.Write(raw) + if err != nil { + return err + } + } + } + + csvwriter.Flush() + case "json": + if *alerts == nil { + // avoid returning "null" in `json" + // could be cleaner if we used slice of alerts directly + fmt.Println("[]") + return nil + } + + x, _ := json.MarshalIndent(alerts, "", " ") + fmt.Printf("%s", string(x)) + case "human": + if len(*alerts) == 0 { + fmt.Println("No active decisions") + return nil + } + + cli.decisionsTable(color.Output, alerts, printMachine) + + if skipped > 0 { + fmt.Printf("%d duplicated entries skipped\n", skipped) + } + } + + return nil +} + +type configGetter func() *csconfig.Config + +type cliDecisions struct { + client *apiclient.ApiClient + cfg configGetter +} + +func New(cfg configGetter) *cliDecisions { + return &cliDecisions{ + cfg: cfg, + } +} + +func (cli *cliDecisions) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "decisions [action]", + Short: "Manage decisions", + Long: `Add/List/Delete/Import decisions from LAPI`, + Example: `cscli decisions [action] [filter]`, + Aliases: []string{"decision"}, + /*TBD example*/ + Args: cobra.MinimumNArgs(1), + DisableAutoGenTag: true, + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + if err := cfg.LoadAPIClient(); err != nil { + return fmt.Errorf("loading api client: %w", err) + } + apiURL, err := url.Parse(cfg.API.Client.Credentials.URL) + if err != nil { + return fmt.Errorf("parsing api url: %w", err) + } + + cli.client, err = apiclient.NewClient(&apiclient.Config{ + MachineID: cfg.API.Client.Credentials.Login, + Password: strfmt.Password(cfg.API.Client.Credentials.Password), + URL: apiURL, + VersionPrefix: "v1", + }) + if err != nil { + return fmt.Errorf("creating api client: %w", err) + } + + return nil + }, + } + + cmd.AddCommand(cli.newListCmd()) + cmd.AddCommand(cli.newAddCmd()) + cmd.AddCommand(cli.newDeleteCmd()) + cmd.AddCommand(cli.newImportCmd()) + + return cmd +} + +func (cli *cliDecisions) list(ctx context.Context, filter apiclient.AlertsListOpts, NoSimu *bool, contained *bool, printMachine bool) error { + var err error + + *filter.ScopeEquals, err = clialert.SanitizeScope(*filter.ScopeEquals, *filter.IPEquals, *filter.RangeEquals) + if err != nil { + return err + } + + filter.ActiveDecisionEquals = new(bool) + *filter.ActiveDecisionEquals = true + + if NoSimu != nil && *NoSimu { + filter.IncludeSimulated = new(bool) + } + /* nullify the empty entries to avoid bad filter */ + if *filter.Until == "" { + filter.Until = nil + } else if strings.HasSuffix(*filter.Until, "d") { + /*time.ParseDuration support hours 'h' as bigger unit, let's make the user's life easier*/ + realDuration := strings.TrimSuffix(*filter.Until, "d") + + days, err := strconv.Atoi(realDuration) + if err != nil { + return fmt.Errorf("can't parse duration %s, valid durations format: 1d, 4h, 4h15m", *filter.Until) + } + + *filter.Until = fmt.Sprintf("%d%s", days*24, "h") + } + + if *filter.Since == "" { + filter.Since = nil + } else if strings.HasSuffix(*filter.Since, "d") { + /*time.ParseDuration support hours 'h' as bigger unit, let's make the user's life easier*/ + realDuration := strings.TrimSuffix(*filter.Since, "d") + + days, err := strconv.Atoi(realDuration) + if err != nil { + return fmt.Errorf("can't parse duration %s, valid durations format: 1d, 4h, 4h15m", *filter.Since) + } + + *filter.Since = fmt.Sprintf("%d%s", days*24, "h") + } + + if *filter.IncludeCAPI { + *filter.Limit = 0 + } + + if *filter.TypeEquals == "" { + filter.TypeEquals = nil + } + + if *filter.ValueEquals == "" { + filter.ValueEquals = nil + } + + if *filter.ScopeEquals == "" { + filter.ScopeEquals = nil + } + + if *filter.ScenarioEquals == "" { + filter.ScenarioEquals = nil + } + + if *filter.IPEquals == "" { + filter.IPEquals = nil + } + + if *filter.RangeEquals == "" { + filter.RangeEquals = nil + } + + if *filter.OriginEquals == "" { + filter.OriginEquals = nil + } + + if contained != nil && *contained { + filter.Contains = new(bool) + } + + alerts, _, err := cli.client.Alerts.List(ctx, filter) + if err != nil { + return fmt.Errorf("unable to retrieve decisions: %w", err) + } + + err = cli.decisionsToTable(alerts, printMachine) + if err != nil { + return fmt.Errorf("unable to print decisions: %w", err) + } + + return nil +} + +func (cli *cliDecisions) newListCmd() *cobra.Command { + filter := apiclient.AlertsListOpts{ + ValueEquals: new(string), + ScopeEquals: new(string), + ScenarioEquals: new(string), + OriginEquals: new(string), + IPEquals: new(string), + RangeEquals: new(string), + Since: new(string), + Until: new(string), + TypeEquals: new(string), + IncludeCAPI: new(bool), + Limit: new(int), + } + + NoSimu := new(bool) + contained := new(bool) + + var printMachine bool + + cmd := &cobra.Command{ + Use: "list [options]", + Short: "List decisions from LAPI", + Example: `cscli decisions list -i 1.2.3.4 +cscli decisions list -r 1.2.3.0/24 +cscli decisions list -s crowdsecurity/ssh-bf +cscli decisions list --origin lists --scenario list_name +`, + Args: cobra.ExactArgs(0), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.list(cmd.Context(), filter, NoSimu, contained, printMachine) + }, + } + + flags := cmd.Flags() + flags.SortFlags = false + flags.BoolVarP(filter.IncludeCAPI, "all", "a", false, "Include decisions from Central API") + flags.StringVar(filter.Since, "since", "", "restrict to alerts newer than since (ie. 4h, 30d)") + flags.StringVar(filter.Until, "until", "", "restrict to alerts older than until (ie. 4h, 30d)") + flags.StringVarP(filter.TypeEquals, "type", "t", "", "restrict to this decision type (ie. ban,captcha)") + flags.StringVar(filter.ScopeEquals, "scope", "", "restrict to this scope (ie. ip,range,session)") + flags.StringVar(filter.OriginEquals, "origin", "", fmt.Sprintf("the value to match for the specified origin (%s ...)", strings.Join(types.GetOrigins(), ","))) + flags.StringVarP(filter.ValueEquals, "value", "v", "", "restrict to this value (ie. 1.2.3.4,userName)") + flags.StringVarP(filter.ScenarioEquals, "scenario", "s", "", "restrict to this scenario (ie. crowdsecurity/ssh-bf)") + flags.StringVarP(filter.IPEquals, "ip", "i", "", "restrict to alerts from this source ip (shorthand for --scope ip --value )") + flags.StringVarP(filter.RangeEquals, "range", "r", "", "restrict to alerts from this source range (shorthand for --scope range --value )") + flags.IntVarP(filter.Limit, "limit", "l", 100, "number of alerts to get (use 0 to remove the limit)") + flags.BoolVar(NoSimu, "no-simu", false, "exclude decisions in simulation mode") + flags.BoolVarP(&printMachine, "machine", "m", false, "print machines that triggered decisions") + flags.BoolVar(contained, "contained", false, "query decisions contained by range") + + return cmd +} + +func (cli *cliDecisions) add(ctx context.Context, addIP, addRange, addDuration, addValue, addScope, addReason, addType string) error { + alerts := models.AddAlertsRequest{} + origin := types.CscliOrigin + capacity := int32(0) + leakSpeed := "0" + eventsCount := int32(1) + empty := "" + simulated := false + startAt := time.Now().UTC().Format(time.RFC3339) + stopAt := time.Now().UTC().Format(time.RFC3339) + createdAt := time.Now().UTC().Format(time.RFC3339) + + var err error + + addScope, err = clialert.SanitizeScope(addScope, addIP, addRange) + if err != nil { + return err + } + + if addIP != "" { + addValue = addIP + addScope = types.Ip + } else if addRange != "" { + addValue = addRange + addScope = types.Range + } else if addValue == "" { + return errors.New("missing arguments, a value is required (--ip, --range or --scope and --value)") + } + + if addReason == "" { + addReason = fmt.Sprintf("manual '%s' from '%s'", addType, cli.cfg().API.Client.Credentials.Login) + } + + decision := models.Decision{ + Duration: &addDuration, + Scope: &addScope, + Value: &addValue, + Type: &addType, + Scenario: &addReason, + Origin: &origin, + } + alert := models.Alert{ + Capacity: &capacity, + Decisions: []*models.Decision{&decision}, + Events: []*models.Event{}, + EventsCount: &eventsCount, + Leakspeed: &leakSpeed, + Message: &addReason, + ScenarioHash: &empty, + Scenario: &addReason, + ScenarioVersion: &empty, + Simulated: &simulated, + // setting empty scope/value broke plugins, and it didn't seem to be needed anymore w/ latest papi changes + Source: &models.Source{ + AsName: "", + AsNumber: "", + Cn: "", + IP: addValue, + Range: "", + Scope: &addScope, + Value: &addValue, + }, + StartAt: &startAt, + StopAt: &stopAt, + CreatedAt: createdAt, + Remediation: true, + } + alerts = append(alerts, &alert) + + _, _, err = cli.client.Alerts.Add(ctx, alerts) + if err != nil { + return err + } + + log.Info("Decision successfully added") + + return nil +} + +func (cli *cliDecisions) newAddCmd() *cobra.Command { + var ( + addIP string + addRange string + addDuration string + addValue string + addScope string + addReason string + addType string + ) + + cmd := &cobra.Command{ + Use: "add [options]", + Short: "Add decision to LAPI", + Example: `cscli decisions add --ip 1.2.3.4 +cscli decisions add --range 1.2.3.0/24 +cscli decisions add --ip 1.2.3.4 --duration 24h --type captcha +cscli decisions add --scope username --value foobar +`, + /*TBD : fix long and example*/ + Args: cobra.ExactArgs(0), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.add(cmd.Context(), addIP, addRange, addDuration, addValue, addScope, addReason, addType) + }, + } + + flags := cmd.Flags() + flags.SortFlags = false + flags.StringVarP(&addIP, "ip", "i", "", "Source ip (shorthand for --scope ip --value )") + flags.StringVarP(&addRange, "range", "r", "", "Range source ip (shorthand for --scope range --value )") + flags.StringVarP(&addDuration, "duration", "d", "4h", "Decision duration (ie. 1h,4h,30m)") + flags.StringVarP(&addValue, "value", "v", "", "The value (ie. --scope username --value foobar)") + flags.StringVar(&addScope, "scope", types.Ip, "Decision scope (ie. ip,range,username)") + flags.StringVarP(&addReason, "reason", "R", "", "Decision reason (ie. scenario-name)") + flags.StringVarP(&addType, "type", "t", "ban", "Decision type (ie. ban,captcha,throttle)") + + return cmd +} + +func (cli *cliDecisions) delete(ctx context.Context, delFilter apiclient.DecisionsDeleteOpts, delDecisionID string, contained *bool) error { + var err error + + /*take care of shorthand options*/ + *delFilter.ScopeEquals, err = clialert.SanitizeScope(*delFilter.ScopeEquals, *delFilter.IPEquals, *delFilter.RangeEquals) + if err != nil { + return err + } + + if *delFilter.ScopeEquals == "" { + delFilter.ScopeEquals = nil + } + + if *delFilter.OriginEquals == "" { + delFilter.OriginEquals = nil + } + + if *delFilter.ValueEquals == "" { + delFilter.ValueEquals = nil + } + + if *delFilter.ScenarioEquals == "" { + delFilter.ScenarioEquals = nil + } + + if *delFilter.TypeEquals == "" { + delFilter.TypeEquals = nil + } + + if *delFilter.IPEquals == "" { + delFilter.IPEquals = nil + } + + if *delFilter.RangeEquals == "" { + delFilter.RangeEquals = nil + } + + if contained != nil && *contained { + delFilter.Contains = new(bool) + } + + var decisions *models.DeleteDecisionResponse + + if delDecisionID == "" { + decisions, _, err = cli.client.Decisions.Delete(ctx, delFilter) + if err != nil { + return fmt.Errorf("unable to delete decisions: %w", err) + } + } else { + if _, err = strconv.Atoi(delDecisionID); err != nil { + return fmt.Errorf("id '%s' is not an integer: %w", delDecisionID, err) + } + + decisions, _, err = cli.client.Decisions.DeleteOne(ctx, delDecisionID) + if err != nil { + return fmt.Errorf("unable to delete decision: %w", err) + } + } + + log.Infof("%s decision(s) deleted", decisions.NbDeleted) + + return nil +} + +func (cli *cliDecisions) newDeleteCmd() *cobra.Command { + delFilter := apiclient.DecisionsDeleteOpts{ + ScopeEquals: new(string), + ValueEquals: new(string), + TypeEquals: new(string), + IPEquals: new(string), + RangeEquals: new(string), + ScenarioEquals: new(string), + OriginEquals: new(string), + } + + var delDecisionID string + + var delDecisionAll bool + + contained := new(bool) + + cmd := &cobra.Command{ + Use: "delete [options]", + Short: "Delete decisions", + DisableAutoGenTag: true, + Aliases: []string{"remove"}, + Example: `cscli decisions delete -r 1.2.3.0/24 +cscli decisions delete -i 1.2.3.4 +cscli decisions delete --id 42 +cscli decisions delete --type captcha +cscli decisions delete --origin lists --scenario list_name +`, + /*TBD : refaire le Long/Example*/ + PreRunE: func(cmd *cobra.Command, _ []string) error { + if delDecisionAll { + return nil + } + if *delFilter.ScopeEquals == "" && *delFilter.ValueEquals == "" && + *delFilter.TypeEquals == "" && *delFilter.IPEquals == "" && + *delFilter.RangeEquals == "" && *delFilter.ScenarioEquals == "" && + *delFilter.OriginEquals == "" && delDecisionID == "" { + _ = cmd.Usage() + return errors.New("at least one filter or --all must be specified") + } + + return nil + }, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.delete(cmd.Context(), delFilter, delDecisionID, contained) + }, + } + + flags := cmd.Flags() + flags.SortFlags = false + flags.StringVarP(delFilter.IPEquals, "ip", "i", "", "Source ip (shorthand for --scope ip --value )") + flags.StringVarP(delFilter.RangeEquals, "range", "r", "", "Range source ip (shorthand for --scope range --value )") + flags.StringVarP(delFilter.TypeEquals, "type", "t", "", "the decision type (ie. ban,captcha)") + flags.StringVarP(delFilter.ValueEquals, "value", "v", "", "the value to match for in the specified scope") + flags.StringVarP(delFilter.ScenarioEquals, "scenario", "s", "", "the scenario name (ie. crowdsecurity/ssh-bf)") + flags.StringVar(delFilter.OriginEquals, "origin", "", fmt.Sprintf("the value to match for the specified origin (%s ...)", strings.Join(types.GetOrigins(), ","))) + + flags.StringVar(&delDecisionID, "id", "", "decision id") + flags.BoolVar(&delDecisionAll, "all", false, "delete all decisions") + flags.BoolVar(contained, "contained", false, "query decisions contained by range") + + return cmd +} diff --git a/cmd/crowdsec-cli/decisions_import.go b/cmd/crowdsec-cli/clidecision/decisions_import.go similarity index 83% rename from cmd/crowdsec-cli/decisions_import.go rename to cmd/crowdsec-cli/clidecision/decisions_import.go index 6a47a96b3ea..10d92f88876 100644 --- a/cmd/crowdsec-cli/decisions_import.go +++ b/cmd/crowdsec-cli/clidecision/decisions_import.go @@ -1,10 +1,11 @@ -package main +package clidecision import ( "bufio" "bytes" "context" "encoding/json" + "errors" "fmt" "io" "os" @@ -15,8 +16,8 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "github.com/crowdsecurity/go-cs-lib/pkg/ptr" - "github.com/crowdsecurity/go-cs-lib/pkg/slicetools" + "github.com/crowdsecurity/go-cs-lib/ptr" + "github.com/crowdsecurity/go-cs-lib/slicetools" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/types" @@ -37,23 +38,27 @@ func parseDecisionList(content []byte, format string) ([]decisionRaw, error) { switch format { case "values": log.Infof("Parsing values") + scanner := bufio.NewScanner(bytes.NewReader(content)) for scanner.Scan() { value := strings.TrimSpace(scanner.Text()) ret = append(ret, decisionRaw{Value: value}) } + if err := scanner.Err(); err != nil { - return nil, fmt.Errorf("unable to parse values: '%s'", err) + return nil, fmt.Errorf("unable to parse values: '%w'", err) } case "json": log.Infof("Parsing json") + if err := json.Unmarshal(content, &ret); err != nil { return nil, err } case "csv": log.Infof("Parsing csv") + if err := csvutil.Unmarshal(content, &ret); err != nil { - return nil, fmt.Errorf("unable to parse csv: '%s'", err) + return nil, fmt.Errorf("unable to parse csv: '%w'", err) } default: return nil, fmt.Errorf("invalid format '%s', expected one of 'json', 'csv', 'values'", format) @@ -62,8 +67,7 @@ func parseDecisionList(content []byte, format string) ([]decisionRaw, error) { return ret, nil } - -func runDecisionsImport(cmd *cobra.Command, args []string) error { +func (cli *cliDecisions) runImport(cmd *cobra.Command, args []string) error { flags := cmd.Flags() input, err := flags.GetString("input") @@ -75,32 +79,36 @@ func runDecisionsImport(cmd *cobra.Command, args []string) error { if err != nil { return err } + if defaultDuration == "" { - return fmt.Errorf("--duration cannot be empty") + return errors.New("--duration cannot be empty") } defaultScope, err := flags.GetString("scope") if err != nil { return err } + if defaultScope == "" { - return fmt.Errorf("--scope cannot be empty") + return errors.New("--scope cannot be empty") } defaultReason, err := flags.GetString("reason") if err != nil { return err } + if defaultReason == "" { - return fmt.Errorf("--reason cannot be empty") + return errors.New("--reason cannot be empty") } defaultType, err := flags.GetString("type") if err != nil { return err } + if defaultType == "" { - return fmt.Errorf("--type cannot be empty") + return errors.New("--type cannot be empty") } batchSize, err := flags.GetInt("batch") @@ -115,7 +123,7 @@ func runDecisionsImport(cmd *cobra.Command, args []string) error { var ( content []byte - fin *os.File + fin *os.File ) // set format if the file has a json or csv extension @@ -128,7 +136,7 @@ func runDecisionsImport(cmd *cobra.Command, args []string) error { } if format == "" { - return fmt.Errorf("unable to guess format from file extension, please provide a format with --format flag") + return errors.New("unable to guess format from file extension, please provide a format with --format flag") } if input == "-" { @@ -137,13 +145,13 @@ func runDecisionsImport(cmd *cobra.Command, args []string) error { } else { fin, err = os.Open(input) if err != nil { - return fmt.Errorf("unable to open %s: %s", input, err) + return fmt.Errorf("unable to open %s: %w", input, err) } } content, err = io.ReadAll(fin) if err != nil { - return fmt.Errorf("unable to read from %s: %s", input, err) + return fmt.Errorf("unable to read from %s: %w", input, err) } decisionsListRaw, err := parseDecisionList(content, format) @@ -152,6 +160,7 @@ func runDecisionsImport(cmd *cobra.Command, args []string) error { } decisions := make([]*models.Decision, len(decisionsListRaw)) + for i, d := range decisionsListRaw { if d.Value == "" { return fmt.Errorf("item %d: missing 'value'", i) @@ -188,7 +197,9 @@ func runDecisionsImport(cmd *cobra.Command, args []string) error { } } - alerts := models.AddAlertsRequest{} + if len(decisions) > 1000 { + log.Infof("You are about to add %d decisions, this may take a while", len(decisions)) + } for _, chunk := range slicetools.Chunks(decisions, batchSize) { log.Debugf("Processing chunk of %d decisions", len(chunk)) @@ -212,30 +223,26 @@ func runDecisionsImport(cmd *cobra.Command, args []string) error { ScenarioVersion: ptr.Of(""), Decisions: chunk, } - alerts = append(alerts, &importAlert) - } - - if len(decisions) > 1000 { - log.Infof("You are about to add %d decisions, this may take a while", len(decisions)) - } - _, _, err = Client.Alerts.Add(context.Background(), alerts) - if err != nil { - return err + _, _, err = cli.client.Alerts.Add(context.Background(), models.AddAlertsRequest{&importAlert}) + if err != nil { + return err + } } log.Infof("Imported %d decisions", len(decisions)) + return nil } - -func NewDecisionsImportCmd() *cobra.Command { - var cmdDecisionsImport = &cobra.Command{ +func (cli *cliDecisions) newImportCmd() *cobra.Command { + cmd := &cobra.Command{ Use: "import [options]", Short: "Import decisions from a file or pipe", Long: "expected format:\n" + "csv : any of duration,reason,scope,type,value, with a header line\n" + - `json : {"duration" : "24h", "reason" : "my_scenario", "scope" : "ip", "type" : "ban", "value" : "x.y.z.z"}`, + "json :" + "`{" + `"duration" : "24h", "reason" : "my_scenario", "scope" : "ip", "type" : "ban", "value" : "x.y.z.z"` + "}`", + Args: cobra.NoArgs, DisableAutoGenTag: true, Example: `decisions.csv: duration,scope,value @@ -253,10 +260,10 @@ Raw values, standard input: $ echo "1.2.3.4" | cscli decisions import -i - --format values `, - RunE: runDecisionsImport, + RunE: cli.runImport, } - flags := cmdDecisionsImport.Flags() + flags := cmd.Flags() flags.SortFlags = false flags.StringP("input", "i", "", "Input file") flags.StringP("duration", "d", "4h", "Decision duration: 1h,4h,30m") @@ -266,7 +273,7 @@ $ echo "1.2.3.4" | cscli decisions import -i - --format values flags.Int("batch", 0, "Split import in batches of N decisions") flags.String("format", "", "Input format: 'json', 'csv' or 'values' (each line is a value, no headers)") - cmdDecisionsImport.MarkFlagRequired("input") + _ = cmd.MarkFlagRequired("input") - return cmdDecisionsImport + return cmd } diff --git a/cmd/crowdsec-cli/decisions_table.go b/cmd/crowdsec-cli/clidecision/decisions_table.go similarity index 80% rename from cmd/crowdsec-cli/decisions_table.go rename to cmd/crowdsec-cli/clidecision/decisions_table.go index d8d5e032594..90a0ae1176b 100644 --- a/cmd/crowdsec-cli/decisions_table.go +++ b/cmd/crowdsec-cli/clidecision/decisions_table.go @@ -1,20 +1,23 @@ -package main +package clidecision import ( "fmt" "io" "strconv" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" "github.com/crowdsecurity/crowdsec/pkg/models" ) -func decisionsTable(out io.Writer, alerts *models.GetAlertsResponse, printMachine bool) { - t := newTable(out) +func (cli *cliDecisions) decisionsTable(out io.Writer, alerts *models.GetAlertsResponse, printMachine bool) { + t := cstable.New(out, cli.cfg().Cscli.Color) t.SetRowLines(false) + header := []string{"ID", "Source", "Scope:Value", "Reason", "Action", "Country", "AS", "Events", "expiration", "Alert ID"} if printMachine { header = append(header, "Machine") } + t.SetHeaders(header...) for _, alertItem := range *alerts { @@ -22,6 +25,7 @@ func decisionsTable(out io.Writer, alerts *models.GetAlertsResponse, printMachin if *alertItem.Simulated { *decisionItem.Type = fmt.Sprintf("(simul)%s", *decisionItem.Type) } + row := []string{ strconv.Itoa(int(decisionItem.ID)), *decisionItem.Origin, @@ -42,5 +46,6 @@ func decisionsTable(out io.Writer, alerts *models.GetAlertsResponse, printMachin t.AddRow(row...) } } + t.Render() } diff --git a/cmd/crowdsec-cli/clientinfo/clientinfo.go b/cmd/crowdsec-cli/clientinfo/clientinfo.go new file mode 100644 index 00000000000..0bf1d98804f --- /dev/null +++ b/cmd/crowdsec-cli/clientinfo/clientinfo.go @@ -0,0 +1,39 @@ +package clientinfo + +import ( + "strings" +) + +type featureflagProvider interface { + GetFeatureflags() string +} + +type osProvider interface { + GetOsname() string + GetOsversion() string +} + +func GetOSNameAndVersion(o osProvider) string { + ret := o.GetOsname() + if o.GetOsversion() != "" { + if ret != "" { + ret += "/" + } + + ret += o.GetOsversion() + } + + if ret == "" { + return "?" + } + + return ret +} + +func GetFeatureFlagList(o featureflagProvider) []string { + if o.GetFeatureflags() == "" { + return nil + } + + return strings.Split(o.GetFeatureflags(), ",") +} diff --git a/cmd/crowdsec-cli/cliexplain/explain.go b/cmd/crowdsec-cli/cliexplain/explain.go new file mode 100644 index 00000000000..182e34a12a5 --- /dev/null +++ b/cmd/crowdsec-cli/cliexplain/explain.go @@ -0,0 +1,254 @@ +package cliexplain + +import ( + "bufio" + "errors" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/dumps" + "github.com/crowdsecurity/crowdsec/pkg/hubtest" +) + +func getLineCountForFile(filepath string) (int, error) { + f, err := os.Open(filepath) + if err != nil { + return 0, err + } + defer f.Close() + + lc := 0 + fs := bufio.NewReader(f) + + for { + input, err := fs.ReadBytes('\n') + if len(input) > 1 { + lc++ + } + + if err != nil && err == io.EOF { + break + } + } + + return lc, nil +} + +type configGetter func() *csconfig.Config + +type cliExplain struct { + cfg configGetter + configFilePath string + flags struct { + logFile string + dsn string + logLine string + logType string + details bool + skipOk bool + onlySuccessfulParsers bool + noClean bool + crowdsec string + labels string + } +} + +func New(cfg configGetter, configFilePath string) *cliExplain { + return &cliExplain{ + cfg: cfg, + configFilePath: configFilePath, + } +} + +func (cli *cliExplain) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "explain", + Short: "Explain log pipeline", + Long: ` +Explain log pipeline + `, + Example: ` +cscli explain --file ./myfile.log --type nginx +cscli explain --log "Sep 19 18:33:22 scw-d95986 sshd[24347]: pam_unix(sshd:auth): authentication failure; logname= uid=0 euid=0 tty=ssh ruser= rhost=1.2.3.4" --type syslog +cscli explain --dsn "file://myfile.log" --type nginx +tail -n 5 myfile.log | cscli explain --type nginx -f - + `, + Args: cobra.ExactArgs(0), + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, _ []string) error { + return cli.run() + }, + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + fileInfo, _ := os.Stdin.Stat() + if cli.flags.logFile == "-" && ((fileInfo.Mode() & os.ModeCharDevice) == os.ModeCharDevice) { + return errors.New("the option -f - is intended to work with pipes") + } + + return nil + }, + } + + flags := cmd.Flags() + + flags.StringVarP(&cli.flags.logFile, "file", "f", "", "Log file to test") + flags.StringVarP(&cli.flags.dsn, "dsn", "d", "", "DSN to test") + flags.StringVarP(&cli.flags.logLine, "log", "l", "", "Log line to test") + flags.StringVarP(&cli.flags.logType, "type", "t", "", "Type of the acquisition to test") + flags.StringVar(&cli.flags.labels, "labels", "", "Additional labels to add to the acquisition format (key:value,key2:value2)") + flags.BoolVarP(&cli.flags.details, "verbose", "v", false, "Display individual changes") + flags.BoolVar(&cli.flags.skipOk, "failures", false, "Only show failed lines") + flags.BoolVar(&cli.flags.onlySuccessfulParsers, "only-successful-parsers", false, "Only show successful parsers") + flags.StringVar(&cli.flags.crowdsec, "crowdsec", "crowdsec", "Path to crowdsec") + flags.BoolVar(&cli.flags.noClean, "no-clean", false, "Don't clean runtime environment after tests") + + _ = cmd.MarkFlagRequired("type") + cmd.MarkFlagsOneRequired("log", "file", "dsn") + + return cmd +} + +func (cli *cliExplain) run() error { + logFile := cli.flags.logFile + logLine := cli.flags.logLine + logType := cli.flags.logType + dsn := cli.flags.dsn + labels := cli.flags.labels + crowdsec := cli.flags.crowdsec + + opts := dumps.DumpOpts{ + Details: cli.flags.details, + SkipOk: cli.flags.skipOk, + ShowNotOkParsers: !cli.flags.onlySuccessfulParsers, + } + + var f *os.File + + // using empty string fallback to /tmp + dir, err := os.MkdirTemp("", "cscli_explain") + if err != nil { + return fmt.Errorf("couldn't create a temporary directory to store cscli explain result: %w", err) + } + + defer func() { + if cli.flags.noClean { + return + } + + if _, err := os.Stat(dir); !os.IsNotExist(err) { + if err := os.RemoveAll(dir); err != nil { + log.Errorf("unable to delete temporary directory '%s': %s", dir, err) + } + } + }() + + // we create a temporary log file if a log line/stdin has been provided + if logLine != "" || logFile == "-" { + tmpFile := filepath.Join(dir, "cscli_test_tmp.log") + + f, err = os.Create(tmpFile) + if err != nil { + return err + } + + if logLine != "" { + _, err = f.WriteString(logLine) + if err != nil { + return err + } + } else if logFile == "-" { + reader := bufio.NewReader(os.Stdin) + errCount := 0 + + for { + input, err := reader.ReadBytes('\n') + if err != nil && errors.Is(err, io.EOF) { + break + } + + if len(input) > 1 { + _, err = f.Write(input) + } + + if err != nil || len(input) <= 1 { + errCount++ + } + } + + if errCount > 0 { + log.Warnf("Failed to write %d lines to %s", errCount, tmpFile) + } + } + + f.Close() + // this is the file that was going to be read by crowdsec anyway + logFile = tmpFile + } + + if logFile != "" { + absolutePath, err := filepath.Abs(logFile) + if err != nil { + return fmt.Errorf("unable to get absolute path of '%s', exiting", logFile) + } + + dsn = fmt.Sprintf("file://%s", absolutePath) + + lineCount, err := getLineCountForFile(absolutePath) + if err != nil { + return err + } + + log.Debugf("file %s has %d lines", absolutePath, lineCount) + + if lineCount == 0 { + return fmt.Errorf("the log file is empty: %s", absolutePath) + } + + if lineCount > 100 { + log.Warnf("%s contains %d lines. This may take a lot of resources.", absolutePath, lineCount) + } + } + + if dsn == "" { + return errors.New("no acquisition (--file or --dsn) provided, can't run cscli test") + } + + cmdArgs := []string{"-c", cli.configFilePath, "-type", logType, "-dsn", dsn, "-dump-data", dir, "-no-api"} + + if labels != "" { + log.Debugf("adding labels %s", labels) + cmdArgs = append(cmdArgs, "-label", labels) + } + + crowdsecCmd := exec.Command(crowdsec, cmdArgs...) + + output, err := crowdsecCmd.CombinedOutput() + if err != nil { + fmt.Println(string(output)) + + return fmt.Errorf("fail to run crowdsec for test: %w", err) + } + + parserDumpFile := filepath.Join(dir, hubtest.ParserResultFileName) + bucketStateDumpFile := filepath.Join(dir, hubtest.BucketPourResultFileName) + + parserDump, err := dumps.LoadParserDump(parserDumpFile) + if err != nil { + return fmt.Errorf("unable to load parser dump result: %w", err) + } + + bucketStateDump, err := dumps.LoadBucketPourDump(bucketStateDumpFile) + if err != nil { + return fmt.Errorf("unable to load bucket dump result: %w", err) + } + + dumps.DumpTree(*parserDump, *bucketStateDump, opts) + + return nil +} diff --git a/cmd/crowdsec-cli/clihub/hub.go b/cmd/crowdsec-cli/clihub/hub.go new file mode 100644 index 00000000000..22568355546 --- /dev/null +++ b/cmd/crowdsec-cli/clihub/hub.go @@ -0,0 +1,246 @@ +package clihub + +import ( + "context" + "encoding/json" + "fmt" + "io" + + "github.com/fatih/color" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" +) + +type configGetter = func() *csconfig.Config + +type cliHub struct { + cfg configGetter +} + +func New(cfg configGetter) *cliHub { + return &cliHub{ + cfg: cfg, + } +} + +func (cli *cliHub) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "hub [action]", + Short: "Manage hub index", + Long: `Hub management + +List/update parsers/scenarios/postoverflows/collections from [Crowdsec Hub](https://hub.crowdsec.net). +The Hub is managed by cscli, to get the latest hub files from [Crowdsec Hub](https://hub.crowdsec.net), you need to update.`, + Example: `cscli hub list +cscli hub update +cscli hub upgrade`, + Args: cobra.ExactArgs(0), + DisableAutoGenTag: true, + } + + cmd.AddCommand(cli.newListCmd()) + cmd.AddCommand(cli.newUpdateCmd()) + cmd.AddCommand(cli.newUpgradeCmd()) + cmd.AddCommand(cli.newTypesCmd()) + + return cmd +} + +func (cli *cliHub) List(out io.Writer, hub *cwhub.Hub, all bool) error { + cfg := cli.cfg() + + for _, v := range hub.Warnings { + log.Info(v) + } + + for _, line := range hub.ItemStats() { + log.Info(line) + } + + items := make(map[string][]*cwhub.Item) + + var err error + + for _, itemType := range cwhub.ItemTypes { + items[itemType], err = SelectItems(hub, itemType, nil, !all) + if err != nil { + return err + } + } + + err = ListItems(out, cfg.Cscli.Color, cwhub.ItemTypes, items, true, cfg.Cscli.Output) + if err != nil { + return err + } + + return nil +} + +func (cli *cliHub) newListCmd() *cobra.Command { + var all bool + + cmd := &cobra.Command{ + Use: "list [-a]", + Short: "List all installed configurations", + Args: cobra.ExactArgs(0), + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, _ []string) error { + hub, err := require.Hub(cli.cfg(), nil, log.StandardLogger()) + if err != nil { + return err + } + + return cli.List(color.Output, hub, all) + }, + } + + flags := cmd.Flags() + flags.BoolVarP(&all, "all", "a", false, "List disabled items as well") + + return cmd +} + +func (cli *cliHub) update(ctx context.Context, withContent bool) error { + local := cli.cfg().Hub + remote := require.RemoteHub(ctx, cli.cfg()) + remote.EmbedItemContent = withContent + + // don't use require.Hub because if there is no index file, it would fail + hub, err := cwhub.NewHub(local, remote, log.StandardLogger()) + if err != nil { + return err + } + + if err := hub.Update(ctx); err != nil { + return fmt.Errorf("failed to update hub: %w", err) + } + + if err := hub.Load(); err != nil { + return fmt.Errorf("failed to load hub: %w", err) + } + + for _, v := range hub.Warnings { + log.Info(v) + } + + return nil +} + +func (cli *cliHub) newUpdateCmd() *cobra.Command { + withContent := false + + cmd := &cobra.Command{ + Use: "update", + Short: "Download the latest index (catalog of available configurations)", + Long: ` +Fetches the .index.json file from the hub, containing the list of available configs. +`, + Args: cobra.ExactArgs(0), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.update(cmd.Context(), withContent) + }, + } + + flags := cmd.Flags() + flags.BoolVar(&withContent, "with-content", false, "Download index with embedded item content") + + return cmd +} + +func (cli *cliHub) upgrade(ctx context.Context, force bool) error { + hub, err := require.Hub(cli.cfg(), require.RemoteHub(ctx, cli.cfg()), log.StandardLogger()) + if err != nil { + return err + } + + for _, itemType := range cwhub.ItemTypes { + updated := 0 + + log.Infof("Upgrading %s", itemType) + + for _, item := range hub.GetInstalledByType(itemType, true) { + didUpdate, err := item.Upgrade(ctx, force) + if err != nil { + return err + } + + if didUpdate { + updated++ + } + } + + log.Infof("Upgraded %d %s", updated, itemType) + } + + return nil +} + +func (cli *cliHub) newUpgradeCmd() *cobra.Command { + var force bool + + cmd := &cobra.Command{ + Use: "upgrade", + Short: "Upgrade all configurations to their latest version", + Long: ` +Upgrade all configs installed from Crowdsec Hub. Run 'sudo cscli hub update' if you want the latest versions available. +`, + Args: cobra.ExactArgs(0), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.upgrade(cmd.Context(), force) + }, + } + + flags := cmd.Flags() + flags.BoolVar(&force, "force", false, "Force upgrade: overwrite tainted and outdated files") + + return cmd +} + +func (cli *cliHub) types() error { + switch cli.cfg().Cscli.Output { + case "human": + s, err := yaml.Marshal(cwhub.ItemTypes) + if err != nil { + return err + } + + fmt.Print(string(s)) + case "json": + jsonStr, err := json.Marshal(cwhub.ItemTypes) + if err != nil { + return err + } + + fmt.Println(string(jsonStr)) + case "raw": + for _, itemType := range cwhub.ItemTypes { + fmt.Println(itemType) + } + } + + return nil +} + +func (cli *cliHub) newTypesCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "types", + Short: "List supported item types", + Long: ` +List the types of supported hub items. +`, + Args: cobra.ExactArgs(0), + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, _ []string) error { + return cli.types() + }, + } + + return cmd +} diff --git a/cmd/crowdsec-cli/clihub/item_metrics.go b/cmd/crowdsec-cli/clihub/item_metrics.go new file mode 100644 index 00000000000..f4af8f635db --- /dev/null +++ b/cmd/crowdsec-cli/clihub/item_metrics.go @@ -0,0 +1,291 @@ +package clihub + +import ( + "net/http" + "strconv" + "strings" + "time" + + "github.com/fatih/color" + dto "github.com/prometheus/client_model/go" + "github.com/prometheus/prom2json" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/go-cs-lib/trace" + + "github.com/crowdsecurity/crowdsec/pkg/cwhub" +) + +func showMetrics(prometheusURL string, hubItem *cwhub.Item, wantColor string) error { + switch hubItem.Type { + case cwhub.PARSERS: + metrics := getParserMetric(prometheusURL, hubItem.Name) + parserMetricsTable(color.Output, wantColor, hubItem.Name, metrics) + case cwhub.SCENARIOS: + metrics := getScenarioMetric(prometheusURL, hubItem.Name) + scenarioMetricsTable(color.Output, wantColor, hubItem.Name, metrics) + case cwhub.COLLECTIONS: + for _, sub := range hubItem.SubItems() { + if err := showMetrics(prometheusURL, sub, wantColor); err != nil { + return err + } + } + case cwhub.APPSEC_RULES: + metrics := getAppsecRuleMetric(prometheusURL, hubItem.Name) + appsecMetricsTable(color.Output, wantColor, hubItem.Name, metrics) + default: // no metrics for this item type + } + + return nil +} + +// getParserMetric is a complete rip from prom2json +func getParserMetric(url string, itemName string) map[string]map[string]int { + stats := make(map[string]map[string]int) + + result := getPrometheusMetric(url) + for idx, fam := range result { + if !strings.HasPrefix(fam.Name, "cs_") { + continue + } + + log.Tracef("round %d", idx) + + for _, m := range fam.Metrics { + metric, ok := m.(prom2json.Metric) + if !ok { + log.Debugf("failed to convert metric to prom2json.Metric") + continue + } + + name, ok := metric.Labels["name"] + if !ok { + log.Debugf("no name in Metric %v", metric.Labels) + } + + if name != itemName { + continue + } + + source, ok := metric.Labels["source"] + + if !ok { + log.Debugf("no source in Metric %v", metric.Labels) + } else { + if srctype, ok := metric.Labels["type"]; ok { + source = srctype + ":" + source + } + } + + value := m.(prom2json.Metric).Value + + fval, err := strconv.ParseFloat(value, 32) + if err != nil { + log.Errorf("Unexpected int value %s : %s", value, err) + continue + } + + ival := int(fval) + + switch fam.Name { + case "cs_reader_hits_total": + if _, ok := stats[source]; !ok { + stats[source] = make(map[string]int) + stats[source]["parsed"] = 0 + stats[source]["reads"] = 0 + stats[source]["unparsed"] = 0 + stats[source]["hits"] = 0 + } + stats[source]["reads"] += ival + case "cs_parser_hits_ok_total": + if _, ok := stats[source]; !ok { + stats[source] = make(map[string]int) + } + stats[source]["parsed"] += ival + case "cs_parser_hits_ko_total": + if _, ok := stats[source]; !ok { + stats[source] = make(map[string]int) + } + stats[source]["unparsed"] += ival + case "cs_node_hits_total": + if _, ok := stats[source]; !ok { + stats[source] = make(map[string]int) + } + stats[source]["hits"] += ival + case "cs_node_hits_ok_total": + if _, ok := stats[source]; !ok { + stats[source] = make(map[string]int) + } + stats[source]["parsed"] += ival + case "cs_node_hits_ko_total": + if _, ok := stats[source]; !ok { + stats[source] = make(map[string]int) + } + stats[source]["unparsed"] += ival + default: + continue + } + } + } + + return stats +} + +func getScenarioMetric(url string, itemName string) map[string]int { + stats := make(map[string]int) + + stats["instantiation"] = 0 + stats["curr_count"] = 0 + stats["overflow"] = 0 + stats["pour"] = 0 + stats["underflow"] = 0 + + result := getPrometheusMetric(url) + for idx, fam := range result { + if !strings.HasPrefix(fam.Name, "cs_") { + continue + } + + log.Tracef("round %d", idx) + + for _, m := range fam.Metrics { + metric, ok := m.(prom2json.Metric) + if !ok { + log.Debugf("failed to convert metric to prom2json.Metric") + continue + } + + name, ok := metric.Labels["name"] + + if !ok { + log.Debugf("no name in Metric %v", metric.Labels) + } + + if name != itemName { + continue + } + + value := m.(prom2json.Metric).Value + + fval, err := strconv.ParseFloat(value, 32) + if err != nil { + log.Errorf("Unexpected int value %s : %s", value, err) + continue + } + + ival := int(fval) + + switch fam.Name { + case "cs_bucket_created_total": + stats["instantiation"] += ival + case "cs_buckets": + stats["curr_count"] += ival + case "cs_bucket_overflowed_total": + stats["overflow"] += ival + case "cs_bucket_poured_total": + stats["pour"] += ival + case "cs_bucket_underflowed_total": + stats["underflow"] += ival + default: + continue + } + } + } + + return stats +} + +func getAppsecRuleMetric(url string, itemName string) map[string]int { + stats := make(map[string]int) + + stats["inband_hits"] = 0 + stats["outband_hits"] = 0 + + results := getPrometheusMetric(url) + for idx, fam := range results { + if !strings.HasPrefix(fam.Name, "cs_") { + continue + } + + log.Tracef("round %d", idx) + + for _, m := range fam.Metrics { + metric, ok := m.(prom2json.Metric) + if !ok { + log.Debugf("failed to convert metric to prom2json.Metric") + continue + } + + name, ok := metric.Labels["rule_name"] + + if !ok { + log.Debugf("no rule_name in Metric %v", metric.Labels) + } + + if name != itemName { + continue + } + + band, ok := metric.Labels["type"] + if !ok { + log.Debugf("no type in Metric %v", metric.Labels) + } + + value := m.(prom2json.Metric).Value + + fval, err := strconv.ParseFloat(value, 32) + if err != nil { + log.Errorf("Unexpected int value %s : %s", value, err) + continue + } + + ival := int(fval) + + switch fam.Name { + case "cs_appsec_rule_hits": + switch band { + case "inband": + stats["inband_hits"] += ival + case "outband": + stats["outband_hits"] += ival + default: + continue + } + default: + continue + } + } + } + + return stats +} + +func getPrometheusMetric(url string) []*prom2json.Family { + mfChan := make(chan *dto.MetricFamily, 1024) + + // Start with the DefaultTransport for sane defaults. + transport := http.DefaultTransport.(*http.Transport).Clone() + // Conservatively disable HTTP keep-alives as this program will only + // ever need a single HTTP request. + transport.DisableKeepAlives = true + // Timeout early if the server doesn't even return the headers. + transport.ResponseHeaderTimeout = time.Minute + + go func() { + defer trace.CatchPanic("crowdsec/GetPrometheusMetric") + + err := prom2json.FetchMetricFamilies(url, mfChan, transport) + if err != nil { + log.Fatalf("failed to fetch prometheus metrics : %v", err) + } + }() + + result := []*prom2json.Family{} + for mf := range mfChan { + result = append(result, prom2json.NewFamily(mf)) + } + + log.Debugf("Finished reading prometheus output, %d entries", len(result)) + + return result +} diff --git a/cmd/crowdsec-cli/clihub/items.go b/cmd/crowdsec-cli/clihub/items.go new file mode 100644 index 00000000000..f86fe65a2a1 --- /dev/null +++ b/cmd/crowdsec-cli/clihub/items.go @@ -0,0 +1,186 @@ +package clihub + +import ( + "encoding/csv" + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "slices" + "strings" + + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/pkg/cwhub" +) + +// selectItems returns a slice of items of a given type, selected by name and sorted by case-insensitive name +func SelectItems(hub *cwhub.Hub, itemType string, args []string, installedOnly bool) ([]*cwhub.Item, error) { + allItems := hub.GetItemsByType(itemType, true) + + itemNames := make([]string, len(allItems)) + for idx, item := range allItems { + itemNames[idx] = item.Name + } + + notExist := []string{} + + if len(args) > 0 { + for _, arg := range args { + if !slices.Contains(itemNames, arg) { + notExist = append(notExist, arg) + } + } + } + + if len(notExist) > 0 { + return nil, fmt.Errorf("item(s) '%s' not found in %s", strings.Join(notExist, ", "), itemType) + } + + if len(args) > 0 { + itemNames = args + installedOnly = false + } + + wantedItems := make([]*cwhub.Item, 0, len(itemNames)) + + for _, itemName := range itemNames { + item := hub.GetItem(itemType, itemName) + if installedOnly && !item.State.Installed { + continue + } + + wantedItems = append(wantedItems, item) + } + + return wantedItems, nil +} + +func ListItems(out io.Writer, wantColor string, itemTypes []string, items map[string][]*cwhub.Item, omitIfEmpty bool, output string) error { + switch output { + case "human": + nothingToDisplay := true + + for _, itemType := range itemTypes { + if omitIfEmpty && len(items[itemType]) == 0 { + continue + } + + listHubItemTable(out, wantColor, "\n"+strings.ToUpper(itemType), items[itemType]) + + nothingToDisplay = false + } + + if nothingToDisplay { + fmt.Println("No items to display") + } + case "json": + type itemHubStatus struct { + Name string `json:"name"` + LocalVersion string `json:"local_version"` + LocalPath string `json:"local_path"` + Description string `json:"description"` + UTF8Status string `json:"utf8_status"` + Status string `json:"status"` + } + + hubStatus := make(map[string][]itemHubStatus) + for _, itemType := range itemTypes { + // empty slice in case there are no items of this type + hubStatus[itemType] = make([]itemHubStatus, len(items[itemType])) + + for i, item := range items[itemType] { + status := item.State.Text() + statusEmo := item.State.Emoji() + hubStatus[itemType][i] = itemHubStatus{ + Name: item.Name, + LocalVersion: item.State.LocalVersion, + LocalPath: item.State.LocalPath, + Description: item.Description, + Status: status, + UTF8Status: fmt.Sprintf("%v %s", statusEmo, status), + } + } + } + + x, err := json.MarshalIndent(hubStatus, "", " ") + if err != nil { + return fmt.Errorf("failed to parse: %w", err) + } + + out.Write(x) + case "raw": + csvwriter := csv.NewWriter(out) + + header := []string{"name", "status", "version", "description"} + if len(itemTypes) > 1 { + header = append(header, "type") + } + + if err := csvwriter.Write(header); err != nil { + return fmt.Errorf("failed to write header: %w", err) + } + + for _, itemType := range itemTypes { + for _, item := range items[itemType] { + row := []string{ + item.Name, + item.State.Text(), + item.State.LocalVersion, + item.Description, + } + if len(itemTypes) > 1 { + row = append(row, itemType) + } + + if err := csvwriter.Write(row); err != nil { + return fmt.Errorf("failed to write raw output: %w", err) + } + } + } + + csvwriter.Flush() + } + + return nil +} + +func InspectItem(item *cwhub.Item, wantMetrics bool, output string, prometheusURL string, wantColor string) error { + switch output { + case "human", "raw": + enc := yaml.NewEncoder(os.Stdout) + enc.SetIndent(2) + + if err := enc.Encode(item); err != nil { + return fmt.Errorf("unable to encode item: %w", err) + } + case "json": + b, err := json.MarshalIndent(*item, "", " ") + if err != nil { + return fmt.Errorf("unable to serialize item: %w", err) + } + + fmt.Print(string(b)) + } + + if output != "human" { + return nil + } + + if item.State.Tainted { + fmt.Println() + fmt.Printf(`This item is tainted. Use "%s %s inspect --diff %s" to see why.`, filepath.Base(os.Args[0]), item.Type, item.Name) + fmt.Println() + } + + if wantMetrics { + fmt.Printf("\nCurrent metrics: \n") + + if err := showMetrics(prometheusURL, item, wantColor); err != nil { + return err + } + } + + return nil +} diff --git a/cmd/crowdsec-cli/clihub/utils_table.go b/cmd/crowdsec-cli/clihub/utils_table.go new file mode 100644 index 00000000000..98f14341b10 --- /dev/null +++ b/cmd/crowdsec-cli/clihub/utils_table.go @@ -0,0 +1,85 @@ +package clihub + +import ( + "fmt" + "io" + "strconv" + + "github.com/jedib0t/go-pretty/v6/table" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" + "github.com/crowdsecurity/crowdsec/pkg/emoji" +) + +func listHubItemTable(out io.Writer, wantColor string, title string, items []*cwhub.Item) { + t := cstable.NewLight(out, wantColor).Writer + t.AppendHeader(table.Row{"Name", fmt.Sprintf("%v Status", emoji.Package), "Version", "Local Path"}) + + for _, item := range items { + status := fmt.Sprintf("%v %s", item.State.Emoji(), item.State.Text()) + t.AppendRow(table.Row{item.Name, status, item.State.LocalVersion, item.State.LocalPath}) + } + + io.WriteString(out, title+"\n") + io.WriteString(out, t.Render()+"\n") +} + +func appsecMetricsTable(out io.Writer, wantColor string, itemName string, metrics map[string]int) { + t := cstable.NewLight(out, wantColor).Writer + t.AppendHeader(table.Row{"Inband Hits", "Outband Hits"}) + + t.AppendRow(table.Row{ + strconv.Itoa(metrics["inband_hits"]), + strconv.Itoa(metrics["outband_hits"]), + }) + + io.WriteString(out, fmt.Sprintf("\n - (AppSec Rule) %s:\n", itemName)) + io.WriteString(out, t.Render()+"\n") +} + +func scenarioMetricsTable(out io.Writer, wantColor string, itemName string, metrics map[string]int) { + if metrics["instantiation"] == 0 { + return + } + + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Current Count", "Overflows", "Instantiated", "Poured", "Expired"}) + + t.AppendRow(table.Row{ + strconv.Itoa(metrics["curr_count"]), + strconv.Itoa(metrics["overflow"]), + strconv.Itoa(metrics["instantiation"]), + strconv.Itoa(metrics["pour"]), + strconv.Itoa(metrics["underflow"]), + }) + + io.WriteString(out, fmt.Sprintf("\n - (Scenario) %s:\n", itemName)) + io.WriteString(out, t.Render()+"\n") +} + +func parserMetricsTable(out io.Writer, wantColor string, itemName string, metrics map[string]map[string]int) { + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Parsers", "Hits", "Parsed", "Unparsed"}) + + // don't show table if no hits + showTable := false + + for source, stats := range metrics { + if stats["hits"] > 0 { + t.AppendRow(table.Row{ + source, + strconv.Itoa(stats["hits"]), + strconv.Itoa(stats["parsed"]), + strconv.Itoa(stats["unparsed"]), + }) + + showTable = true + } + } + + if showTable { + io.WriteString(out, fmt.Sprintf("\n - (Parser) %s:\n", itemName)) + io.WriteString(out, t.Render()+"\n") + } +} diff --git a/cmd/crowdsec-cli/clihubtest/clean.go b/cmd/crowdsec-cli/clihubtest/clean.go new file mode 100644 index 00000000000..e3b40b6bd57 --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/clean.go @@ -0,0 +1,31 @@ +package clihubtest + +import ( + "fmt" + + "github.com/spf13/cobra" +) + +func (cli *cliHubTest) newCleanCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "clean", + Short: "clean [test_name]", + Args: cobra.MinimumNArgs(1), + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + for _, testName := range args { + test, err := hubPtr.LoadTestItem(testName) + if err != nil { + return fmt.Errorf("unable to load test '%s': %w", testName, err) + } + if err := test.Clean(); err != nil { + return fmt.Errorf("unable to clean test '%s' env: %w", test.Name, err) + } + } + + return nil + }, + } + + return cmd +} diff --git a/cmd/crowdsec-cli/clihubtest/coverage.go b/cmd/crowdsec-cli/clihubtest/coverage.go new file mode 100644 index 00000000000..5a4f231caf5 --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/coverage.go @@ -0,0 +1,166 @@ +package clihubtest + +import ( + "encoding/json" + "errors" + "fmt" + "math" + + "github.com/fatih/color" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/pkg/hubtest" +) + +// getCoverage returns the coverage and the percentage of tests that passed +func getCoverage(show bool, getCoverageFunc func() ([]hubtest.Coverage, error)) ([]hubtest.Coverage, int, error) { + if !show { + return nil, 0, nil + } + + coverage, err := getCoverageFunc() + if err != nil { + return nil, 0, fmt.Errorf("while getting coverage: %w", err) + } + + tested := 0 + + for _, test := range coverage { + if test.TestsCount > 0 { + tested++ + } + } + + // keep coverage 0 if there's no tests? + percent := 0 + if len(coverage) > 0 { + percent = int(math.Round((float64(tested) / float64(len(coverage)) * 100))) + } + + return coverage, percent, nil +} + +func (cli *cliHubTest) coverage(showScenarioCov bool, showParserCov bool, showAppsecCov bool, showOnlyPercent bool) error { + cfg := cli.cfg() + + // for this one we explicitly don't do for appsec + if err := HubTest.LoadAllTests(); err != nil { + return fmt.Errorf("unable to load all tests: %+v", err) + } + + var err error + + // if all are false (flag by default), show them + if !showParserCov && !showScenarioCov && !showAppsecCov { + showParserCov = true + showScenarioCov = true + showAppsecCov = true + } + + parserCoverage, parserCoveragePercent, err := getCoverage(showParserCov, HubTest.GetParsersCoverage) + if err != nil { + return err + } + + scenarioCoverage, scenarioCoveragePercent, err := getCoverage(showScenarioCov, HubTest.GetScenariosCoverage) + if err != nil { + return err + } + + appsecRuleCoverage, appsecRuleCoveragePercent, err := getCoverage(showAppsecCov, HubTest.GetAppsecCoverage) + if err != nil { + return err + } + + if showOnlyPercent { + switch { + case showParserCov: + fmt.Printf("parsers=%d%%", parserCoveragePercent) + case showScenarioCov: + fmt.Printf("scenarios=%d%%", scenarioCoveragePercent) + case showAppsecCov: + fmt.Printf("appsec_rules=%d%%", appsecRuleCoveragePercent) + } + + return nil + } + + switch cfg.Cscli.Output { + case "human": + if showParserCov { + hubTestCoverageTable(color.Output, cfg.Cscli.Color, []string{"Parser", "Status", "Number of tests"}, parserCoverage) + } + + if showScenarioCov { + hubTestCoverageTable(color.Output, cfg.Cscli.Color, []string{"Scenario", "Status", "Number of tests"}, parserCoverage) + } + + if showAppsecCov { + hubTestCoverageTable(color.Output, cfg.Cscli.Color, []string{"Appsec Rule", "Status", "Number of tests"}, parserCoverage) + } + + fmt.Println() + + if showParserCov { + fmt.Printf("PARSERS : %d%% of coverage\n", parserCoveragePercent) + } + + if showScenarioCov { + fmt.Printf("SCENARIOS : %d%% of coverage\n", scenarioCoveragePercent) + } + + if showAppsecCov { + fmt.Printf("APPSEC RULES : %d%% of coverage\n", appsecRuleCoveragePercent) + } + case "json": + dump, err := json.MarshalIndent(parserCoverage, "", " ") + if err != nil { + return err + } + + fmt.Printf("%s", dump) + + dump, err = json.MarshalIndent(scenarioCoverage, "", " ") + if err != nil { + return err + } + + fmt.Printf("%s", dump) + + dump, err = json.MarshalIndent(appsecRuleCoverage, "", " ") + if err != nil { + return err + } + + fmt.Printf("%s", dump) + default: + return errors.New("only human/json output modes are supported") + } + + return nil +} + +func (cli *cliHubTest) newCoverageCmd() *cobra.Command { + var ( + showParserCov bool + showScenarioCov bool + showOnlyPercent bool + showAppsecCov bool + ) + + cmd := &cobra.Command{ + Use: "coverage", + Short: "coverage", + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, _ []string) error { + return cli.coverage(showScenarioCov, showParserCov, showAppsecCov, showOnlyPercent) + }, + } + + cmd.PersistentFlags().BoolVar(&showOnlyPercent, "percent", false, "Show only percentages of coverage") + cmd.PersistentFlags().BoolVar(&showParserCov, "parsers", false, "Show only parsers coverage") + cmd.PersistentFlags().BoolVar(&showScenarioCov, "scenarios", false, "Show only scenarios coverage") + cmd.PersistentFlags().BoolVar(&showAppsecCov, "appsec", false, "Show only appsec coverage") + + return cmd +} diff --git a/cmd/crowdsec-cli/clihubtest/create.go b/cmd/crowdsec-cli/clihubtest/create.go new file mode 100644 index 00000000000..3822bed8903 --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/create.go @@ -0,0 +1,158 @@ +package clihubtest + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "text/template" + + "github.com/spf13/cobra" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/pkg/hubtest" +) + +func (cli *cliHubTest) newCreateCmd() *cobra.Command { + var ( + ignoreParsers bool + labels map[string]string + logType string + ) + + parsers := []string{} + postoverflows := []string{} + scenarios := []string{} + + cmd := &cobra.Command{ + Use: "create", + Short: "create [test_name]", + Example: `cscli hubtest create my-awesome-test --type syslog +cscli hubtest create my-nginx-custom-test --type nginx +cscli hubtest create my-scenario-test --parsers crowdsecurity/nginx --scenarios crowdsecurity/http-probing`, + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + testName := args[0] + testPath := filepath.Join(hubPtr.HubTestPath, testName) + if _, err := os.Stat(testPath); os.IsExist(err) { + return fmt.Errorf("test '%s' already exists in '%s', exiting", testName, testPath) + } + + if isAppsecTest { + logType = "appsec" + } + + if logType == "" { + return errors.New("please provide a type (--type) for the test") + } + + if err := os.MkdirAll(testPath, os.ModePerm); err != nil { + return fmt.Errorf("unable to create folder '%s': %+v", testPath, err) + } + + configFilePath := filepath.Join(testPath, "config.yaml") + + configFileData := &hubtest.HubTestItemConfig{} + if logType == "appsec" { + // create empty nuclei template file + nucleiFileName := testName + ".yaml" + nucleiFilePath := filepath.Join(testPath, nucleiFileName) + + nucleiFile, err := os.OpenFile(nucleiFilePath, os.O_RDWR|os.O_CREATE, 0o755) + if err != nil { + return err + } + + ntpl := template.Must(template.New("nuclei").Parse(hubtest.TemplateNucleiFile)) + if ntpl == nil { + return errors.New("unable to parse nuclei template") + } + ntpl.ExecuteTemplate(nucleiFile, "nuclei", struct{ TestName string }{TestName: testName}) + nucleiFile.Close() + configFileData.AppsecRules = []string{"./appsec-rules//your_rule_here.yaml"} + configFileData.NucleiTemplate = nucleiFileName + fmt.Println() + fmt.Printf(" Test name : %s\n", testName) + fmt.Printf(" Test path : %s\n", testPath) + fmt.Printf(" Config File : %s\n", configFilePath) + fmt.Printf(" Nuclei Template : %s\n", nucleiFilePath) + } else { + // create empty log file + logFileName := testName + ".log" + logFilePath := filepath.Join(testPath, logFileName) + logFile, err := os.Create(logFilePath) + if err != nil { + return err + } + logFile.Close() + + // create empty parser assertion file + parserAssertFilePath := filepath.Join(testPath, hubtest.ParserAssertFileName) + parserAssertFile, err := os.Create(parserAssertFilePath) + if err != nil { + return err + } + parserAssertFile.Close() + // create empty scenario assertion file + scenarioAssertFilePath := filepath.Join(testPath, hubtest.ScenarioAssertFileName) + scenarioAssertFile, err := os.Create(scenarioAssertFilePath) + if err != nil { + return err + } + scenarioAssertFile.Close() + + parsers = append(parsers, "crowdsecurity/syslog-logs") + parsers = append(parsers, "crowdsecurity/dateparse-enrich") + + if len(scenarios) == 0 { + scenarios = append(scenarios, "") + } + + if len(postoverflows) == 0 { + postoverflows = append(postoverflows, "") + } + configFileData.Parsers = parsers + configFileData.Scenarios = scenarios + configFileData.PostOverflows = postoverflows + configFileData.LogFile = logFileName + configFileData.LogType = logType + configFileData.IgnoreParsers = ignoreParsers + configFileData.Labels = labels + fmt.Println() + fmt.Printf(" Test name : %s\n", testName) + fmt.Printf(" Test path : %s\n", testPath) + fmt.Printf(" Log file : %s (please fill it with logs)\n", logFilePath) + fmt.Printf(" Parser assertion file : %s (please fill it with assertion)\n", parserAssertFilePath) + fmt.Printf(" Scenario assertion file : %s (please fill it with assertion)\n", scenarioAssertFilePath) + fmt.Printf(" Configuration File : %s (please fill it with parsers, scenarios...)\n", configFilePath) + } + + fd, err := os.Create(configFilePath) + if err != nil { + return fmt.Errorf("open: %w", err) + } + data, err := yaml.Marshal(configFileData) + if err != nil { + return fmt.Errorf("serialize: %w", err) + } + _, err = fd.Write(data) + if err != nil { + return fmt.Errorf("write: %w", err) + } + if err := fd.Close(); err != nil { + return fmt.Errorf("close: %w", err) + } + + return nil + }, + } + + cmd.PersistentFlags().StringVarP(&logType, "type", "t", "", "Log type of the test") + cmd.Flags().StringSliceVarP(&parsers, "parsers", "p", parsers, "Parsers to add to test") + cmd.Flags().StringSliceVar(&postoverflows, "postoverflows", postoverflows, "Postoverflows to add to test") + cmd.Flags().StringSliceVarP(&scenarios, "scenarios", "s", scenarios, "Scenarios to add to test") + cmd.PersistentFlags().BoolVar(&ignoreParsers, "ignore-parsers", false, "Don't run test on parsers") + + return cmd +} diff --git a/cmd/crowdsec-cli/clihubtest/eval.go b/cmd/crowdsec-cli/clihubtest/eval.go new file mode 100644 index 00000000000..83e9eae9c15 --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/eval.go @@ -0,0 +1,44 @@ +package clihubtest + +import ( + "fmt" + + "github.com/spf13/cobra" +) + +func (cli *cliHubTest) newEvalCmd() *cobra.Command { + var evalExpression string + + cmd := &cobra.Command{ + Use: "eval", + Short: "eval [test_name]", + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + for _, testName := range args { + test, err := hubPtr.LoadTestItem(testName) + if err != nil { + return fmt.Errorf("can't load test: %+v", err) + } + + err = test.ParserAssert.LoadTest(test.ParserResultFile) + if err != nil { + return fmt.Errorf("can't load test results from '%s': %+v", test.ParserResultFile, err) + } + + output, err := test.ParserAssert.EvalExpression(evalExpression) + if err != nil { + return err + } + + fmt.Print(output) + } + + return nil + }, + } + + cmd.PersistentFlags().StringVarP(&evalExpression, "expr", "e", "", "Expression to eval") + + return cmd +} diff --git a/cmd/crowdsec-cli/clihubtest/explain.go b/cmd/crowdsec-cli/clihubtest/explain.go new file mode 100644 index 00000000000..dbe10fa7ec0 --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/explain.go @@ -0,0 +1,76 @@ +package clihubtest + +import ( + "fmt" + + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/pkg/dumps" +) + +func (cli *cliHubTest) explain(testName string, details bool, skipOk bool) error { + test, err := HubTest.LoadTestItem(testName) + if err != nil { + return fmt.Errorf("can't load test: %+v", err) + } + + err = test.ParserAssert.LoadTest(test.ParserResultFile) + if err != nil { + if err = test.Run(); err != nil { + return fmt.Errorf("running test '%s' failed: %+v", test.Name, err) + } + + if err = test.ParserAssert.LoadTest(test.ParserResultFile); err != nil { + return fmt.Errorf("unable to load parser result after run: %w", err) + } + } + + err = test.ScenarioAssert.LoadTest(test.ScenarioResultFile, test.BucketPourResultFile) + if err != nil { + if err = test.Run(); err != nil { + return fmt.Errorf("running test '%s' failed: %+v", test.Name, err) + } + + if err = test.ScenarioAssert.LoadTest(test.ScenarioResultFile, test.BucketPourResultFile); err != nil { + return fmt.Errorf("unable to load scenario result after run: %w", err) + } + } + + opts := dumps.DumpOpts{ + Details: details, + SkipOk: skipOk, + } + + dumps.DumpTree(*test.ParserAssert.TestData, *test.ScenarioAssert.PourData, opts) + + return nil +} + +func (cli *cliHubTest) newExplainCmd() *cobra.Command { + var ( + details bool + skipOk bool + ) + + cmd := &cobra.Command{ + Use: "explain", + Short: "explain [test_name]", + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + for _, testName := range args { + if err := cli.explain(testName, details, skipOk); err != nil { + return err + } + } + + return nil + }, + } + + flags := cmd.Flags() + flags.BoolVarP(&details, "verbose", "v", false, "Display individual changes") + flags.BoolVar(&skipOk, "failures", false, "Only show failed lines") + + return cmd +} diff --git a/cmd/crowdsec-cli/clihubtest/hubtest.go b/cmd/crowdsec-cli/clihubtest/hubtest.go new file mode 100644 index 00000000000..3420e21e1e2 --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/hubtest.go @@ -0,0 +1,81 @@ +package clihubtest + +import ( + "fmt" + + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/hubtest" +) + +type configGetter func() *csconfig.Config + +var ( + HubTest hubtest.HubTest + HubAppsecTests hubtest.HubTest + hubPtr *hubtest.HubTest + isAppsecTest bool +) + +type cliHubTest struct { + cfg configGetter +} + +func New(cfg configGetter) *cliHubTest { + return &cliHubTest{ + cfg: cfg, + } +} + +func (cli *cliHubTest) NewCommand() *cobra.Command { + var ( + hubPath string + crowdsecPath string + cscliPath string + ) + + cmd := &cobra.Command{ + Use: "hubtest", + Short: "Run functional tests on hub configurations", + Long: "Run functional tests on hub configurations (parsers, scenarios, collections...)", + Args: cobra.ExactArgs(0), + DisableAutoGenTag: true, + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + var err error + HubTest, err = hubtest.NewHubTest(hubPath, crowdsecPath, cscliPath, false) + if err != nil { + return fmt.Errorf("unable to load hubtest: %+v", err) + } + + HubAppsecTests, err = hubtest.NewHubTest(hubPath, crowdsecPath, cscliPath, true) + if err != nil { + return fmt.Errorf("unable to load appsec specific hubtest: %+v", err) + } + + // commands will use the hubPtr, will point to the default hubTest object, or the one dedicated to appsec tests + hubPtr = &HubTest + if isAppsecTest { + hubPtr = &HubAppsecTests + } + + return nil + }, + } + + cmd.PersistentFlags().StringVar(&hubPath, "hub", ".", "Path to hub folder") + cmd.PersistentFlags().StringVar(&crowdsecPath, "crowdsec", "crowdsec", "Path to crowdsec") + cmd.PersistentFlags().StringVar(&cscliPath, "cscli", "cscli", "Path to cscli") + cmd.PersistentFlags().BoolVar(&isAppsecTest, "appsec", false, "Command relates to appsec tests") + + cmd.AddCommand(cli.newCreateCmd()) + cmd.AddCommand(cli.newRunCmd()) + cmd.AddCommand(cli.newCleanCmd()) + cmd.AddCommand(cli.newInfoCmd()) + cmd.AddCommand(cli.newListCmd()) + cmd.AddCommand(cli.newCoverageCmd()) + cmd.AddCommand(cli.newEvalCmd()) + cmd.AddCommand(cli.newExplainCmd()) + + return cmd +} diff --git a/cmd/crowdsec-cli/clihubtest/info.go b/cmd/crowdsec-cli/clihubtest/info.go new file mode 100644 index 00000000000..a5d760eea01 --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/info.go @@ -0,0 +1,44 @@ +package clihubtest + +import ( + "fmt" + "path/filepath" + "strings" + + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/pkg/hubtest" +) + +func (cli *cliHubTest) newInfoCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "info", + Short: "info [test_name]", + Args: cobra.MinimumNArgs(1), + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + for _, testName := range args { + test, err := hubPtr.LoadTestItem(testName) + if err != nil { + return fmt.Errorf("unable to load test '%s': %w", testName, err) + } + fmt.Println() + fmt.Printf(" Test name : %s\n", test.Name) + fmt.Printf(" Test path : %s\n", test.Path) + if isAppsecTest { + fmt.Printf(" Nuclei Template : %s\n", test.Config.NucleiTemplate) + fmt.Printf(" Appsec Rules : %s\n", strings.Join(test.Config.AppsecRules, ", ")) + } else { + fmt.Printf(" Log file : %s\n", filepath.Join(test.Path, test.Config.LogFile)) + fmt.Printf(" Parser assertion file : %s\n", filepath.Join(test.Path, hubtest.ParserAssertFileName)) + fmt.Printf(" Scenario assertion file : %s\n", filepath.Join(test.Path, hubtest.ScenarioAssertFileName)) + } + fmt.Printf(" Configuration File : %s\n", filepath.Join(test.Path, "config.yaml")) + } + + return nil + }, + } + + return cmd +} diff --git a/cmd/crowdsec-cli/clihubtest/list.go b/cmd/crowdsec-cli/clihubtest/list.go new file mode 100644 index 00000000000..3e76824a18e --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/list.go @@ -0,0 +1,42 @@ +package clihubtest + +import ( + "encoding/json" + "errors" + "fmt" + + "github.com/fatih/color" + "github.com/spf13/cobra" +) + +func (cli *cliHubTest) newListCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "list", + Short: "list", + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + + if err := hubPtr.LoadAllTests(); err != nil { + return fmt.Errorf("unable to load all tests: %w", err) + } + + switch cfg.Cscli.Output { + case "human": + hubTestListTable(color.Output, cfg.Cscli.Color, hubPtr.Tests) + case "json": + j, err := json.MarshalIndent(hubPtr.Tests, " ", " ") + if err != nil { + return err + } + fmt.Println(string(j)) + default: + return errors.New("only human/json output modes are supported") + } + + return nil + }, + } + + return cmd +} diff --git a/cmd/crowdsec-cli/clihubtest/run.go b/cmd/crowdsec-cli/clihubtest/run.go new file mode 100644 index 00000000000..31cceb81884 --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/run.go @@ -0,0 +1,213 @@ +package clihubtest + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "strings" + + "github.com/AlecAivazis/survey/v2" + "github.com/fatih/color" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/pkg/emoji" + "github.com/crowdsecurity/crowdsec/pkg/hubtest" +) + +func (cli *cliHubTest) run(runAll bool, nucleiTargetHost string, appSecHost string, args []string) error { + cfg := cli.cfg() + + if !runAll && len(args) == 0 { + return errors.New("please provide test to run or --all flag") + } + + hubPtr.NucleiTargetHost = nucleiTargetHost + hubPtr.AppSecHost = appSecHost + + if runAll { + if err := hubPtr.LoadAllTests(); err != nil { + return fmt.Errorf("unable to load all tests: %+v", err) + } + } else { + for _, testName := range args { + _, err := hubPtr.LoadTestItem(testName) + if err != nil { + return fmt.Errorf("unable to load test '%s': %w", testName, err) + } + } + } + + // set timezone to avoid DST issues + os.Setenv("TZ", "UTC") + + for _, test := range hubPtr.Tests { + if cfg.Cscli.Output == "human" { + log.Infof("Running test '%s'", test.Name) + } + + err := test.Run() + if err != nil { + log.Errorf("running test '%s' failed: %+v", test.Name, err) + } + } + + return nil +} + +func printParserFailures(test *hubtest.HubTestItem) { + if len(test.ParserAssert.Fails) == 0 { + return + } + + fmt.Println() + log.Errorf("Parser test '%s' failed (%d errors)\n", test.Name, len(test.ParserAssert.Fails)) + + for _, fail := range test.ParserAssert.Fails { + fmt.Printf("(L.%d) %s => %s\n", fail.Line, emoji.RedCircle, fail.Expression) + fmt.Printf(" Actual expression values:\n") + + for key, value := range fail.Debug { + fmt.Printf(" %s = '%s'\n", key, strings.TrimSuffix(value, "\n")) + } + + fmt.Println() + } +} + +func printScenarioFailures(test *hubtest.HubTestItem) { + if len(test.ScenarioAssert.Fails) == 0 { + return + } + + fmt.Println() + log.Errorf("Scenario test '%s' failed (%d errors)\n", test.Name, len(test.ScenarioAssert.Fails)) + + for _, fail := range test.ScenarioAssert.Fails { + fmt.Printf("(L.%d) %s => %s\n", fail.Line, emoji.RedCircle, fail.Expression) + fmt.Printf(" Actual expression values:\n") + + for key, value := range fail.Debug { + fmt.Printf(" %s = '%s'\n", key, strings.TrimSuffix(value, "\n")) + } + + fmt.Println() + } +} + +func (cli *cliHubTest) newRunCmd() *cobra.Command { + var ( + noClean bool + runAll bool + forceClean bool + nucleiTargetHost string + appSecHost string + ) + + cmd := &cobra.Command{ + Use: "run", + Short: "run [test_name]", + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + return cli.run(runAll, nucleiTargetHost, appSecHost, args) + }, + PersistentPostRunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + + success := true + testResult := make(map[string]bool) + for _, test := range hubPtr.Tests { + if test.AutoGen && !isAppsecTest { + if test.ParserAssert.AutoGenAssert { + log.Warningf("Assert file '%s' is empty, generating assertion:", test.ParserAssert.File) + fmt.Println() + fmt.Println(test.ParserAssert.AutoGenAssertData) + } + if test.ScenarioAssert.AutoGenAssert { + log.Warningf("Assert file '%s' is empty, generating assertion:", test.ScenarioAssert.File) + fmt.Println() + fmt.Println(test.ScenarioAssert.AutoGenAssertData) + } + if !noClean { + if err := test.Clean(); err != nil { + return fmt.Errorf("unable to clean test '%s' env: %w", test.Name, err) + } + } + + return fmt.Errorf("please fill your assert file(s) for test '%s', exiting", test.Name) + } + testResult[test.Name] = test.Success + if test.Success { + if cfg.Cscli.Output == "human" { + log.Infof("Test '%s' passed successfully (%d assertions)\n", test.Name, test.ParserAssert.NbAssert+test.ScenarioAssert.NbAssert) + } + if !noClean { + if err := test.Clean(); err != nil { + return fmt.Errorf("unable to clean test '%s' env: %w", test.Name, err) + } + } + } else { + success = false + cleanTestEnv := false + if cfg.Cscli.Output == "human" { + printParserFailures(test) + printScenarioFailures(test) + if !forceClean && !noClean { + prompt := &survey.Confirm{ + Message: fmt.Sprintf("\nDo you want to remove runtime folder for test '%s'? (default: Yes)", test.Name), + Default: true, + } + if err := survey.AskOne(prompt, &cleanTestEnv); err != nil { + return fmt.Errorf("unable to ask to remove runtime folder: %w", err) + } + } + } + + if cleanTestEnv || forceClean { + if err := test.Clean(); err != nil { + return fmt.Errorf("unable to clean test '%s' env: %w", test.Name, err) + } + } + } + } + + switch cfg.Cscli.Output { + case "human": + hubTestResultTable(color.Output, cfg.Cscli.Color, testResult) + case "json": + jsonResult := make(map[string][]string, 0) + jsonResult["success"] = make([]string, 0) + jsonResult["fail"] = make([]string, 0) + for testName, success := range testResult { + if success { + jsonResult["success"] = append(jsonResult["success"], testName) + } else { + jsonResult["fail"] = append(jsonResult["fail"], testName) + } + } + jsonStr, err := json.Marshal(jsonResult) + if err != nil { + return fmt.Errorf("unable to json test result: %w", err) + } + fmt.Println(string(jsonStr)) + default: + return errors.New("only human/json output modes are supported") + } + + if !success { + return errors.New("some tests failed") + } + + return nil + }, + } + + cmd.Flags().BoolVar(&noClean, "no-clean", false, "Don't clean runtime environment if test succeed") + cmd.Flags().BoolVar(&forceClean, "clean", false, "Clean runtime environment if test fail") + cmd.Flags().StringVar(&nucleiTargetHost, "target", hubtest.DefaultNucleiTarget, "Target for AppSec Test") + cmd.Flags().StringVar(&appSecHost, "host", hubtest.DefaultAppsecHost, "Address to expose AppSec for hubtest") + cmd.Flags().BoolVar(&runAll, "all", false, "Run all tests") + + return cmd +} diff --git a/cmd/crowdsec-cli/clihubtest/table.go b/cmd/crowdsec-cli/clihubtest/table.go new file mode 100644 index 00000000000..2a105a1f5c1 --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/table.go @@ -0,0 +1,64 @@ +package clihubtest + +import ( + "fmt" + "io" + + "github.com/jedib0t/go-pretty/v6/text" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" + "github.com/crowdsecurity/crowdsec/pkg/emoji" + "github.com/crowdsecurity/crowdsec/pkg/hubtest" +) + +func hubTestResultTable(out io.Writer, wantColor string, testResult map[string]bool) { + t := cstable.NewLight(out, wantColor) + t.SetHeaders("Test", "Result") + t.SetHeaderAlignment(text.AlignLeft) + t.SetAlignment(text.AlignLeft) + + for testName, success := range testResult { + status := emoji.CheckMarkButton + if !success { + status = emoji.CrossMark + } + + t.AddRow(testName, status) + } + + t.Render() +} + +func hubTestListTable(out io.Writer, wantColor string, tests []*hubtest.HubTestItem) { + t := cstable.NewLight(out, wantColor) + t.SetHeaders("Name", "Path") + t.SetHeaderAlignment(text.AlignLeft, text.AlignLeft) + t.SetAlignment(text.AlignLeft, text.AlignLeft) + + for _, test := range tests { + t.AddRow(test.Name, test.Path) + } + + t.Render() +} + +func hubTestCoverageTable(out io.Writer, wantColor string, headers []string, coverage []hubtest.Coverage) { + t := cstable.NewLight(out, wantColor) + t.SetHeaders(headers...) + t.SetHeaderAlignment(text.AlignLeft, text.AlignLeft, text.AlignLeft) + t.SetAlignment(text.AlignLeft, text.AlignLeft, text.AlignLeft) + + parserTested := 0 + + for _, test := range coverage { + status := emoji.RedCircle + if test.TestsCount > 0 { + status = emoji.GreenCircle + parserTested++ + } + + t.AddRow(test.Name, status, fmt.Sprintf("%d times (across %d tests)", test.TestsCount, len(test.PresentIn))) + } + + t.Render() +} diff --git a/cmd/crowdsec-cli/cliitem/appsec.go b/cmd/crowdsec-cli/cliitem/appsec.go new file mode 100644 index 00000000000..44afa2133bd --- /dev/null +++ b/cmd/crowdsec-cli/cliitem/appsec.go @@ -0,0 +1,123 @@ +package cliitem + +import ( + "fmt" + "os" + + "golang.org/x/text/cases" + "golang.org/x/text/language" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/pkg/appsec" + "github.com/crowdsecurity/crowdsec/pkg/appsec/appsec_rule" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" +) + +func NewAppsecConfig(cfg configGetter) *cliItem { + return &cliItem{ + cfg: cfg, + name: cwhub.APPSEC_CONFIGS, + singular: "appsec-config", + oneOrMore: "appsec-config(s)", + help: cliHelp{ + example: `cscli appsec-configs list -a +cscli appsec-configs install crowdsecurity/vpatch +cscli appsec-configs inspect crowdsecurity/vpatch +cscli appsec-configs upgrade crowdsecurity/vpatch +cscli appsec-configs remove crowdsecurity/vpatch +`, + }, + installHelp: cliHelp{ + example: `cscli appsec-configs install crowdsecurity/vpatch`, + }, + removeHelp: cliHelp{ + example: `cscli appsec-configs remove crowdsecurity/vpatch`, + }, + upgradeHelp: cliHelp{ + example: `cscli appsec-configs upgrade crowdsecurity/vpatch`, + }, + inspectHelp: cliHelp{ + example: `cscli appsec-configs inspect crowdsecurity/vpatch`, + }, + listHelp: cliHelp{ + example: `cscli appsec-configs list +cscli appsec-configs list -a +cscli appsec-configs list crowdsecurity/vpatch`, + }, + } +} + +func NewAppsecRule(cfg configGetter) *cliItem { + inspectDetail := func(item *cwhub.Item) error { + // Only show the converted rules in human mode + if cfg().Cscli.Output != "human" { + return nil + } + + appsecRule := appsec.AppsecCollectionConfig{} + + yamlContent, err := os.ReadFile(item.State.LocalPath) + if err != nil { + return fmt.Errorf("unable to read file %s: %w", item.State.LocalPath, err) + } + + if err := yaml.Unmarshal(yamlContent, &appsecRule); err != nil { + return fmt.Errorf("unable to parse yaml file %s: %w", item.State.LocalPath, err) + } + + for _, ruleType := range appsec_rule.SupportedTypes() { + fmt.Printf("\n%s format:\n", cases.Title(language.Und, cases.NoLower).String(ruleType)) + + for _, rule := range appsecRule.Rules { + convertedRule, _, err := rule.Convert(ruleType, appsecRule.Name) + if err != nil { + return fmt.Errorf("unable to convert rule %s: %w", rule.Name, err) + } + + fmt.Println(convertedRule) + } + + switch ruleType { //nolint:gocritic + case appsec_rule.ModsecurityRuleType: + for _, rule := range appsecRule.SecLangRules { + fmt.Println(rule) + } + } + } + + return nil + } + + return &cliItem{ + cfg: cfg, + name: "appsec-rules", + singular: "appsec-rule", + oneOrMore: "appsec-rule(s)", + help: cliHelp{ + example: `cscli appsec-rules list -a +cscli appsec-rules install crowdsecurity/crs +cscli appsec-rules inspect crowdsecurity/crs +cscli appsec-rules upgrade crowdsecurity/crs +cscli appsec-rules remove crowdsecurity/crs +`, + }, + installHelp: cliHelp{ + example: `cscli appsec-rules install crowdsecurity/crs`, + }, + removeHelp: cliHelp{ + example: `cscli appsec-rules remove crowdsecurity/crs`, + }, + upgradeHelp: cliHelp{ + example: `cscli appsec-rules upgrade crowdsecurity/crs`, + }, + inspectHelp: cliHelp{ + example: `cscli appsec-rules inspect crowdsecurity/crs`, + }, + inspectDetail: inspectDetail, + listHelp: cliHelp{ + example: `cscli appsec-rules list +cscli appsec-rules list -a +cscli appsec-rules list crowdsecurity/crs`, + }, + } +} diff --git a/cmd/crowdsec-cli/cliitem/collection.go b/cmd/crowdsec-cli/cliitem/collection.go new file mode 100644 index 00000000000..ea91c1e537a --- /dev/null +++ b/cmd/crowdsec-cli/cliitem/collection.go @@ -0,0 +1,41 @@ +package cliitem + +import ( + "github.com/crowdsecurity/crowdsec/pkg/cwhub" +) + +func NewCollection(cfg configGetter) *cliItem { + return &cliItem{ + cfg: cfg, + name: cwhub.COLLECTIONS, + singular: "collection", + oneOrMore: "collection(s)", + help: cliHelp{ + example: `cscli collections list -a +cscli collections install crowdsecurity/http-cve crowdsecurity/iptables +cscli collections inspect crowdsecurity/http-cve crowdsecurity/iptables +cscli collections upgrade crowdsecurity/http-cve crowdsecurity/iptables +cscli collections remove crowdsecurity/http-cve crowdsecurity/iptables +`, + }, + installHelp: cliHelp{ + example: `cscli collections install crowdsecurity/http-cve crowdsecurity/iptables`, + }, + removeHelp: cliHelp{ + example: `cscli collections remove crowdsecurity/http-cve crowdsecurity/iptables`, + }, + upgradeHelp: cliHelp{ + example: `cscli collections upgrade crowdsecurity/http-cve crowdsecurity/iptables`, + }, + inspectHelp: cliHelp{ + example: `cscli collections inspect crowdsecurity/http-cve crowdsecurity/iptables`, + }, + listHelp: cliHelp{ + example: `cscli collections list +cscli collections list -a +cscli collections list crowdsecurity/http-cve crowdsecurity/iptables + +List only enabled collections unless "-a" or names are specified.`, + }, + } +} diff --git a/cmd/crowdsec-cli/cliitem/context.go b/cmd/crowdsec-cli/cliitem/context.go new file mode 100644 index 00000000000..7d110b8203d --- /dev/null +++ b/cmd/crowdsec-cli/cliitem/context.go @@ -0,0 +1,41 @@ +package cliitem + +import ( + "github.com/crowdsecurity/crowdsec/pkg/cwhub" +) + +func NewContext(cfg configGetter) *cliItem { + return &cliItem{ + cfg: cfg, + name: cwhub.CONTEXTS, + singular: "context", + oneOrMore: "context(s)", + help: cliHelp{ + example: `cscli contexts list -a +cscli contexts install crowdsecurity/yyy crowdsecurity/zzz +cscli contexts inspect crowdsecurity/yyy crowdsecurity/zzz +cscli contexts upgrade crowdsecurity/yyy crowdsecurity/zzz +cscli contexts remove crowdsecurity/yyy crowdsecurity/zzz +`, + }, + installHelp: cliHelp{ + example: `cscli contexts install crowdsecurity/yyy crowdsecurity/zzz`, + }, + removeHelp: cliHelp{ + example: `cscli contexts remove crowdsecurity/yyy crowdsecurity/zzz`, + }, + upgradeHelp: cliHelp{ + example: `cscli contexts upgrade crowdsecurity/yyy crowdsecurity/zzz`, + }, + inspectHelp: cliHelp{ + example: `cscli contexts inspect crowdsecurity/yyy crowdsecurity/zzz`, + }, + listHelp: cliHelp{ + example: `cscli contexts list +cscli contexts list -a +cscli contexts list crowdsecurity/yyy crowdsecurity/zzz + +List only enabled contexts unless "-a" or names are specified.`, + }, + } +} diff --git a/cmd/crowdsec-cli/cliitem/hubscenario.go b/cmd/crowdsec-cli/cliitem/hubscenario.go new file mode 100644 index 00000000000..a5e854b3c82 --- /dev/null +++ b/cmd/crowdsec-cli/cliitem/hubscenario.go @@ -0,0 +1,41 @@ +package cliitem + +import ( + "github.com/crowdsecurity/crowdsec/pkg/cwhub" +) + +func NewScenario(cfg configGetter) *cliItem { + return &cliItem{ + cfg: cfg, + name: cwhub.SCENARIOS, + singular: "scenario", + oneOrMore: "scenario(s)", + help: cliHelp{ + example: `cscli scenarios list -a +cscli scenarios install crowdsecurity/ssh-bf crowdsecurity/http-probing +cscli scenarios inspect crowdsecurity/ssh-bf crowdsecurity/http-probing +cscli scenarios upgrade crowdsecurity/ssh-bf crowdsecurity/http-probing +cscli scenarios remove crowdsecurity/ssh-bf crowdsecurity/http-probing +`, + }, + installHelp: cliHelp{ + example: `cscli scenarios install crowdsecurity/ssh-bf crowdsecurity/http-probing`, + }, + removeHelp: cliHelp{ + example: `cscli scenarios remove crowdsecurity/ssh-bf crowdsecurity/http-probing`, + }, + upgradeHelp: cliHelp{ + example: `cscli scenarios upgrade crowdsecurity/ssh-bf crowdsecurity/http-probing`, + }, + inspectHelp: cliHelp{ + example: `cscli scenarios inspect crowdsecurity/ssh-bf crowdsecurity/http-probing`, + }, + listHelp: cliHelp{ + example: `cscli scenarios list +cscli scenarios list -a +cscli scenarios list crowdsecurity/ssh-bf crowdsecurity/http-probing + +List only enabled scenarios unless "-a" or names are specified.`, + }, + } +} diff --git a/cmd/crowdsec-cli/cliitem/item.go b/cmd/crowdsec-cli/cliitem/item.go new file mode 100644 index 00000000000..28828eb9c95 --- /dev/null +++ b/cmd/crowdsec-cli/cliitem/item.go @@ -0,0 +1,550 @@ +package cliitem + +import ( + "cmp" + "context" + "errors" + "fmt" + "os" + "strings" + + "github.com/fatih/color" + "github.com/hexops/gotextdiff" + "github.com/hexops/gotextdiff/myers" + "github.com/hexops/gotextdiff/span" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clihub" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/reload" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" +) + +type cliHelp struct { + // Example is required, the others have a default value + // generated from the item type + use string + short string + long string + example string +} + +type configGetter func() *csconfig.Config + +type cliItem struct { + cfg configGetter + name string // plural, as used in the hub index + singular string + oneOrMore string // parenthetical pluralizaion: "parser(s)" + help cliHelp + installHelp cliHelp + removeHelp cliHelp + upgradeHelp cliHelp + inspectHelp cliHelp + inspectDetail func(item *cwhub.Item) error + listHelp cliHelp +} + +func (cli cliItem) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: cmp.Or(cli.help.use, cli.name+" [item]..."), + Short: cmp.Or(cli.help.short, "Manage hub "+cli.name), + Long: cli.help.long, + Example: cli.help.example, + Args: cobra.MinimumNArgs(1), + Aliases: []string{cli.singular}, + DisableAutoGenTag: true, + } + + cmd.AddCommand(cli.newInstallCmd()) + cmd.AddCommand(cli.newRemoveCmd()) + cmd.AddCommand(cli.newUpgradeCmd()) + cmd.AddCommand(cli.newInspectCmd()) + cmd.AddCommand(cli.newListCmd()) + + return cmd +} + +func (cli cliItem) install(ctx context.Context, args []string, downloadOnly bool, force bool, ignoreError bool) error { + cfg := cli.cfg() + + hub, err := require.Hub(cfg, require.RemoteHub(ctx, cfg), log.StandardLogger()) + if err != nil { + return err + } + + for _, name := range args { + item := hub.GetItem(cli.name, name) + if item == nil { + msg := suggestNearestMessage(hub, cli.name, name) + if !ignoreError { + return errors.New(msg) + } + + log.Error(msg) + + continue + } + + if err := item.Install(ctx, force, downloadOnly); err != nil { + if !ignoreError { + return fmt.Errorf("error while installing '%s': %w", item.Name, err) + } + + log.Errorf("Error while installing '%s': %s", item.Name, err) + } + } + + log.Info(reload.Message) + + return nil +} + +func (cli cliItem) newInstallCmd() *cobra.Command { + var ( + downloadOnly bool + force bool + ignoreError bool + ) + + cmd := &cobra.Command{ + Use: cmp.Or(cli.installHelp.use, "install [item]..."), + Short: cmp.Or(cli.installHelp.short, "Install given "+cli.oneOrMore), + Long: cmp.Or(cli.installHelp.long, fmt.Sprintf("Fetch and install one or more %s from the hub", cli.name)), + Example: cli.installHelp.example, + Args: cobra.MinimumNArgs(1), + DisableAutoGenTag: true, + ValidArgsFunction: func(_ *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + return compAllItems(cli.name, args, toComplete, cli.cfg) + }, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.install(cmd.Context(), args, downloadOnly, force, ignoreError) + }, + } + + flags := cmd.Flags() + flags.BoolVarP(&downloadOnly, "download-only", "d", false, "Only download packages, don't enable") + flags.BoolVar(&force, "force", false, "Force install: overwrite tainted and outdated files") + flags.BoolVar(&ignoreError, "ignore", false, "Ignore errors when installing multiple "+cli.name) + + return cmd +} + +// return the names of the installed parents of an item, used to check if we can remove it +func istalledParentNames(item *cwhub.Item) []string { + ret := make([]string, 0) + + for _, parent := range item.Ancestors() { + if parent.State.Installed { + ret = append(ret, parent.Name) + } + } + + return ret +} + +func (cli cliItem) remove(args []string, purge bool, force bool, all bool) error { + hub, err := require.Hub(cli.cfg(), nil, log.StandardLogger()) + if err != nil { + return err + } + + if all { + itemGetter := hub.GetInstalledByType + if purge { + itemGetter = hub.GetItemsByType + } + + removed := 0 + + for _, item := range itemGetter(cli.name, true) { + didRemove, err := item.Remove(purge, force) + if err != nil { + return err + } + + if didRemove { + log.Infof("Removed %s", item.Name) + + removed++ + } + } + + log.Infof("Removed %d %s", removed, cli.name) + + if removed > 0 { + log.Info(reload.Message) + } + + return nil + } + + if len(args) == 0 { + return fmt.Errorf("specify at least one %s to remove or '--all'", cli.singular) + } + + removed := 0 + + for _, itemName := range args { + item := hub.GetItem(cli.name, itemName) + if item == nil { + return fmt.Errorf("can't find '%s' in %s", itemName, cli.name) + } + + parents := istalledParentNames(item) + + if !force && len(parents) > 0 { + log.Warningf("%s belongs to collections: %s", item.Name, parents) + log.Warningf("Run 'sudo cscli %s remove %s --force' if you want to force remove this %s", item.Type, item.Name, cli.singular) + + continue + } + + didRemove, err := item.Remove(purge, force) + if err != nil { + return err + } + + if didRemove { + log.Infof("Removed %s", item.Name) + + removed++ + } + } + + log.Infof("Removed %d %s", removed, cli.name) + + if removed > 0 { + log.Info(reload.Message) + } + + return nil +} + +func (cli cliItem) newRemoveCmd() *cobra.Command { + var ( + purge bool + force bool + all bool + ) + + cmd := &cobra.Command{ + Use: cmp.Or(cli.removeHelp.use, "remove [item]..."), + Short: cmp.Or(cli.removeHelp.short, "Remove given "+cli.oneOrMore), + Long: cmp.Or(cli.removeHelp.long, "Remove one or more "+cli.name), + Example: cli.removeHelp.example, + Aliases: []string{"delete"}, + DisableAutoGenTag: true, + ValidArgsFunction: func(_ *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + return compInstalledItems(cli.name, args, toComplete, cli.cfg) + }, + RunE: func(_ *cobra.Command, args []string) error { + return cli.remove(args, purge, force, all) + }, + } + + flags := cmd.Flags() + flags.BoolVar(&purge, "purge", false, "Delete source file too") + flags.BoolVar(&force, "force", false, "Force remove: remove tainted and outdated files") + flags.BoolVar(&all, "all", false, "Remove all the "+cli.name) + + return cmd +} + +func (cli cliItem) upgrade(ctx context.Context, args []string, force bool, all bool) error { + cfg := cli.cfg() + + hub, err := require.Hub(cfg, require.RemoteHub(ctx, cfg), log.StandardLogger()) + if err != nil { + return err + } + + if all { + updated := 0 + + for _, item := range hub.GetInstalledByType(cli.name, true) { + didUpdate, err := item.Upgrade(ctx, force) + if err != nil { + return err + } + + if didUpdate { + updated++ + } + } + + log.Infof("Updated %d %s", updated, cli.name) + + if updated > 0 { + log.Info(reload.Message) + } + + return nil + } + + if len(args) == 0 { + return fmt.Errorf("specify at least one %s to upgrade or '--all'", cli.singular) + } + + updated := 0 + + for _, itemName := range args { + item := hub.GetItem(cli.name, itemName) + if item == nil { + return fmt.Errorf("can't find '%s' in %s", itemName, cli.name) + } + + didUpdate, err := item.Upgrade(ctx, force) + if err != nil { + return err + } + + if didUpdate { + log.Infof("Updated %s", item.Name) + + updated++ + } + } + + if updated > 0 { + log.Info(reload.Message) + } + + return nil +} + +func (cli cliItem) newUpgradeCmd() *cobra.Command { + var ( + all bool + force bool + ) + + cmd := &cobra.Command{ + Use: cmp.Or(cli.upgradeHelp.use, "upgrade [item]..."), + Short: cmp.Or(cli.upgradeHelp.short, "Upgrade given "+cli.oneOrMore), + Long: cmp.Or(cli.upgradeHelp.long, fmt.Sprintf("Fetch and upgrade one or more %s from the hub", cli.name)), + Example: cli.upgradeHelp.example, + DisableAutoGenTag: true, + ValidArgsFunction: func(_ *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + return compInstalledItems(cli.name, args, toComplete, cli.cfg) + }, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.upgrade(cmd.Context(), args, force, all) + }, + } + + flags := cmd.Flags() + flags.BoolVarP(&all, "all", "a", false, "Upgrade all the "+cli.name) + flags.BoolVar(&force, "force", false, "Force upgrade: overwrite tainted and outdated files") + + return cmd +} + +func (cli cliItem) inspect(ctx context.Context, args []string, url string, diff bool, rev bool, noMetrics bool) error { + cfg := cli.cfg() + + if rev && !diff { + return errors.New("--rev can only be used with --diff") + } + + if url != "" { + cfg.Cscli.PrometheusUrl = url + } + + remote := (*cwhub.RemoteHubCfg)(nil) + + if diff { + remote = require.RemoteHub(ctx, cfg) + } + + hub, err := require.Hub(cfg, remote, log.StandardLogger()) + if err != nil { + return err + } + + for _, name := range args { + item := hub.GetItem(cli.name, name) + if item == nil { + return fmt.Errorf("can't find '%s' in %s", name, cli.name) + } + + if diff { + fmt.Println(cli.whyTainted(ctx, hub, item, rev)) + + continue + } + + if err = clihub.InspectItem(item, !noMetrics, cfg.Cscli.Output, cfg.Cscli.PrometheusUrl, cfg.Cscli.Color); err != nil { + return err + } + + if cli.inspectDetail != nil { + if err = cli.inspectDetail(item); err != nil { + return err + } + } + } + + return nil +} + +func (cli cliItem) newInspectCmd() *cobra.Command { + var ( + url string + diff bool + rev bool + noMetrics bool + ) + + cmd := &cobra.Command{ + Use: cmp.Or(cli.inspectHelp.use, "inspect [item]..."), + Short: cmp.Or(cli.inspectHelp.short, "Inspect given "+cli.oneOrMore), + Long: cmp.Or(cli.inspectHelp.long, "Inspect the state of one or more "+cli.name), + Example: cli.inspectHelp.example, + Args: cobra.MinimumNArgs(1), + DisableAutoGenTag: true, + ValidArgsFunction: func(_ *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + return compInstalledItems(cli.name, args, toComplete, cli.cfg) + }, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.inspect(cmd.Context(), args, url, diff, rev, noMetrics) + }, + } + + flags := cmd.Flags() + flags.StringVarP(&url, "url", "u", "", "Prometheus url") + flags.BoolVar(&diff, "diff", false, "Show diff with latest version (for tainted items)") + flags.BoolVar(&rev, "rev", false, "Reverse diff output") + flags.BoolVar(&noMetrics, "no-metrics", false, "Don't show metrics (when cscli.output=human)") + + return cmd +} + +func (cli cliItem) list(args []string, all bool) error { + cfg := cli.cfg() + + hub, err := require.Hub(cli.cfg(), nil, log.StandardLogger()) + if err != nil { + return err + } + + items := make(map[string][]*cwhub.Item) + + items[cli.name], err = clihub.SelectItems(hub, cli.name, args, !all) + if err != nil { + return err + } + + return clihub.ListItems(color.Output, cfg.Cscli.Color, []string{cli.name}, items, false, cfg.Cscli.Output) +} + +func (cli cliItem) newListCmd() *cobra.Command { + var all bool + + cmd := &cobra.Command{ + Use: cmp.Or(cli.listHelp.use, "list [item... | -a]"), + Short: cmp.Or(cli.listHelp.short, "List "+cli.oneOrMore), + Long: cmp.Or(cli.listHelp.long, "List of installed/available/specified "+cli.name), + Example: cli.listHelp.example, + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + return cli.list(args, all) + }, + } + + flags := cmd.Flags() + flags.BoolVarP(&all, "all", "a", false, "List disabled items as well") + + return cmd +} + +// return the diff between the installed version and the latest version +func (cli cliItem) itemDiff(ctx context.Context, item *cwhub.Item, reverse bool) (string, error) { + if !item.State.Installed { + return "", fmt.Errorf("'%s' is not installed", item.FQName()) + } + + dest, err := os.CreateTemp("", "cscli-diff-*") + if err != nil { + return "", fmt.Errorf("while creating temporary file: %w", err) + } + defer os.Remove(dest.Name()) + + _, remoteURL, err := item.FetchContentTo(ctx, dest.Name()) + if err != nil { + return "", err + } + + latestContent, err := os.ReadFile(dest.Name()) + if err != nil { + return "", fmt.Errorf("while reading %s: %w", dest.Name(), err) + } + + localContent, err := os.ReadFile(item.State.LocalPath) + if err != nil { + return "", fmt.Errorf("while reading %s: %w", item.State.LocalPath, err) + } + + file1 := item.State.LocalPath + file2 := remoteURL + content1 := string(localContent) + content2 := string(latestContent) + + if reverse { + file1, file2 = file2, file1 + content1, content2 = content2, content1 + } + + edits := myers.ComputeEdits(span.URIFromPath(file1), content1, content2) + diff := gotextdiff.ToUnified(file1, file2, content1, edits) + + return fmt.Sprintf("%s", diff), nil +} + +func (cli cliItem) whyTainted(ctx context.Context, hub *cwhub.Hub, item *cwhub.Item, reverse bool) string { + if !item.State.Installed { + return fmt.Sprintf("# %s is not installed", item.FQName()) + } + + if !item.State.Tainted { + return fmt.Sprintf("# %s is not tainted", item.FQName()) + } + + if len(item.State.TaintedBy) == 0 { + return fmt.Sprintf("# %s is tainted but we don't know why. please report this as a bug", item.FQName()) + } + + ret := []string{ + fmt.Sprintf("# Let's see why %s is tainted.", item.FQName()), + } + + for _, fqsub := range item.State.TaintedBy { + ret = append(ret, fmt.Sprintf("\n-> %s\n", fqsub)) + + sub, err := hub.GetItemFQ(fqsub) + if err != nil { + ret = append(ret, err.Error()) + } + + diff, err := cli.itemDiff(ctx, sub, reverse) + if err != nil { + ret = append(ret, err.Error()) + } + + if diff != "" { + ret = append(ret, diff) + } else if len(sub.State.TaintedBy) > 0 { + taintList := strings.Join(sub.State.TaintedBy, ", ") + if sub.FQName() == taintList { + // hack: avoid message "item is tainted by itself" + continue + } + + ret = append(ret, fmt.Sprintf("# %s is tainted by %s", sub.FQName(), taintList)) + } + } + + return strings.Join(ret, "\n") +} diff --git a/cmd/crowdsec-cli/cliitem/parser.go b/cmd/crowdsec-cli/cliitem/parser.go new file mode 100644 index 00000000000..bc1d96bdaf0 --- /dev/null +++ b/cmd/crowdsec-cli/cliitem/parser.go @@ -0,0 +1,41 @@ +package cliitem + +import ( + "github.com/crowdsecurity/crowdsec/pkg/cwhub" +) + +func NewParser(cfg configGetter) *cliItem { + return &cliItem{ + cfg: cfg, + name: cwhub.PARSERS, + singular: "parser", + oneOrMore: "parser(s)", + help: cliHelp{ + example: `cscli parsers list -a +cscli parsers install crowdsecurity/caddy-logs crowdsecurity/sshd-logs +cscli parsers inspect crowdsecurity/caddy-logs crowdsecurity/sshd-logs +cscli parsers upgrade crowdsecurity/caddy-logs crowdsecurity/sshd-logs +cscli parsers remove crowdsecurity/caddy-logs crowdsecurity/sshd-logs +`, + }, + installHelp: cliHelp{ + example: `cscli parsers install crowdsecurity/caddy-logs crowdsecurity/sshd-logs`, + }, + removeHelp: cliHelp{ + example: `cscli parsers remove crowdsecurity/caddy-logs crowdsecurity/sshd-logs`, + }, + upgradeHelp: cliHelp{ + example: `cscli parsers upgrade crowdsecurity/caddy-logs crowdsecurity/sshd-logs`, + }, + inspectHelp: cliHelp{ + example: `cscli parsers inspect crowdsecurity/httpd-logs crowdsecurity/sshd-logs`, + }, + listHelp: cliHelp{ + example: `cscli parsers list +cscli parsers list -a +cscli parsers list crowdsecurity/caddy-logs crowdsecurity/sshd-logs + +List only enabled parsers unless "-a" or names are specified.`, + }, + } +} diff --git a/cmd/crowdsec-cli/cliitem/postoverflow.go b/cmd/crowdsec-cli/cliitem/postoverflow.go new file mode 100644 index 00000000000..ea53aef327d --- /dev/null +++ b/cmd/crowdsec-cli/cliitem/postoverflow.go @@ -0,0 +1,41 @@ +package cliitem + +import ( + "github.com/crowdsecurity/crowdsec/pkg/cwhub" +) + +func NewPostOverflow(cfg configGetter) *cliItem { + return &cliItem{ + cfg: cfg, + name: cwhub.POSTOVERFLOWS, + singular: "postoverflow", + oneOrMore: "postoverflow(s)", + help: cliHelp{ + example: `cscli postoverflows list -a +cscli postoverflows install crowdsecurity/cdn-whitelist crowdsecurity/rdns +cscli postoverflows inspect crowdsecurity/cdn-whitelist crowdsecurity/rdns +cscli postoverflows upgrade crowdsecurity/cdn-whitelist crowdsecurity/rdns +cscli postoverflows remove crowdsecurity/cdn-whitelist crowdsecurity/rdns +`, + }, + installHelp: cliHelp{ + example: `cscli postoverflows install crowdsecurity/cdn-whitelist crowdsecurity/rdns`, + }, + removeHelp: cliHelp{ + example: `cscli postoverflows remove crowdsecurity/cdn-whitelist crowdsecurity/rdns`, + }, + upgradeHelp: cliHelp{ + example: `cscli postoverflows upgrade crowdsecurity/cdn-whitelist crowdsecurity/rdns`, + }, + inspectHelp: cliHelp{ + example: `cscli postoverflows inspect crowdsecurity/cdn-whitelist crowdsecurity/rdns`, + }, + listHelp: cliHelp{ + example: `cscli postoverflows list +cscli postoverflows list -a +cscli postoverflows list crowdsecurity/cdn-whitelist crowdsecurity/rdns + +List only enabled postoverflows unless "-a" or names are specified.`, + }, + } +} diff --git a/cmd/crowdsec-cli/cliitem/suggest.go b/cmd/crowdsec-cli/cliitem/suggest.go new file mode 100644 index 00000000000..5b080722af9 --- /dev/null +++ b/cmd/crowdsec-cli/cliitem/suggest.go @@ -0,0 +1,77 @@ +package cliitem + +import ( + "fmt" + "slices" + "strings" + + "github.com/agext/levenshtein" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" +) + +// suggestNearestMessage returns a message with the most similar item name, if one is found +func suggestNearestMessage(hub *cwhub.Hub, itemType string, itemName string) string { + const maxDistance = 7 + + score := 100 + nearest := "" + + for _, item := range hub.GetItemsByType(itemType, false) { + d := levenshtein.Distance(itemName, item.Name, nil) + if d < score { + score = d + nearest = item.Name + } + } + + msg := fmt.Sprintf("can't find '%s' in %s", itemName, itemType) + + if score < maxDistance { + msg += fmt.Sprintf(", did you mean '%s'?", nearest) + } + + return msg +} + +func compAllItems(itemType string, args []string, toComplete string, cfg configGetter) ([]string, cobra.ShellCompDirective) { + hub, err := require.Hub(cfg(), nil, nil) + if err != nil { + return nil, cobra.ShellCompDirectiveDefault + } + + comp := make([]string, 0) + + for _, item := range hub.GetItemsByType(itemType, false) { + if !slices.Contains(args, item.Name) && strings.Contains(item.Name, toComplete) { + comp = append(comp, item.Name) + } + } + + cobra.CompDebugln(fmt.Sprintf("%s: %+v", itemType, comp), true) + + return comp, cobra.ShellCompDirectiveNoFileComp +} + +func compInstalledItems(itemType string, args []string, toComplete string, cfg configGetter) ([]string, cobra.ShellCompDirective) { + hub, err := require.Hub(cfg(), nil, nil) + if err != nil { + return nil, cobra.ShellCompDirectiveDefault + } + + items := hub.GetInstalledByType(itemType, true) + + comp := make([]string, 0) + + for _, item := range items { + if strings.Contains(item.Name, toComplete) { + comp = append(comp, item.Name) + } + } + + cobra.CompDebugln(fmt.Sprintf("%s: %+v", itemType, comp), true) + + return comp, cobra.ShellCompDirectiveNoFileComp +} diff --git a/cmd/crowdsec-cli/clilapi/lapi.go b/cmd/crowdsec-cli/clilapi/lapi.go new file mode 100644 index 00000000000..bb721eefe03 --- /dev/null +++ b/cmd/crowdsec-cli/clilapi/lapi.go @@ -0,0 +1,636 @@ +package clilapi + +import ( + "context" + "errors" + "fmt" + "io" + "net/url" + "os" + "slices" + "sort" + "strings" + + "github.com/fatih/color" + "github.com/go-openapi/strfmt" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/idgen" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/reload" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/alertcontext" + "github.com/crowdsecurity/crowdsec/pkg/apiclient" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" + "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" + "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/crowdsecurity/crowdsec/pkg/parser" +) + +const LAPIURLPrefix = "v1" + +type configGetter = func() *csconfig.Config + +type cliLapi struct { + cfg configGetter +} + +func New(cfg configGetter) *cliLapi { + return &cliLapi{ + cfg: cfg, + } +} + +// queryLAPIStatus checks if the Local API is reachable, and if the credentials are correct. +func queryLAPIStatus(ctx context.Context, hub *cwhub.Hub, credURL string, login string, password string) (bool, error) { + apiURL, err := url.Parse(credURL) + if err != nil { + return false, err + } + + client, err := apiclient.NewDefaultClient(apiURL, + LAPIURLPrefix, + "", + nil) + if err != nil { + return false, err + } + + pw := strfmt.Password(password) + + itemsForAPI := hub.GetInstalledListForAPI() + + t := models.WatcherAuthRequest{ + MachineID: &login, + Password: &pw, + Scenarios: itemsForAPI, + } + + _, _, err = client.Auth.AuthenticateWatcher(ctx, t) + if err != nil { + return false, err + } + + return true, nil +} + +func (cli *cliLapi) Status(ctx context.Context, out io.Writer, hub *cwhub.Hub) error { + cfg := cli.cfg() + + cred := cfg.API.Client.Credentials + + fmt.Fprintf(out, "Loaded credentials from %s\n", cfg.API.Client.CredentialsFilePath) + fmt.Fprintf(out, "Trying to authenticate with username %s on %s\n", cred.Login, cred.URL) + + _, err := queryLAPIStatus(ctx, hub, cred.URL, cred.Login, cred.Password) + if err != nil { + return fmt.Errorf("failed to authenticate to Local API (LAPI): %w", err) + } + + fmt.Fprintf(out, "You can successfully interact with Local API (LAPI)\n") + + return nil +} + +func (cli *cliLapi) register(ctx context.Context, apiURL string, outputFile string, machine string, token string) error { + var err error + + lapiUser := machine + cfg := cli.cfg() + + if lapiUser == "" { + lapiUser, err = idgen.GenerateMachineID("") + if err != nil { + return fmt.Errorf("unable to generate machine id: %w", err) + } + } + + password := strfmt.Password(idgen.GeneratePassword(idgen.PasswordLength)) + + apiurl, err := prepareAPIURL(cfg.API.Client, apiURL) + if err != nil { + return fmt.Errorf("parsing api url: %w", err) + } + + _, err = apiclient.RegisterClient(ctx, &apiclient.Config{ + MachineID: lapiUser, + Password: password, + RegistrationToken: token, + URL: apiurl, + VersionPrefix: LAPIURLPrefix, + }, nil) + if err != nil { + return fmt.Errorf("api client register: %w", err) + } + + log.Printf("Successfully registered to Local API (LAPI)") + + var dumpFile string + + if outputFile != "" { + dumpFile = outputFile + } else if cfg.API.Client.CredentialsFilePath != "" { + dumpFile = cfg.API.Client.CredentialsFilePath + } else { + dumpFile = "" + } + + apiCfg := cfg.API.Client.Credentials + apiCfg.Login = lapiUser + apiCfg.Password = password.String() + + if apiURL != "" { + apiCfg.URL = apiURL + } + + apiConfigDump, err := yaml.Marshal(apiCfg) + if err != nil { + return fmt.Errorf("unable to serialize api credentials: %w", err) + } + + if dumpFile != "" { + err = os.WriteFile(dumpFile, apiConfigDump, 0o600) + if err != nil { + return fmt.Errorf("write api credentials to '%s' failed: %w", dumpFile, err) + } + + log.Printf("Local API credentials written to '%s'", dumpFile) + } else { + fmt.Printf("%s\n", string(apiConfigDump)) + } + + log.Warning(reload.Message) + + return nil +} + +// prepareAPIURL checks/fixes a LAPI connection url (http, https or socket) and returns an URL struct +func prepareAPIURL(clientCfg *csconfig.LocalApiClientCfg, apiURL string) (*url.URL, error) { + if apiURL == "" { + if clientCfg == nil || clientCfg.Credentials == nil || clientCfg.Credentials.URL == "" { + return nil, errors.New("no Local API URL. Please provide it in your configuration or with the -u parameter") + } + + apiURL = clientCfg.Credentials.URL + } + + // URL needs to end with /, but user doesn't care + if !strings.HasSuffix(apiURL, "/") { + apiURL += "/" + } + + // URL needs to start with http://, but user doesn't care + if !strings.HasPrefix(apiURL, "http://") && !strings.HasPrefix(apiURL, "https://") && !strings.HasPrefix(apiURL, "/") { + apiURL = "http://" + apiURL + } + + return url.Parse(apiURL) +} + +func (cli *cliLapi) newStatusCmd() *cobra.Command { + cmdLapiStatus := &cobra.Command{ + Use: "status", + Short: "Check authentication to Local API (LAPI)", + Args: cobra.MinimumNArgs(0), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + hub, err := require.Hub(cli.cfg(), nil, nil) + if err != nil { + return err + } + + return cli.Status(cmd.Context(), color.Output, hub) + }, + } + + return cmdLapiStatus +} + +func (cli *cliLapi) newRegisterCmd() *cobra.Command { + var ( + apiURL string + outputFile string + machine string + token string + ) + + cmd := &cobra.Command{ + Use: "register", + Short: "Register a machine to Local API (LAPI)", + Long: `Register your machine to the Local API (LAPI). +Keep in mind the machine needs to be validated by an administrator on LAPI side to be effective.`, + Args: cobra.MinimumNArgs(0), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.register(cmd.Context(), apiURL, outputFile, machine, token) + }, + } + + flags := cmd.Flags() + flags.StringVarP(&apiURL, "url", "u", "", "URL of the API (ie. http://127.0.0.1)") + flags.StringVarP(&outputFile, "file", "f", "", "output file destination") + flags.StringVar(&machine, "machine", "", "Name of the machine to register with") + flags.StringVar(&token, "token", "", "Auto registration token to use") + + return cmd +} + +func (cli *cliLapi) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "lapi [action]", + Short: "Manage interaction with Local API (LAPI)", + Args: cobra.MinimumNArgs(1), + DisableAutoGenTag: true, + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + if err := cli.cfg().LoadAPIClient(); err != nil { + return fmt.Errorf("loading api client: %w", err) + } + return nil + }, + } + + cmd.AddCommand(cli.newRegisterCmd()) + cmd.AddCommand(cli.newStatusCmd()) + cmd.AddCommand(cli.newContextCmd()) + + return cmd +} + +func (cli *cliLapi) addContext(key string, values []string) error { + cfg := cli.cfg() + + if err := alertcontext.ValidateContextExpr(key, values); err != nil { + return fmt.Errorf("invalid context configuration: %w", err) + } + + if _, ok := cfg.Crowdsec.ContextToSend[key]; !ok { + cfg.Crowdsec.ContextToSend[key] = make([]string, 0) + + log.Infof("key '%s' added", key) + } + + data := cfg.Crowdsec.ContextToSend[key] + + for _, val := range values { + if !slices.Contains(data, val) { + log.Infof("value '%s' added to key '%s'", val, key) + data = append(data, val) + } + + cfg.Crowdsec.ContextToSend[key] = data + } + + return cfg.Crowdsec.DumpContextConfigFile() +} + +func (cli *cliLapi) newContextAddCmd() *cobra.Command { + var ( + keyToAdd string + valuesToAdd []string + ) + + cmd := &cobra.Command{ + Use: "add", + Short: "Add context to send with alerts. You must specify the output key with the expr value you want", + Example: `cscli lapi context add --key source_ip --value evt.Meta.source_ip +cscli lapi context add --key file_source --value evt.Line.Src +cscli lapi context add --value evt.Meta.source_ip --value evt.Meta.target_user + `, + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, _ []string) error { + hub, err := require.Hub(cli.cfg(), nil, nil) + if err != nil { + return err + } + + if err = alertcontext.LoadConsoleContext(cli.cfg(), hub); err != nil { + return fmt.Errorf("while loading context: %w", err) + } + + if keyToAdd != "" { + return cli.addContext(keyToAdd, valuesToAdd) + } + + for _, v := range valuesToAdd { + keySlice := strings.Split(v, ".") + key := keySlice[len(keySlice)-1] + value := []string{v} + if err := cli.addContext(key, value); err != nil { + return err + } + } + + return nil + }, + } + + flags := cmd.Flags() + flags.StringVarP(&keyToAdd, "key", "k", "", "The key of the different values to send") + flags.StringSliceVar(&valuesToAdd, "value", []string{}, "The expr fields to associate with the key") + + _ = cmd.MarkFlagRequired("value") + + return cmd +} + +func (cli *cliLapi) newContextStatusCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "status", + Short: "List context to send with alerts", + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + hub, err := require.Hub(cfg, nil, nil) + if err != nil { + return err + } + + if err = alertcontext.LoadConsoleContext(cfg, hub); err != nil { + return fmt.Errorf("while loading context: %w", err) + } + + if len(cfg.Crowdsec.ContextToSend) == 0 { + fmt.Println("No context found on this agent. You can use 'cscli lapi context add' to add context to your alerts.") + return nil + } + + dump, err := yaml.Marshal(cfg.Crowdsec.ContextToSend) + if err != nil { + return fmt.Errorf("unable to show context status: %w", err) + } + + fmt.Print(string(dump)) + + return nil + }, + } + + return cmd +} + +func (cli *cliLapi) newContextDetectCmd() *cobra.Command { + var detectAll bool + + cmd := &cobra.Command{ + Use: "detect", + Short: "Detect available fields from the installed parsers", + Example: `cscli lapi context detect --all +cscli lapi context detect crowdsecurity/sshd-logs + `, + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + cfg := cli.cfg() + if !detectAll && len(args) == 0 { + _ = cmd.Help() + return errors.New("please provide parsers to detect or --all flag") + } + + // to avoid all the log.Info from the loaders functions + log.SetLevel(log.WarnLevel) + + if err := exprhelpers.Init(nil); err != nil { + return fmt.Errorf("failed to init expr helpers: %w", err) + } + + hub, err := require.Hub(cfg, nil, nil) + if err != nil { + return err + } + + csParsers := parser.NewParsers(hub) + if csParsers, err = parser.LoadParsers(cfg, csParsers); err != nil { + return fmt.Errorf("unable to load parsers: %w", err) + } + + fieldByParsers := make(map[string][]string) + for _, node := range csParsers.Nodes { + if !detectAll && !slices.Contains(args, node.Name) { + continue + } + if !detectAll { + args = removeFromSlice(node.Name, args) + } + fieldByParsers[node.Name] = make([]string, 0) + fieldByParsers[node.Name] = detectNode(node, *csParsers.Ctx) + + subNodeFields := detectSubNode(node, *csParsers.Ctx) + for _, field := range subNodeFields { + if !slices.Contains(fieldByParsers[node.Name], field) { + fieldByParsers[node.Name] = append(fieldByParsers[node.Name], field) + } + } + } + + fmt.Printf("Acquisition :\n\n") + fmt.Printf(" - evt.Line.Module\n") + fmt.Printf(" - evt.Line.Raw\n") + fmt.Printf(" - evt.Line.Src\n") + fmt.Println() + + parsersKey := make([]string, 0) + for k := range fieldByParsers { + parsersKey = append(parsersKey, k) + } + sort.Strings(parsersKey) + + for _, k := range parsersKey { + if len(fieldByParsers[k]) == 0 { + continue + } + fmt.Printf("%s :\n\n", k) + values := fieldByParsers[k] + sort.Strings(values) + for _, value := range values { + fmt.Printf(" - %s\n", value) + } + fmt.Println() + } + + if len(args) > 0 { + for _, parserNotFound := range args { + log.Errorf("parser '%s' not found, can't detect fields", parserNotFound) + } + } + + return nil + }, + } + cmd.Flags().BoolVarP(&detectAll, "all", "a", false, "Detect evt field for all installed parser") + + return cmd +} + +func (cli *cliLapi) newContextDeleteCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "delete", + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, _ []string) error { + filePath := cli.cfg().Crowdsec.ConsoleContextPath + if filePath == "" { + filePath = "the context file" + } + + return fmt.Errorf("command 'delete' has been removed, please manually edit %s", filePath) + }, + } + + return cmd +} + +func (cli *cliLapi) newContextCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "context [command]", + Short: "Manage context to send with alerts", + DisableAutoGenTag: true, + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + if err := cfg.LoadCrowdsec(); err != nil { + fileNotFoundMessage := fmt.Sprintf("failed to open context file: open %s: no such file or directory", cfg.Crowdsec.ConsoleContextPath) + if err.Error() != fileNotFoundMessage { + return fmt.Errorf("unable to load CrowdSec agent configuration: %w", err) + } + } + if cfg.DisableAgent { + return errors.New("agent is disabled and lapi context can only be used on the agent") + } + + return nil + }, + } + + cmd.AddCommand(cli.newContextAddCmd()) + cmd.AddCommand(cli.newContextStatusCmd()) + cmd.AddCommand(cli.newContextDetectCmd()) + cmd.AddCommand(cli.newContextDeleteCmd()) + + return cmd +} + +func detectStaticField(grokStatics []parser.ExtraField) []string { + ret := make([]string, 0) + + for _, static := range grokStatics { + if static.Parsed != "" { + fieldName := "evt.Parsed." + static.Parsed + if !slices.Contains(ret, fieldName) { + ret = append(ret, fieldName) + } + } + + if static.Meta != "" { + fieldName := "evt.Meta." + static.Meta + if !slices.Contains(ret, fieldName) { + ret = append(ret, fieldName) + } + } + + if static.TargetByName != "" { + fieldName := static.TargetByName + if !strings.HasPrefix(fieldName, "evt.") { + fieldName = "evt." + fieldName + } + + if !slices.Contains(ret, fieldName) { + ret = append(ret, fieldName) + } + } + } + + return ret +} + +func detectNode(node parser.Node, parserCTX parser.UnixParserCtx) []string { + ret := make([]string, 0) + + if node.Grok.RunTimeRegexp != nil { + for _, capturedField := range node.Grok.RunTimeRegexp.Names() { + fieldName := "evt.Parsed." + capturedField + if !slices.Contains(ret, fieldName) { + ret = append(ret, fieldName) + } + } + } + + if node.Grok.RegexpName != "" { + grokCompiled, err := parserCTX.Grok.Get(node.Grok.RegexpName) + // ignore error (parser does not exist?) + if err == nil { + for _, capturedField := range grokCompiled.Names() { + fieldName := "evt.Parsed." + capturedField + if !slices.Contains(ret, fieldName) { + ret = append(ret, fieldName) + } + } + } + } + + if len(node.Grok.Statics) > 0 { + staticsField := detectStaticField(node.Grok.Statics) + for _, staticField := range staticsField { + if !slices.Contains(ret, staticField) { + ret = append(ret, staticField) + } + } + } + + if len(node.Statics) > 0 { + staticsField := detectStaticField(node.Statics) + for _, staticField := range staticsField { + if !slices.Contains(ret, staticField) { + ret = append(ret, staticField) + } + } + } + + return ret +} + +func detectSubNode(node parser.Node, parserCTX parser.UnixParserCtx) []string { + ret := make([]string, 0) + + for _, subnode := range node.LeavesNodes { + if subnode.Grok.RunTimeRegexp != nil { + for _, capturedField := range subnode.Grok.RunTimeRegexp.Names() { + fieldName := "evt.Parsed." + capturedField + if !slices.Contains(ret, fieldName) { + ret = append(ret, fieldName) + } + } + } + + if subnode.Grok.RegexpName != "" { + grokCompiled, err := parserCTX.Grok.Get(subnode.Grok.RegexpName) + if err == nil { + // ignore error (parser does not exist?) + for _, capturedField := range grokCompiled.Names() { + fieldName := "evt.Parsed." + capturedField + if !slices.Contains(ret, fieldName) { + ret = append(ret, fieldName) + } + } + } + } + + if len(subnode.Grok.Statics) > 0 { + staticsField := detectStaticField(subnode.Grok.Statics) + for _, staticField := range staticsField { + if !slices.Contains(ret, staticField) { + ret = append(ret, staticField) + } + } + } + + if len(subnode.Statics) > 0 { + staticsField := detectStaticField(subnode.Statics) + for _, staticField := range staticsField { + if !slices.Contains(ret, staticField) { + ret = append(ret, staticField) + } + } + } + } + + return ret +} diff --git a/cmd/crowdsec-cli/clilapi/lapi_test.go b/cmd/crowdsec-cli/clilapi/lapi_test.go new file mode 100644 index 00000000000..caf986d847a --- /dev/null +++ b/cmd/crowdsec-cli/clilapi/lapi_test.go @@ -0,0 +1,49 @@ +package clilapi + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/crowdsecurity/crowdsec/pkg/csconfig" +) + +func TestPrepareAPIURL_NoProtocol(t *testing.T) { + url, err := prepareAPIURL(nil, "localhost:81") + require.NoError(t, err) + assert.Equal(t, "http://localhost:81/", url.String()) +} + +func TestPrepareAPIURL_Http(t *testing.T) { + url, err := prepareAPIURL(nil, "http://localhost:81") + require.NoError(t, err) + assert.Equal(t, "http://localhost:81/", url.String()) +} + +func TestPrepareAPIURL_Https(t *testing.T) { + url, err := prepareAPIURL(nil, "https://localhost:81") + require.NoError(t, err) + assert.Equal(t, "https://localhost:81/", url.String()) +} + +func TestPrepareAPIURL_UnixSocket(t *testing.T) { + url, err := prepareAPIURL(nil, "/path/socket") + require.NoError(t, err) + assert.Equal(t, "/path/socket/", url.String()) +} + +func TestPrepareAPIURL_Empty(t *testing.T) { + _, err := prepareAPIURL(nil, "") + require.Error(t, err) +} + +func TestPrepareAPIURL_Empty_ConfigOverride(t *testing.T) { + url, err := prepareAPIURL(&csconfig.LocalApiClientCfg{ + Credentials: &csconfig.ApiCredentialsCfg{ + URL: "localhost:80", + }, + }, "") + require.NoError(t, err) + assert.Equal(t, "http://localhost:80/", url.String()) +} diff --git a/cmd/crowdsec-cli/clilapi/utils.go b/cmd/crowdsec-cli/clilapi/utils.go new file mode 100644 index 00000000000..e3ec65f2145 --- /dev/null +++ b/cmd/crowdsec-cli/clilapi/utils.go @@ -0,0 +1,24 @@ +package clilapi + +func removeFromSlice(val string, slice []string) []string { + var i int + var value string + + valueFound := false + + // get the index + for i, value = range slice { + if value == val { + valueFound = true + break + } + } + + if valueFound { + slice[i] = slice[len(slice)-1] + slice[len(slice)-1] = "" + slice = slice[:len(slice)-1] + } + + return slice +} diff --git a/cmd/crowdsec-cli/climachine/flag.go b/cmd/crowdsec-cli/climachine/flag.go new file mode 100644 index 00000000000..c3fefd896e1 --- /dev/null +++ b/cmd/crowdsec-cli/climachine/flag.go @@ -0,0 +1,29 @@ +package climachine + +// Custom types for flag validation and conversion. + +import ( + "errors" +) + +type MachinePassword string + +func (p *MachinePassword) String() string { + return string(*p) +} + +func (p *MachinePassword) Set(v string) error { + // a password can't be more than 72 characters + // due to bcrypt limitations + if len(v) > 72 { + return errors.New("password too long (max 72 characters)") + } + + *p = MachinePassword(v) + + return nil +} + +func (p *MachinePassword) Type() string { + return "string" +} diff --git a/cmd/crowdsec-cli/climachine/machines.go b/cmd/crowdsec-cli/climachine/machines.go new file mode 100644 index 00000000000..1fbedcf57fd --- /dev/null +++ b/cmd/crowdsec-cli/climachine/machines.go @@ -0,0 +1,717 @@ +package climachine + +import ( + "context" + "encoding/csv" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "slices" + "strings" + "time" + + "github.com/AlecAivazis/survey/v2" + "github.com/fatih/color" + "github.com/go-openapi/strfmt" + "github.com/jedib0t/go-pretty/v6/table" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/ask" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clientinfo" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/idgen" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" + "github.com/crowdsecurity/crowdsec/pkg/database" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" + "github.com/crowdsecurity/crowdsec/pkg/emoji" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +// getLastHeartbeat returns the last heartbeat timestamp of a machine +// and a boolean indicating if the machine is considered active or not. +func getLastHeartbeat(m *ent.Machine) (string, bool) { + if m.LastHeartbeat == nil { + return "-", false + } + + elapsed := time.Now().UTC().Sub(*m.LastHeartbeat) + + hb := elapsed.Truncate(time.Second).String() + if elapsed > 2*time.Minute { + return hb, false + } + + return hb, true +} + +type configGetter = func() *csconfig.Config + +type cliMachines struct { + db *database.Client + cfg configGetter +} + +func New(cfg configGetter) *cliMachines { + return &cliMachines{ + cfg: cfg, + } +} + +func (cli *cliMachines) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "machines [action]", + Short: "Manage local API machines [requires local API]", + Long: `To list/add/delete/validate/prune machines. +Note: This command requires database direct access, so is intended to be run on the local API machine. +`, + Example: `cscli machines [action]`, + DisableAutoGenTag: true, + Aliases: []string{"machine"}, + PersistentPreRunE: func(cmd *cobra.Command, _ []string) error { + var err error + if err = require.LAPI(cli.cfg()); err != nil { + return err + } + cli.db, err = require.DBClient(cmd.Context(), cli.cfg().DbConfig) + if err != nil { + return err + } + + return nil + }, + } + + cmd.AddCommand(cli.newListCmd()) + cmd.AddCommand(cli.newAddCmd()) + cmd.AddCommand(cli.newDeleteCmd()) + cmd.AddCommand(cli.newValidateCmd()) + cmd.AddCommand(cli.newPruneCmd()) + cmd.AddCommand(cli.newInspectCmd()) + + return cmd +} + +func (cli *cliMachines) inspectHubHuman(out io.Writer, machine *ent.Machine) { + state := machine.Hubstate + + if len(state) == 0 { + fmt.Println("No hub items found for this machine") + return + } + + // group state rows by type for multiple tables + rowsByType := make(map[string][]table.Row) + + for itemType, items := range state { + for _, item := range items { + if _, ok := rowsByType[itemType]; !ok { + rowsByType[itemType] = make([]table.Row, 0) + } + + row := table.Row{item.Name, item.Status, item.Version} + rowsByType[itemType] = append(rowsByType[itemType], row) + } + } + + for itemType, rows := range rowsByType { + t := cstable.New(out, cli.cfg().Cscli.Color).Writer + t.AppendHeader(table.Row{"Name", "Status", "Version"}) + t.SetTitle(itemType) + t.AppendRows(rows) + io.WriteString(out, t.Render()+"\n") + } +} + +func (cli *cliMachines) listHuman(out io.Writer, machines ent.Machines) { + t := cstable.NewLight(out, cli.cfg().Cscli.Color).Writer + t.AppendHeader(table.Row{"Name", "IP Address", "Last Update", "Status", "Version", "OS", "Auth Type", "Last Heartbeat"}) + + for _, m := range machines { + validated := emoji.Prohibited + if m.IsValidated { + validated = emoji.CheckMark + } + + hb, active := getLastHeartbeat(m) + if !active { + hb = emoji.Warning + " " + hb + } + + t.AppendRow(table.Row{m.MachineId, m.IpAddress, m.UpdatedAt.Format(time.RFC3339), validated, m.Version, clientinfo.GetOSNameAndVersion(m), m.AuthType, hb}) + } + + io.WriteString(out, t.Render()+"\n") +} + +// machineInfo contains only the data we want for inspect/list: no hub status, scenarios, edges, etc. +type machineInfo struct { + CreatedAt time.Time `json:"created_at,omitempty"` + UpdatedAt time.Time `json:"updated_at,omitempty"` + LastPush *time.Time `json:"last_push,omitempty"` + LastHeartbeat *time.Time `json:"last_heartbeat,omitempty"` + MachineId string `json:"machineId,omitempty"` + IpAddress string `json:"ipAddress,omitempty"` + Version string `json:"version,omitempty"` + IsValidated bool `json:"isValidated,omitempty"` + AuthType string `json:"auth_type"` + OS string `json:"os,omitempty"` + Featureflags []string `json:"featureflags,omitempty"` + Datasources map[string]int64 `json:"datasources,omitempty"` +} + +func newMachineInfo(m *ent.Machine) machineInfo { + return machineInfo{ + CreatedAt: m.CreatedAt, + UpdatedAt: m.UpdatedAt, + LastPush: m.LastPush, + LastHeartbeat: m.LastHeartbeat, + MachineId: m.MachineId, + IpAddress: m.IpAddress, + Version: m.Version, + IsValidated: m.IsValidated, + AuthType: m.AuthType, + OS: clientinfo.GetOSNameAndVersion(m), + Featureflags: clientinfo.GetFeatureFlagList(m), + Datasources: m.Datasources, + } +} + +func (cli *cliMachines) listCSV(out io.Writer, machines ent.Machines) error { + csvwriter := csv.NewWriter(out) + + err := csvwriter.Write([]string{"machine_id", "ip_address", "updated_at", "validated", "version", "auth_type", "last_heartbeat", "os"}) + if err != nil { + return fmt.Errorf("failed to write header: %w", err) + } + + for _, m := range machines { + validated := "false" + if m.IsValidated { + validated = "true" + } + + hb := "-" + if m.LastHeartbeat != nil { + hb = m.LastHeartbeat.Format(time.RFC3339) + } + + if err := csvwriter.Write([]string{m.MachineId, m.IpAddress, m.UpdatedAt.Format(time.RFC3339), validated, m.Version, m.AuthType, hb, fmt.Sprintf("%s/%s", m.Osname, m.Osversion)}); err != nil { + return fmt.Errorf("failed to write raw output: %w", err) + } + } + + csvwriter.Flush() + + return nil +} + +func (cli *cliMachines) List(ctx context.Context, out io.Writer, db *database.Client) error { + // XXX: must use the provided db object, the one in the struct might be nil + // (calling List directly skips the PersistentPreRunE) + + machines, err := db.ListMachines(ctx) + if err != nil { + return fmt.Errorf("unable to list machines: %w", err) + } + + switch cli.cfg().Cscli.Output { + case "human": + cli.listHuman(out, machines) + case "json": + info := make([]machineInfo, 0, len(machines)) + for _, m := range machines { + info = append(info, newMachineInfo(m)) + } + + enc := json.NewEncoder(out) + enc.SetIndent("", " ") + + if err := enc.Encode(info); err != nil { + return errors.New("failed to serialize") + } + + return nil + case "raw": + return cli.listCSV(out, machines) + } + + return nil +} + +func (cli *cliMachines) newListCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "list", + Short: "list all machines in the database", + Long: `list all machines in the database with their status and last heartbeat`, + Example: `cscli machines list`, + Args: cobra.NoArgs, + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.List(cmd.Context(), color.Output, cli.db) + }, + } + + return cmd +} + +func (cli *cliMachines) newAddCmd() *cobra.Command { + var ( + password MachinePassword + dumpFile string + apiURL string + interactive bool + autoAdd bool + force bool + ) + + cmd := &cobra.Command{ + Use: "add", + Short: "add a single machine to the database", + DisableAutoGenTag: true, + Long: `Register a new machine in the database. cscli should be on the same machine as LAPI.`, + Example: `cscli machines add --auto +cscli machines add MyTestMachine --auto +cscli machines add MyTestMachine --password MyPassword +cscli machines add -f- --auto > /tmp/mycreds.yaml`, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.add(cmd.Context(), args, string(password), dumpFile, apiURL, interactive, autoAdd, force) + }, + } + + flags := cmd.Flags() + flags.VarP(&password, "password", "p", "machine password to login to the API") + flags.StringVarP(&dumpFile, "file", "f", "", "output file destination (defaults to "+csconfig.DefaultConfigPath("local_api_credentials.yaml")+")") + flags.StringVarP(&apiURL, "url", "u", "", "URL of the local API") + flags.BoolVarP(&interactive, "interactive", "i", false, "interfactive mode to enter the password") + flags.BoolVarP(&autoAdd, "auto", "a", false, "automatically generate password (and username if not provided)") + flags.BoolVar(&force, "force", false, "will force add the machine if it already exist") + + return cmd +} + +func (cli *cliMachines) add(ctx context.Context, args []string, machinePassword string, dumpFile string, apiURL string, interactive bool, autoAdd bool, force bool) error { + var ( + err error + machineID string + ) + + // create machineID if not specified by user + if len(args) == 0 { + if !autoAdd { + return errors.New("please specify a machine name to add, or use --auto") + } + + machineID, err = idgen.GenerateMachineID("") + if err != nil { + return fmt.Errorf("unable to generate machine id: %w", err) + } + } else { + machineID = args[0] + } + + clientCfg := cli.cfg().API.Client + serverCfg := cli.cfg().API.Server + + /*check if file already exists*/ + if dumpFile == "" && clientCfg != nil && clientCfg.CredentialsFilePath != "" { + credFile := clientCfg.CredentialsFilePath + // use the default only if the file does not exist + _, err = os.Stat(credFile) + + switch { + case os.IsNotExist(err) || force: + dumpFile = credFile + case err != nil: + return fmt.Errorf("unable to stat '%s': %w", credFile, err) + default: + return fmt.Errorf(`credentials file '%s' already exists: please remove it, use "--force" or specify a different file with "-f" ("-f -" for standard output)`, credFile) + } + } + + if dumpFile == "" { + return errors.New(`please specify a file to dump credentials to, with -f ("-f -" for standard output)`) + } + + // create a password if it's not specified by user + if machinePassword == "" && !interactive { + if !autoAdd { + return errors.New("please specify a password with --password or use --auto") + } + + machinePassword = idgen.GeneratePassword(idgen.PasswordLength) + } else if machinePassword == "" && interactive { + qs := &survey.Password{ + Message: "Please provide a password for the machine:", + } + survey.AskOne(qs, &machinePassword) + } + + password := strfmt.Password(machinePassword) + + _, err = cli.db.CreateMachine(ctx, &machineID, &password, "", true, force, types.PasswordAuthType) + if err != nil { + return fmt.Errorf("unable to create machine: %w", err) + } + + fmt.Fprintf(os.Stderr, "Machine '%s' successfully added to the local API.\n", machineID) + + if apiURL == "" { + if clientCfg != nil && clientCfg.Credentials != nil && clientCfg.Credentials.URL != "" { + apiURL = clientCfg.Credentials.URL + } else if serverCfg.ClientURL() != "" { + apiURL = serverCfg.ClientURL() + } else { + return errors.New("unable to dump an api URL. Please provide it in your configuration or with the -u parameter") + } + } + + apiCfg := csconfig.ApiCredentialsCfg{ + Login: machineID, + Password: password.String(), + URL: apiURL, + } + + apiConfigDump, err := yaml.Marshal(apiCfg) + if err != nil { + return fmt.Errorf("unable to serialize api credentials: %w", err) + } + + if dumpFile != "" && dumpFile != "-" { + if err = os.WriteFile(dumpFile, apiConfigDump, 0o600); err != nil { + return fmt.Errorf("write api credentials in '%s' failed: %w", dumpFile, err) + } + + fmt.Fprintf(os.Stderr, "API credentials written to '%s'.\n", dumpFile) + } else { + fmt.Print(string(apiConfigDump)) + } + + return nil +} + +// validMachineID returns a list of machine IDs for command completion +func (cli *cliMachines) validMachineID(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + var err error + + cfg := cli.cfg() + ctx := cmd.Context() + + // need to load config and db because PersistentPreRunE is not called for completions + + if err = require.LAPI(cfg); err != nil { + cobra.CompError("unable to list machines " + err.Error()) + return nil, cobra.ShellCompDirectiveNoFileComp + } + + cli.db, err = require.DBClient(ctx, cfg.DbConfig) + if err != nil { + cobra.CompError("unable to list machines " + err.Error()) + return nil, cobra.ShellCompDirectiveNoFileComp + } + + machines, err := cli.db.ListMachines(ctx) + if err != nil { + cobra.CompError("unable to list machines " + err.Error()) + return nil, cobra.ShellCompDirectiveNoFileComp + } + + ret := []string{} + + for _, machine := range machines { + if strings.Contains(machine.MachineId, toComplete) && !slices.Contains(args, machine.MachineId) { + ret = append(ret, machine.MachineId) + } + } + + return ret, cobra.ShellCompDirectiveNoFileComp +} + +func (cli *cliMachines) delete(ctx context.Context, machines []string, ignoreMissing bool) error { + for _, machineID := range machines { + if err := cli.db.DeleteWatcher(ctx, machineID); err != nil { + var notFoundErr *database.MachineNotFoundError + if ignoreMissing && errors.As(err, ¬FoundErr) { + return nil + } + + log.Errorf("unable to delete machine: %s", err) + + return nil + } + + log.Infof("machine '%s' deleted successfully", machineID) + } + + return nil +} + +func (cli *cliMachines) newDeleteCmd() *cobra.Command { + var ignoreMissing bool + + cmd := &cobra.Command{ + Use: "delete [machine_name]...", + Short: "delete machine(s) by name", + Example: `cscli machines delete "machine1" "machine2"`, + Args: cobra.MinimumNArgs(1), + Aliases: []string{"remove"}, + DisableAutoGenTag: true, + ValidArgsFunction: cli.validMachineID, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.delete(cmd.Context(), args, ignoreMissing) + }, + } + + flags := cmd.Flags() + flags.BoolVar(&ignoreMissing, "ignore-missing", false, "don't print errors if one or more machines don't exist") + + return cmd +} + +func (cli *cliMachines) prune(ctx context.Context, duration time.Duration, notValidOnly bool, force bool) error { + if duration < 2*time.Minute && !notValidOnly { + if yes, err := ask.YesNo( + "The duration you provided is less than 2 minutes. "+ + "This can break installations if the machines are only temporarily disconnected. Continue?", false); err != nil { + return err + } else if !yes { + fmt.Println("User aborted prune. No changes were made.") + return nil + } + } + + machines := []*ent.Machine{} + if pending, err := cli.db.QueryPendingMachine(ctx); err == nil { + machines = append(machines, pending...) + } + + if !notValidOnly { + if pending, err := cli.db.QueryMachinesInactiveSince(ctx, time.Now().UTC().Add(-duration)); err == nil { + machines = append(machines, pending...) + } + } + + if len(machines) == 0 { + fmt.Println("No machines to prune.") + return nil + } + + cli.listHuman(color.Output, machines) + + if !force { + if yes, err := ask.YesNo( + "You are about to PERMANENTLY remove the above machines from the database. "+ + "These will NOT be recoverable. Continue?", false); err != nil { + return err + } else if !yes { + fmt.Println("User aborted prune. No changes were made.") + return nil + } + } + + deleted, err := cli.db.BulkDeleteWatchers(ctx, machines) + if err != nil { + return fmt.Errorf("unable to prune machines: %w", err) + } + + fmt.Fprintf(os.Stderr, "successfully deleted %d machines\n", deleted) + + return nil +} + +func (cli *cliMachines) newPruneCmd() *cobra.Command { + var ( + duration time.Duration + notValidOnly bool + force bool + ) + + const defaultDuration = 10 * time.Minute + + cmd := &cobra.Command{ + Use: "prune", + Short: "prune multiple machines from the database", + Long: `prune multiple machines that are not validated or have not connected to the local API in a given duration.`, + Example: `cscli machines prune +cscli machines prune --duration 1h +cscli machines prune --not-validated-only --force`, + Args: cobra.NoArgs, + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.prune(cmd.Context(), duration, notValidOnly, force) + }, + } + + flags := cmd.Flags() + flags.DurationVarP(&duration, "duration", "d", defaultDuration, "duration of time since validated machine last heartbeat") + flags.BoolVar(¬ValidOnly, "not-validated-only", false, "only prune machines that are not validated") + flags.BoolVar(&force, "force", false, "force prune without asking for confirmation") + + return cmd +} + +func (cli *cliMachines) validate(ctx context.Context, machineID string) error { + if err := cli.db.ValidateMachine(ctx, machineID); err != nil { + return fmt.Errorf("unable to validate machine '%s': %w", machineID, err) + } + + log.Infof("machine '%s' validated successfully", machineID) + + return nil +} + +func (cli *cliMachines) newValidateCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "validate", + Short: "validate a machine to access the local API", + Long: `validate a machine to access the local API.`, + Example: `cscli machines validate "machine_name"`, + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.validate(cmd.Context(), args[0]) + }, + } + + return cmd +} + +func (cli *cliMachines) inspectHuman(out io.Writer, machine *ent.Machine) { + t := cstable.New(out, cli.cfg().Cscli.Color).Writer + + t.SetTitle("Machine: " + machine.MachineId) + + t.SetColumnConfigs([]table.ColumnConfig{ + {Number: 1, AutoMerge: true}, + }) + + t.AppendRows([]table.Row{ + {"IP Address", machine.IpAddress}, + {"Created At", machine.CreatedAt}, + {"Last Update", machine.UpdatedAt}, + {"Last Heartbeat", machine.LastHeartbeat}, + {"Validated?", machine.IsValidated}, + {"CrowdSec version", machine.Version}, + {"OS", clientinfo.GetOSNameAndVersion(machine)}, + {"Auth type", machine.AuthType}, + }) + + for dsName, dsCount := range machine.Datasources { + t.AppendRow(table.Row{"Datasources", fmt.Sprintf("%s: %d", dsName, dsCount)}) + } + + for _, ff := range clientinfo.GetFeatureFlagList(machine) { + t.AppendRow(table.Row{"Feature Flags", ff}) + } + + for _, coll := range machine.Hubstate[cwhub.COLLECTIONS] { + t.AppendRow(table.Row{"Collections", coll.Name}) + } + + io.WriteString(out, t.Render()+"\n") +} + +func (cli *cliMachines) inspect(machine *ent.Machine) error { + out := color.Output + outputFormat := cli.cfg().Cscli.Output + + switch outputFormat { + case "human": + cli.inspectHuman(out, machine) + case "json": + enc := json.NewEncoder(out) + enc.SetIndent("", " ") + + if err := enc.Encode(newMachineInfo(machine)); err != nil { + return errors.New("failed to serialize") + } + + return nil + default: + return fmt.Errorf("output format '%s' not supported for this command", outputFormat) + } + + return nil +} + +func (cli *cliMachines) inspectHub(machine *ent.Machine) error { + out := color.Output + + switch cli.cfg().Cscli.Output { + case "human": + cli.inspectHubHuman(out, machine) + case "json": + enc := json.NewEncoder(out) + enc.SetIndent("", " ") + + if err := enc.Encode(machine.Hubstate); err != nil { + return errors.New("failed to serialize") + } + + return nil + case "raw": + csvwriter := csv.NewWriter(out) + + err := csvwriter.Write([]string{"type", "name", "status", "version"}) + if err != nil { + return fmt.Errorf("failed to write header: %w", err) + } + + rows := make([][]string, 0) + + for itemType, items := range machine.Hubstate { + for _, item := range items { + rows = append(rows, []string{itemType, item.Name, item.Status, item.Version}) + } + } + + for _, row := range rows { + if err := csvwriter.Write(row); err != nil { + return fmt.Errorf("failed to write raw output: %w", err) + } + } + + csvwriter.Flush() + } + + return nil +} + +func (cli *cliMachines) newInspectCmd() *cobra.Command { + var showHub bool + + cmd := &cobra.Command{ + Use: "inspect [machine_name]", + Short: "inspect a machine by name", + Example: `cscli machines inspect "machine1"`, + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + ValidArgsFunction: cli.validMachineID, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + machineID := args[0] + + machine, err := cli.db.QueryMachineByID(ctx, machineID) + if err != nil { + return fmt.Errorf("unable to read machine data '%s': %w", machineID, err) + } + + if showHub { + return cli.inspectHub(machine) + } + + return cli.inspect(machine) + }, + } + + flags := cmd.Flags() + + flags.BoolVarP(&showHub, "hub", "H", false, "show hub state") + + return cmd +} diff --git a/cmd/crowdsec-cli/climetrics/list.go b/cmd/crowdsec-cli/climetrics/list.go new file mode 100644 index 00000000000..ddb2baac14d --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/list.go @@ -0,0 +1,95 @@ +package climetrics + +import ( + "encoding/json" + "fmt" + "io" + + "github.com/fatih/color" + "github.com/jedib0t/go-pretty/v6/table" + "github.com/jedib0t/go-pretty/v6/text" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/go-cs-lib/maptools" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" +) + +func (cli *cliMetrics) list() error { + type metricType struct { + Type string `json:"type" yaml:"type"` + Title string `json:"title" yaml:"title"` + Description string `json:"description" yaml:"description"` + } + + var allMetrics []metricType + + ms := NewMetricStore() + for _, section := range maptools.SortedKeys(ms) { + title, description := ms[section].Description() + allMetrics = append(allMetrics, metricType{ + Type: section, + Title: title, + Description: description, + }) + } + + outputFormat := cli.cfg().Cscli.Output + + switch outputFormat { + case "human": + out := color.Output + t := cstable.New(out, cli.cfg().Cscli.Color).Writer + t.AppendHeader(table.Row{"Type", "Title", "Description"}) + t.SetColumnConfigs([]table.ColumnConfig{ + { + Name: "Type", + AlignHeader: text.AlignCenter, + }, + { + Name: "Title", + AlignHeader: text.AlignCenter, + }, + { + Name: "Description", + AlignHeader: text.AlignCenter, + WidthMax: 60, + WidthMaxEnforcer: text.WrapSoft, + }, + }) + + t.Style().Options.SeparateRows = true + + for _, metric := range allMetrics { + t.AppendRow(table.Row{metric.Type, metric.Title, metric.Description}) + } + + io.WriteString(out, t.Render()+"\n") + case "json": + x, err := json.MarshalIndent(allMetrics, "", " ") + if err != nil { + return fmt.Errorf("failed to serialize metric types: %w", err) + } + + fmt.Println(string(x)) + default: + return fmt.Errorf("output format '%s' not supported for this command", outputFormat) + } + + return nil +} + +func (cli *cliMetrics) newListCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "list", + Short: "List available types of metrics.", + Long: `List available types of metrics.`, + Args: cobra.ExactArgs(0), + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, _ []string) error { + return cli.list() + }, + } + + return cmd +} diff --git a/cmd/crowdsec-cli/climetrics/metrics.go b/cmd/crowdsec-cli/climetrics/metrics.go new file mode 100644 index 00000000000..f3bc4874460 --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/metrics.go @@ -0,0 +1,54 @@ +package climetrics + +import ( + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/pkg/csconfig" +) + +type configGetter func() *csconfig.Config + +type cliMetrics struct { + cfg configGetter +} + +func New(cfg configGetter) *cliMetrics { + return &cliMetrics{ + cfg: cfg, + } +} + +func (cli *cliMetrics) NewCommand() *cobra.Command { + var ( + url string + noUnit bool + ) + + cmd := &cobra.Command{ + Use: "metrics", + Short: "Display crowdsec prometheus metrics.", + Long: `Fetch metrics from a Local API server and display them`, + Example: `# Show all Metrics, skip empty tables (same as "cecli metrics show") +cscli metrics + +# Show only some metrics, connect to a different url +cscli metrics --url http://lapi.local:6060/metrics show acquisition parsers + +# List available metric types +cscli metrics list`, + Args: cobra.ExactArgs(0), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.show(cmd.Context(), nil, url, noUnit) + }, + } + + flags := cmd.Flags() + flags.StringVarP(&url, "url", "u", "", "Prometheus url (http://:/metrics)") + flags.BoolVar(&noUnit, "no-unit", false, "Show the real number instead of formatted with units") + + cmd.AddCommand(cli.newShowCmd()) + cmd.AddCommand(cli.newListCmd()) + + return cmd +} diff --git a/cmd/crowdsec-cli/climetrics/number.go b/cmd/crowdsec-cli/climetrics/number.go new file mode 100644 index 00000000000..709b7cf853a --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/number.go @@ -0,0 +1,45 @@ +package climetrics + +import ( + "fmt" + "math" + "strconv" +) + +type unit struct { + value int64 + symbol string +} + +var ranges = []unit{ + {value: 1e18, symbol: "E"}, + {value: 1e15, symbol: "P"}, + {value: 1e12, symbol: "T"}, + {value: 1e9, symbol: "G"}, + {value: 1e6, symbol: "M"}, + {value: 1e3, symbol: "k"}, + {value: 1, symbol: ""}, +} + +func formatNumber(num int64, withUnit bool) string { + if !withUnit { + return strconv.FormatInt(num, 10) + } + + goodUnit := ranges[len(ranges)-1] + + for _, u := range ranges { + if num >= u.value { + goodUnit = u + break + } + } + + if goodUnit.value == 1 { + return fmt.Sprintf("%d%s", num, goodUnit.symbol) + } + + res := math.Round(float64(num)/float64(goodUnit.value)*100) / 100 + + return fmt.Sprintf("%.2f%s", res, goodUnit.symbol) +} diff --git a/cmd/crowdsec-cli/climetrics/show.go b/cmd/crowdsec-cli/climetrics/show.go new file mode 100644 index 00000000000..045959048f6 --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/show.go @@ -0,0 +1,113 @@ +package climetrics + +import ( + "context" + "errors" + "fmt" + + "github.com/fatih/color" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" +) + +var ( + ErrMissingConfig = errors.New("prometheus section missing, can't show metrics") + ErrMetricsDisabled = errors.New("prometheus is not enabled, can't show metrics") +) + +func (cli *cliMetrics) show(ctx context.Context, sections []string, url string, noUnit bool) error { + cfg := cli.cfg() + + if url != "" { + cfg.Cscli.PrometheusUrl = url + } + + if cfg.Prometheus == nil { + return ErrMissingConfig + } + + if !cfg.Prometheus.Enabled { + return ErrMetricsDisabled + } + + ms := NewMetricStore() + + db, err := require.DBClient(ctx, cfg.DbConfig) + if err != nil { + log.Warnf("unable to open database: %s", err) + } + + if err := ms.Fetch(ctx, cfg.Cscli.PrometheusUrl, db); err != nil { + log.Warn(err) + } + + // any section that we don't have in the store is an error + for _, section := range sections { + if _, ok := ms[section]; !ok { + return fmt.Errorf("unknown metrics type: %s", section) + } + } + + return ms.Format(color.Output, cfg.Cscli.Color, sections, cfg.Cscli.Output, noUnit) +} + +// expandAlias returns a list of sections. The input can be a list of sections or alias. +func expandAlias(args []string) []string { + ret := []string{} + + for _, section := range args { + switch section { + case "engine": + ret = append(ret, "acquisition", "parsers", "scenarios", "stash", "whitelists") + case "lapi": + ret = append(ret, "alerts", "decisions", "lapi", "lapi-bouncer", "lapi-decisions", "lapi-machine") + case "appsec": + ret = append(ret, "appsec-engine", "appsec-rule") + default: + ret = append(ret, section) + } + } + + return ret +} + +func (cli *cliMetrics) newShowCmd() *cobra.Command { + var ( + url string + noUnit bool + ) + + cmd := &cobra.Command{ + Use: "show [type]...", + Short: "Display all or part of the available metrics.", + Long: `Fetch metrics from a Local API server and display them, optionally filtering on specific types.`, + Example: `# Show all Metrics, skip empty tables +cscli metrics show + +# Use an alias: "engine", "lapi" or "appsec" to show a group of metrics +cscli metrics show engine + +# Show some specific metrics, show empty tables, connect to a different url +cscli metrics show acquisition parsers scenarios stash --url http://lapi.local:6060/metrics + +# To list available metric types, use "cscli metrics list" +cscli metrics list; cscli metrics list -o json + +# Show metrics in json format +cscli metrics show acquisition parsers scenarios stash -o json`, + // Positional args are optional + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + args = expandAlias(args) + return cli.show(cmd.Context(), args, url, noUnit) + }, + } + + flags := cmd.Flags() + flags.StringVarP(&url, "url", "u", "", "Metrics url (http://:/metrics)") + flags.BoolVar(&noUnit, "no-unit", false, "Show the real number instead of formatted with units") + + return cmd +} diff --git a/cmd/crowdsec-cli/climetrics/statacquis.go b/cmd/crowdsec-cli/climetrics/statacquis.go new file mode 100644 index 00000000000..0af2e796f40 --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/statacquis.go @@ -0,0 +1,44 @@ +package climetrics + +import ( + "io" + + "github.com/jedib0t/go-pretty/v6/table" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" +) + +type statAcquis map[string]map[string]int + +func (s statAcquis) Description() (string, string) { + return "Acquisition Metrics", + `Measures the lines read, parsed, and unparsed per datasource. ` + + `Zero read lines indicate a misconfigured or inactive datasource. ` + + `Zero parsed lines means the parser(s) failed. ` + + `Non-zero parsed lines are fine as crowdsec selects relevant lines.` +} + +func (s statAcquis) Process(source, metric string, val int) { + if _, ok := s[source]; !ok { + s[source] = make(map[string]int) + } + + s[source][metric] += val +} + +func (s statAcquis) Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) { + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Source", "Lines read", "Lines parsed", "Lines unparsed", "Lines poured to bucket", "Lines whitelisted"}) + + keys := []string{"reads", "parsed", "unparsed", "pour", "whitelisted"} + + if numRows, err := metricsToTable(t, s, keys, noUnit); err != nil { + log.Warningf("while collecting acquis stats: %s", err) + } else if numRows > 0 || showEmpty { + title, _ := s.Description() + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") + io.WriteString(out, "\n") + } +} diff --git a/cmd/crowdsec-cli/climetrics/statalert.go b/cmd/crowdsec-cli/climetrics/statalert.go new file mode 100644 index 00000000000..942eceaa75c --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/statalert.go @@ -0,0 +1,45 @@ +package climetrics + +import ( + "io" + "strconv" + + "github.com/jedib0t/go-pretty/v6/table" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" +) + +type statAlert map[string]int + +func (s statAlert) Description() (string, string) { + return "Local API Alerts", + `Tracks the total number of past and present alerts for the installed scenarios.` +} + +func (s statAlert) Process(reason string, val int) { + s[reason] += val +} + +func (s statAlert) Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) { + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Reason", "Count"}) + + numRows := 0 + + // TODO: sort keys + for scenario, hits := range s { + t.AppendRow(table.Row{ + scenario, + strconv.Itoa(hits), + }) + + numRows++ + } + + if numRows > 0 || showEmpty { + title, _ := s.Description() + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") + io.WriteString(out, "\n") + } +} diff --git a/cmd/crowdsec-cli/climetrics/statappsecengine.go b/cmd/crowdsec-cli/climetrics/statappsecengine.go new file mode 100644 index 00000000000..d924375247f --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/statappsecengine.go @@ -0,0 +1,41 @@ +package climetrics + +import ( + "io" + + "github.com/jedib0t/go-pretty/v6/table" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" +) + +type statAppsecEngine map[string]map[string]int + +func (s statAppsecEngine) Description() (string, string) { + return "Appsec Metrics", + `Measures the number of parsed and blocked requests by the AppSec Component.` +} + +func (s statAppsecEngine) Process(appsecEngine, metric string, val int) { + if _, ok := s[appsecEngine]; !ok { + s[appsecEngine] = make(map[string]int) + } + + s[appsecEngine][metric] += val +} + +func (s statAppsecEngine) Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) { + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Appsec Engine", "Processed", "Blocked"}) + + keys := []string{"processed", "blocked"} + + if numRows, err := metricsToTable(t, s, keys, noUnit); err != nil { + log.Warningf("while collecting appsec stats: %s", err) + } else if numRows > 0 || showEmpty { + title, _ := s.Description() + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") + io.WriteString(out, "\n") + } +} diff --git a/cmd/crowdsec-cli/climetrics/statappsecrule.go b/cmd/crowdsec-cli/climetrics/statappsecrule.go new file mode 100644 index 00000000000..e06a7c2e2b3 --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/statappsecrule.go @@ -0,0 +1,48 @@ +package climetrics + +import ( + "fmt" + "io" + + "github.com/jedib0t/go-pretty/v6/table" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" +) + +type statAppsecRule map[string]map[string]map[string]int + +func (s statAppsecRule) Description() (string, string) { + return "Appsec Rule Metrics", + `Provides “per AppSec Component” information about the number of matches for loaded AppSec Rules.` +} + +func (s statAppsecRule) Process(appsecEngine, appsecRule string, metric string, val int) { + if _, ok := s[appsecEngine]; !ok { + s[appsecEngine] = make(map[string]map[string]int) + } + + if _, ok := s[appsecEngine][appsecRule]; !ok { + s[appsecEngine][appsecRule] = make(map[string]int) + } + + s[appsecEngine][appsecRule][metric] += val +} + +func (s statAppsecRule) Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) { + // TODO: sort keys + for appsecEngine, appsecEngineRulesStats := range s { + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Rule ID", "Triggered"}) + + keys := []string{"triggered"} + + if numRows, err := metricsToTable(t, appsecEngineRulesStats, keys, noUnit); err != nil { + log.Warningf("while collecting appsec rules stats: %s", err) + } else if numRows > 0 || showEmpty { + io.WriteString(out, fmt.Sprintf("Appsec '%s' Rules Metrics:\n", appsecEngine)) + io.WriteString(out, t.Render()+"\n") + io.WriteString(out, "\n") + } + } +} diff --git a/cmd/crowdsec-cli/climetrics/statbouncer.go b/cmd/crowdsec-cli/climetrics/statbouncer.go new file mode 100644 index 00000000000..bc0da152d6d --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/statbouncer.go @@ -0,0 +1,461 @@ +package climetrics + +import ( + "context" + "encoding/json" + "fmt" + "io" + "sort" + "strings" + "time" + + "github.com/jedib0t/go-pretty/v6/table" + "github.com/jedib0t/go-pretty/v6/text" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/go-cs-lib/maptools" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" + "github.com/crowdsecurity/crowdsec/pkg/database" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" + "github.com/crowdsecurity/crowdsec/pkg/models" +) + +// bouncerMetricItem represents unaggregated, denormalized metric data. +// Possibly not unique if a bouncer sent the same data multiple times. +type bouncerMetricItem struct { + collectedAt time.Time + bouncerName string + ipType string + origin string + name string + unit string + value float64 +} + +// aggregationOverTime is the first level of aggregation: we aggregate +// over time, then over ip type, then over origin. we only sum values +// for non-gauge metrics, and take the last value for gauge metrics. +type aggregationOverTime map[string]map[string]map[string]map[string]map[string]int64 + +func (a aggregationOverTime) add(bouncerName, origin, name, unit, ipType string, value float64, isGauge bool) { + if _, ok := a[bouncerName]; !ok { + a[bouncerName] = make(map[string]map[string]map[string]map[string]int64) + } + + if _, ok := a[bouncerName][origin]; !ok { + a[bouncerName][origin] = make(map[string]map[string]map[string]int64) + } + + if _, ok := a[bouncerName][origin][name]; !ok { + a[bouncerName][origin][name] = make(map[string]map[string]int64) + } + + if _, ok := a[bouncerName][origin][name][unit]; !ok { + a[bouncerName][origin][name][unit] = make(map[string]int64) + } + + if isGauge { + a[bouncerName][origin][name][unit][ipType] = int64(value) + } else { + a[bouncerName][origin][name][unit][ipType] += int64(value) + } +} + +// aggregationOverIPType is the second level of aggregation: data is summed +// regardless of the metrics type (gauge or not). This is used to display +// table rows, they won't differentiate ipv4 and ipv6 +type aggregationOverIPType map[string]map[string]map[string]map[string]int64 + +func (a aggregationOverIPType) add(bouncerName, origin, name, unit string, value int64) { + if _, ok := a[bouncerName]; !ok { + a[bouncerName] = make(map[string]map[string]map[string]int64) + } + + if _, ok := a[bouncerName][origin]; !ok { + a[bouncerName][origin] = make(map[string]map[string]int64) + } + + if _, ok := a[bouncerName][origin][name]; !ok { + a[bouncerName][origin][name] = make(map[string]int64) + } + + a[bouncerName][origin][name][unit] += value +} + +// aggregationOverOrigin is the third level of aggregation: these are +// the totals at the end of the table. Metrics without an origin will +// be added to the totals but not displayed in the rows, only in the footer. +type aggregationOverOrigin map[string]map[string]map[string]int64 + +func (a aggregationOverOrigin) add(bouncerName, name, unit string, value int64) { + if _, ok := a[bouncerName]; !ok { + a[bouncerName] = make(map[string]map[string]int64) + } + + if _, ok := a[bouncerName][name]; !ok { + a[bouncerName][name] = make(map[string]int64) + } + + a[bouncerName][name][unit] += value +} + +type statBouncer struct { + // oldest collection timestamp for each bouncer + oldestTS map[string]time.Time + // aggregate over ip type: always sum + // [bouncer][origin][name][unit]value + aggOverIPType aggregationOverIPType + // aggregate over origin: always sum + // [bouncer][name][unit]value + aggOverOrigin aggregationOverOrigin +} + +var knownPlurals = map[string]string{ + "byte": "bytes", + "packet": "packets", + "ip": "IPs", +} + +func (s *statBouncer) MarshalJSON() ([]byte, error) { + return json.Marshal(s.aggOverIPType) +} + +func (*statBouncer) Description() (string, string) { + return "Bouncer Metrics", + `Network traffic blocked by bouncers.` +} + +func logWarningOnce(warningsLogged map[string]bool, msg string) { + if _, ok := warningsLogged[msg]; !ok { + log.Warning(msg) + + warningsLogged[msg] = true + } +} + +// extractRawMetrics converts metrics from the database to a de-normalized, de-duplicated slice +// it returns the slice and the oldest timestamp for each bouncer +func (*statBouncer) extractRawMetrics(metrics []*ent.Metric) ([]bouncerMetricItem, map[string]time.Time) { + oldestTS := make(map[string]time.Time) + + // don't spam the user with the same warnings + warningsLogged := make(map[string]bool) + + // store raw metrics, de-duplicated in case some were sent multiple times + uniqueRaw := make(map[bouncerMetricItem]struct{}) + + for _, met := range metrics { + bouncerName := met.GeneratedBy + + var payload struct { + Metrics []models.DetailedMetrics `json:"metrics"` + } + + if err := json.Unmarshal([]byte(met.Payload), &payload); err != nil { + log.Warningf("while parsing metrics for %s: %s", bouncerName, err) + continue + } + + for _, m := range payload.Metrics { + // fields like timestamp, name, etc. are mandatory but we got pointers, so we check anyway + if m.Meta.UtcNowTimestamp == nil { + logWarningOnce(warningsLogged, "missing 'utc_now_timestamp' field in metrics reported by "+bouncerName) + continue + } + + collectedAt := time.Unix(*m.Meta.UtcNowTimestamp, 0).UTC() + + if oldestTS[bouncerName].IsZero() || collectedAt.Before(oldestTS[bouncerName]) { + oldestTS[bouncerName] = collectedAt + } + + for _, item := range m.Items { + valid := true + + if item.Name == nil { + logWarningOnce(warningsLogged, "missing 'name' field in metrics reported by "+bouncerName) + // no continue - keep checking the rest + valid = false + } + + if item.Unit == nil { + logWarningOnce(warningsLogged, "missing 'unit' field in metrics reported by "+bouncerName) + valid = false + } + + if item.Value == nil { + logWarningOnce(warningsLogged, "missing 'value' field in metrics reported by "+bouncerName) + valid = false + } + + if !valid { + continue + } + + rawMetric := bouncerMetricItem{ + collectedAt: collectedAt, + bouncerName: bouncerName, + ipType: item.Labels["ip_type"], + origin: item.Labels["origin"], + name: *item.Name, + unit: *item.Unit, + value: *item.Value, + } + + uniqueRaw[rawMetric] = struct{}{} + } + } + } + + // extract raw metric structs + keys := make([]bouncerMetricItem, 0, len(uniqueRaw)) + for key := range uniqueRaw { + keys = append(keys, key) + } + + // order them by timestamp + sort.Slice(keys, func(i, j int) bool { + return keys[i].collectedAt.Before(keys[j].collectedAt) + }) + + return keys, oldestTS +} + +func (s *statBouncer) Fetch(ctx context.Context, db *database.Client) error { + if db == nil { + return nil + } + + // query all bouncer metrics that have not been flushed + + metrics, err := db.Ent.Metric.Query(). + Where(metric.GeneratedTypeEQ(metric.GeneratedTypeRC)). + All(ctx) + if err != nil { + return fmt.Errorf("unable to fetch metrics: %w", err) + } + + // de-normalize, de-duplicate metrics and keep the oldest timestamp for each bouncer + + rawMetrics, oldestTS := s.extractRawMetrics(metrics) + + s.oldestTS = oldestTS + aggOverTime := s.newAggregationOverTime(rawMetrics) + s.aggOverIPType = s.newAggregationOverIPType(aggOverTime) + s.aggOverOrigin = s.newAggregationOverOrigin(s.aggOverIPType) + + return nil +} + +// return true if the metric is a gauge and should not be aggregated +func (*statBouncer) isGauge(name string) bool { + return name == "active_decisions" || strings.HasSuffix(name, "_gauge") +} + +// formatMetricName returns the metric name to display in the table header +func (*statBouncer) formatMetricName(name string) string { + return strings.TrimSuffix(name, "_gauge") +} + +// formatMetricOrigin returns the origin to display in the table rows +// (for example, some users don't know what capi is) +func (*statBouncer) formatMetricOrigin(origin string) string { + switch origin { + case "CAPI": + return origin + " (community blocklist)" + case "cscli": + return origin + " (manual decisions)" + case "crowdsec": + return origin + " (security engine)" + default: + return origin + } +} + +func (s *statBouncer) newAggregationOverTime(rawMetrics []bouncerMetricItem) aggregationOverTime { + ret := aggregationOverTime{} + + for _, raw := range rawMetrics { + ret.add(raw.bouncerName, raw.origin, raw.name, raw.unit, raw.ipType, raw.value, s.isGauge(raw.name)) + } + + return ret +} + +func (*statBouncer) newAggregationOverIPType(aggMetrics aggregationOverTime) aggregationOverIPType { + ret := aggregationOverIPType{} + + for bouncerName := range aggMetrics { + for origin := range aggMetrics[bouncerName] { + for name := range aggMetrics[bouncerName][origin] { + for unit := range aggMetrics[bouncerName][origin][name] { + for ipType := range aggMetrics[bouncerName][origin][name][unit] { + value := aggMetrics[bouncerName][origin][name][unit][ipType] + ret.add(bouncerName, origin, name, unit, value) + } + } + } + } + } + + return ret +} + +func (*statBouncer) newAggregationOverOrigin(aggMetrics aggregationOverIPType) aggregationOverOrigin { + ret := aggregationOverOrigin{} + + for bouncerName := range aggMetrics { + for origin := range aggMetrics[bouncerName] { + for name := range aggMetrics[bouncerName][origin] { + for unit := range aggMetrics[bouncerName][origin][name] { + val := aggMetrics[bouncerName][origin][name][unit] + ret.add(bouncerName, name, unit, val) + } + } + } + } + + return ret +} + +// bouncerTable displays a table of metrics for a single bouncer +func (s *statBouncer) bouncerTable(out io.Writer, bouncerName string, wantColor string, noUnit bool) { + columns := make(map[string]map[string]struct{}) + + bouncerData, ok := s.aggOverOrigin[bouncerName] + if !ok { + // no metrics for this bouncer, skip. how did we get here ? + // anyway we can't honor the "showEmpty" flag in this case, + // we don't even have the table headers + return + } + + for metricName, units := range bouncerData { + // build a map of the metric names and units, to display dynamic columns + columns[metricName] = make(map[string]struct{}) + for unit := range units { + columns[metricName][unit] = struct{}{} + } + } + + if len(columns) == 0 { + return + } + + t := cstable.New(out, wantColor).Writer + header1 := table.Row{"Origin"} + header2 := table.Row{""} + colNum := 1 + + colCfg := []table.ColumnConfig{{ + Number: colNum, + AlignHeader: text.AlignLeft, + Align: text.AlignLeft, + AlignFooter: text.AlignRight, + }} + + for _, name := range maptools.SortedKeys(columns) { + for _, unit := range maptools.SortedKeys(columns[name]) { + colNum += 1 + + header1 = append(header1, s.formatMetricName(name)) + + // we don't add "s" to random words + if plural, ok := knownPlurals[unit]; ok { + unit = plural + } + + header2 = append(header2, unit) + colCfg = append(colCfg, table.ColumnConfig{ + Number: colNum, + AlignHeader: text.AlignCenter, + Align: text.AlignRight, + AlignFooter: text.AlignRight, + }) + } + } + + t.AppendHeader(header1, table.RowConfig{AutoMerge: true}) + t.AppendHeader(header2) + + t.SetColumnConfigs(colCfg) + + numRows := 0 + + // sort all the ranges for stable output + + for _, origin := range maptools.SortedKeys(s.aggOverIPType[bouncerName]) { + if origin == "" { + // if the metric has no origin (i.e. processed bytes/packets) + // we don't display it in the table body but it still gets aggreagted + // in the footer's totals + continue + } + + metrics := s.aggOverIPType[bouncerName][origin] + + row := table.Row{s.formatMetricOrigin(origin)} + + for _, name := range maptools.SortedKeys(columns) { + for _, unit := range maptools.SortedKeys(columns[name]) { + valStr := "-" + + if val, ok := metrics[name][unit]; ok { + valStr = formatNumber(val, !noUnit) + } + + row = append(row, valStr) + } + } + + t.AppendRow(row) + + numRows += 1 + } + + totals := s.aggOverOrigin[bouncerName] + + if numRows == 0 { + t.Style().Options.SeparateFooter = false + } + + footer := table.Row{"Total"} + + for _, name := range maptools.SortedKeys(columns) { + for _, unit := range maptools.SortedKeys(columns[name]) { + footer = append(footer, formatNumber(totals[name][unit], !noUnit)) + } + } + + t.AppendFooter(footer) + + title, _ := s.Description() + title = fmt.Sprintf("%s (%s)", title, bouncerName) + + if s.oldestTS != nil { + // if you change this to .Local() beware of tests + title = fmt.Sprintf("%s since %s", title, s.oldestTS[bouncerName].String()) + } + + // don't use SetTitle() because it draws the title inside table box + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") + // empty line between tables + io.WriteString(out, "\n") +} + +// Table displays a table of metrics for each bouncer +func (s *statBouncer) Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) { + found := false + + for _, bouncerName := range maptools.SortedKeys(s.aggOverOrigin) { + s.bouncerTable(out, bouncerName, wantColor, noUnit) + found = true + } + + if !found && showEmpty { + io.WriteString(out, "No bouncer metrics found.\n\n") + } +} diff --git a/cmd/crowdsec-cli/climetrics/statbucket.go b/cmd/crowdsec-cli/climetrics/statbucket.go new file mode 100644 index 00000000000..1882fe21df1 --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/statbucket.go @@ -0,0 +1,42 @@ +package climetrics + +import ( + "io" + + "github.com/jedib0t/go-pretty/v6/table" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" +) + +type statBucket map[string]map[string]int + +func (s statBucket) Description() (string, string) { + return "Scenario Metrics", + `Measure events in different scenarios. Current count is the number of buckets during metrics collection. ` + + `Overflows are past event-producing buckets, while Expired are the ones that didn’t receive enough events to Overflow.` +} + +func (s statBucket) Process(bucket, metric string, val int) { + if _, ok := s[bucket]; !ok { + s[bucket] = make(map[string]int) + } + + s[bucket][metric] += val +} + +func (s statBucket) Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) { + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Scenario", "Current Count", "Overflows", "Instantiated", "Poured", "Expired"}) + + keys := []string{"curr_count", "overflow", "instantiation", "pour", "underflow"} + + if numRows, err := metricsToTable(t, s, keys, noUnit); err != nil { + log.Warningf("while collecting scenario stats: %s", err) + } else if numRows > 0 || showEmpty { + title, _ := s.Description() + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") + io.WriteString(out, "\n") + } +} diff --git a/cmd/crowdsec-cli/climetrics/statdecision.go b/cmd/crowdsec-cli/climetrics/statdecision.go new file mode 100644 index 00000000000..b862f49ff12 --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/statdecision.go @@ -0,0 +1,60 @@ +package climetrics + +import ( + "io" + "strconv" + + "github.com/jedib0t/go-pretty/v6/table" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" +) + +type statDecision map[string]map[string]map[string]int + +func (s statDecision) Description() (string, string) { + return "Local API Decisions", + `Provides information about all currently active decisions. ` + + `Includes both local (crowdsec) and global decisions (CAPI), and lists subscriptions (lists).` +} + +func (s statDecision) Process(reason, origin, action string, val int) { + if _, ok := s[reason]; !ok { + s[reason] = make(map[string]map[string]int) + } + + if _, ok := s[reason][origin]; !ok { + s[reason][origin] = make(map[string]int) + } + + s[reason][origin][action] += val +} + +func (s statDecision) Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) { + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Reason", "Origin", "Action", "Count"}) + + numRows := 0 + + // TODO: sort by reason, origin, action + for reason, origins := range s { + for origin, actions := range origins { + for action, hits := range actions { + t.AppendRow(table.Row{ + reason, + origin, + action, + strconv.Itoa(hits), + }) + + numRows++ + } + } + } + + if numRows > 0 || showEmpty { + title, _ := s.Description() + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") + io.WriteString(out, "\n") + } +} diff --git a/cmd/crowdsec-cli/climetrics/statlapi.go b/cmd/crowdsec-cli/climetrics/statlapi.go new file mode 100644 index 00000000000..9559eacf0f4 --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/statlapi.go @@ -0,0 +1,56 @@ +package climetrics + +import ( + "io" + "strconv" + + "github.com/jedib0t/go-pretty/v6/table" + + "github.com/crowdsecurity/go-cs-lib/maptools" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" +) + +type statLapi map[string]map[string]int + +func (s statLapi) Description() (string, string) { + return "Local API Metrics", + `Monitors the requests made to local API routes.` +} + +func (s statLapi) Process(route, method string, val int) { + if _, ok := s[route]; !ok { + s[route] = make(map[string]int) + } + + s[route][method] += val +} + +func (s statLapi) Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) { + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Route", "Method", "Hits"}) + + // unfortunately, we can't reuse metricsToTable as the structure is too different :/ + numRows := 0 + + for _, alabel := range maptools.SortedKeys(s) { + astats := s[alabel] + + for _, sl := range maptools.SortedKeys(astats) { + t.AppendRow(table.Row{ + alabel, + sl, + strconv.Itoa(astats[sl]), + }) + + numRows++ + } + } + + if numRows > 0 || showEmpty { + title, _ := s.Description() + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") + io.WriteString(out, "\n") + } +} diff --git a/cmd/crowdsec-cli/climetrics/statlapibouncer.go b/cmd/crowdsec-cli/climetrics/statlapibouncer.go new file mode 100644 index 00000000000..5e5f63a79d3 --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/statlapibouncer.go @@ -0,0 +1,42 @@ +package climetrics + +import ( + "io" + + "github.com/jedib0t/go-pretty/v6/table" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" +) + +type statLapiBouncer map[string]map[string]map[string]int + +func (s statLapiBouncer) Description() (string, string) { + return "Local API Bouncers Metrics", + `Tracks total hits to remediation component related API routes.` +} + +func (s statLapiBouncer) Process(bouncer, route, method string, val int) { + if _, ok := s[bouncer]; !ok { + s[bouncer] = make(map[string]map[string]int) + } + + if _, ok := s[bouncer][route]; !ok { + s[bouncer][route] = make(map[string]int) + } + + s[bouncer][route][method] += val +} + +func (s statLapiBouncer) Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) { + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Bouncer", "Route", "Method", "Hits"}) + + numRows := lapiMetricsToTable(t, s) + + if numRows > 0 || showEmpty { + title, _ := s.Description() + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") + io.WriteString(out, "\n") + } +} diff --git a/cmd/crowdsec-cli/climetrics/statlapidecision.go b/cmd/crowdsec-cli/climetrics/statlapidecision.go new file mode 100644 index 00000000000..44f0e8f4b87 --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/statlapidecision.go @@ -0,0 +1,64 @@ +package climetrics + +import ( + "io" + "strconv" + + "github.com/jedib0t/go-pretty/v6/table" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" +) + +type statLapiDecision map[string]struct { + NonEmpty int + Empty int +} + +func (s statLapiDecision) Description() (string, string) { + return "Local API Bouncers Decisions", + `Tracks the number of empty/non-empty answers from LAPI to bouncers that are working in "live" mode.` +} + +func (s statLapiDecision) Process(bouncer, fam string, val int) { + if _, ok := s[bouncer]; !ok { + s[bouncer] = struct { + NonEmpty int + Empty int + }{} + } + + x := s[bouncer] + + switch fam { + case "cs_lapi_decisions_ko_total": + x.Empty += val + case "cs_lapi_decisions_ok_total": + x.NonEmpty += val + } + + s[bouncer] = x +} + +func (s statLapiDecision) Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) { + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Bouncer", "Empty answers", "Non-empty answers"}) + + numRows := 0 + + for bouncer, hits := range s { + t.AppendRow(table.Row{ + bouncer, + strconv.Itoa(hits.Empty), + strconv.Itoa(hits.NonEmpty), + }) + + numRows++ + } + + if numRows > 0 || showEmpty { + title, _ := s.Description() + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") + io.WriteString(out, "\n") + } +} diff --git a/cmd/crowdsec-cli/climetrics/statlapimachine.go b/cmd/crowdsec-cli/climetrics/statlapimachine.go new file mode 100644 index 00000000000..0e6693bea82 --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/statlapimachine.go @@ -0,0 +1,42 @@ +package climetrics + +import ( + "io" + + "github.com/jedib0t/go-pretty/v6/table" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" +) + +type statLapiMachine map[string]map[string]map[string]int + +func (s statLapiMachine) Description() (string, string) { + return "Local API Machines Metrics", + `Tracks the number of calls to the local API from each registered machine.` +} + +func (s statLapiMachine) Process(machine, route, method string, val int) { + if _, ok := s[machine]; !ok { + s[machine] = make(map[string]map[string]int) + } + + if _, ok := s[machine][route]; !ok { + s[machine][route] = make(map[string]int) + } + + s[machine][route][method] += val +} + +func (s statLapiMachine) Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) { + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Machine", "Route", "Method", "Hits"}) + + numRows := lapiMetricsToTable(t, s) + + if numRows > 0 || showEmpty { + title, _ := s.Description() + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") + io.WriteString(out, "\n") + } +} diff --git a/cmd/crowdsec-cli/climetrics/statparser.go b/cmd/crowdsec-cli/climetrics/statparser.go new file mode 100644 index 00000000000..520e68f9adf --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/statparser.go @@ -0,0 +1,43 @@ +package climetrics + +import ( + "io" + + "github.com/jedib0t/go-pretty/v6/table" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" +) + +type statParser map[string]map[string]int + +func (s statParser) Description() (string, string) { + return "Parser Metrics", + `Tracks the number of events processed by each parser and indicates success of failure. ` + + `Zero parsed lines means the parser(s) failed. ` + + `Non-zero unparsed lines are fine as crowdsec select relevant lines.` +} + +func (s statParser) Process(parser, metric string, val int) { + if _, ok := s[parser]; !ok { + s[parser] = make(map[string]int) + } + + s[parser][metric] += val +} + +func (s statParser) Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) { + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Parsers", "Hits", "Parsed", "Unparsed"}) + + keys := []string{"hits", "parsed", "unparsed"} + + if numRows, err := metricsToTable(t, s, keys, noUnit); err != nil { + log.Warningf("while collecting parsers stats: %s", err) + } else if numRows > 0 || showEmpty { + title, _ := s.Description() + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") + io.WriteString(out, "\n") + } +} diff --git a/cmd/crowdsec-cli/climetrics/statstash.go b/cmd/crowdsec-cli/climetrics/statstash.go new file mode 100644 index 00000000000..2729de931a1 --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/statstash.go @@ -0,0 +1,59 @@ +package climetrics + +import ( + "io" + "strconv" + + "github.com/jedib0t/go-pretty/v6/table" + + "github.com/crowdsecurity/go-cs-lib/maptools" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" +) + +type statStash map[string]struct { + Type string + Count int +} + +func (s statStash) Description() (string, string) { + return "Parser Stash Metrics", + `Tracks the status of stashes that might be created by various parsers and scenarios.` +} + +func (s statStash) Process(name, mtype string, val int) { + s[name] = struct { + Type string + Count int + }{ + Type: mtype, + Count: val, + } +} + +func (s statStash) Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) { + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Name", "Type", "Items"}) + + // unfortunately, we can't reuse metricsToTable as the structure is too different :/ + numRows := 0 + + for _, alabel := range maptools.SortedKeys(s) { + astats := s[alabel] + + t.AppendRow(table.Row{ + alabel, + astats.Type, + strconv.Itoa(astats.Count), + }) + + numRows++ + } + + if numRows > 0 || showEmpty { + title, _ := s.Description() + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") + io.WriteString(out, "\n") + } +} diff --git a/cmd/crowdsec-cli/climetrics/statwhitelist.go b/cmd/crowdsec-cli/climetrics/statwhitelist.go new file mode 100644 index 00000000000..7f533b45b4b --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/statwhitelist.go @@ -0,0 +1,43 @@ +package climetrics + +import ( + "io" + + "github.com/jedib0t/go-pretty/v6/table" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" +) + +type statWhitelist map[string]map[string]map[string]int + +func (s statWhitelist) Description() (string, string) { + return "Whitelist Metrics", + `Tracks the number of events processed and possibly whitelisted by each parser whitelist.` +} + +func (s statWhitelist) Process(whitelist, reason, metric string, val int) { + if _, ok := s[whitelist]; !ok { + s[whitelist] = make(map[string]map[string]int) + } + + if _, ok := s[whitelist][reason]; !ok { + s[whitelist][reason] = make(map[string]int) + } + + s[whitelist][reason][metric] += val +} + +func (s statWhitelist) Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) { + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Whitelist", "Reason", "Hits", "Whitelisted"}) + + if numRows, err := wlMetricsToTable(t, s, noUnit); err != nil { + log.Warningf("while collecting parsers stats: %s", err) + } else if numRows > 0 || showEmpty { + title, _ := s.Description() + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") + io.WriteString(out, "\n") + } +} diff --git a/cmd/crowdsec-cli/climetrics/store.go b/cmd/crowdsec-cli/climetrics/store.go new file mode 100644 index 00000000000..55fab5dbd7f --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/store.go @@ -0,0 +1,271 @@ +package climetrics + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "time" + + dto "github.com/prometheus/client_model/go" + "github.com/prometheus/prom2json" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/go-cs-lib/maptools" + "github.com/crowdsecurity/go-cs-lib/trace" + + "github.com/crowdsecurity/crowdsec/pkg/database" +) + +type metricSection interface { + Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) + Description() (string, string) +} + +type metricStore map[string]metricSection + +func NewMetricStore() metricStore { + return metricStore{ + "acquisition": statAcquis{}, + "alerts": statAlert{}, + "bouncers": &statBouncer{}, + "appsec-engine": statAppsecEngine{}, + "appsec-rule": statAppsecRule{}, + "decisions": statDecision{}, + "lapi": statLapi{}, + "lapi-bouncer": statLapiBouncer{}, + "lapi-decisions": statLapiDecision{}, + "lapi-machine": statLapiMachine{}, + "parsers": statParser{}, + "scenarios": statBucket{}, + "stash": statStash{}, + "whitelists": statWhitelist{}, + } +} + +func (ms metricStore) Fetch(ctx context.Context, url string, db *database.Client) error { + if err := ms["bouncers"].(*statBouncer).Fetch(ctx, db); err != nil { + return err + } + + return ms.fetchPrometheusMetrics(url) +} + +func (ms metricStore) fetchPrometheusMetrics(url string) error { + mfChan := make(chan *dto.MetricFamily, 1024) + errChan := make(chan error, 1) + + // Start with the DefaultTransport for sane defaults. + transport := http.DefaultTransport.(*http.Transport).Clone() + // Conservatively disable HTTP keep-alives as this program will only + // ever need a single HTTP request. + transport.DisableKeepAlives = true + // Timeout early if the server doesn't even return the headers. + transport.ResponseHeaderTimeout = time.Minute + go func() { + defer trace.CatchPanic("crowdsec/ShowPrometheus") + + err := prom2json.FetchMetricFamilies(url, mfChan, transport) + if err != nil { + errChan <- fmt.Errorf("while fetching metrics: %w", err) + return + } + errChan <- nil + }() + + result := []*prom2json.Family{} + for mf := range mfChan { + result = append(result, prom2json.NewFamily(mf)) + } + + if err := <-errChan; err != nil { + return err + } + + log.Debugf("Finished reading metrics output, %d entries", len(result)) + ms.processPrometheusMetrics(result) + + return nil +} + +func (ms metricStore) processPrometheusMetrics(result []*prom2json.Family) { + mAcquis := ms["acquisition"].(statAcquis) + mAlert := ms["alerts"].(statAlert) + mAppsecEngine := ms["appsec-engine"].(statAppsecEngine) + mAppsecRule := ms["appsec-rule"].(statAppsecRule) + mDecision := ms["decisions"].(statDecision) + mLapi := ms["lapi"].(statLapi) + mLapiBouncer := ms["lapi-bouncer"].(statLapiBouncer) + mLapiDecision := ms["lapi-decisions"].(statLapiDecision) + mLapiMachine := ms["lapi-machine"].(statLapiMachine) + mParser := ms["parsers"].(statParser) + mBucket := ms["scenarios"].(statBucket) + mStash := ms["stash"].(statStash) + mWhitelist := ms["whitelists"].(statWhitelist) + + for idx, fam := range result { + if !strings.HasPrefix(fam.Name, "cs_") { + continue + } + + log.Tracef("round %d", idx) + + for _, m := range fam.Metrics { + metric, ok := m.(prom2json.Metric) + if !ok { + log.Debugf("failed to convert metric to prom2json.Metric") + continue + } + + name, ok := metric.Labels["name"] + if !ok { + log.Debugf("no name in Metric %v", metric.Labels) + } + + source, ok := metric.Labels["source"] + if !ok { + log.Debugf("no source in Metric %v for %s", metric.Labels, fam.Name) + } else { + if srctype, ok := metric.Labels["type"]; ok { + source = srctype + ":" + source + } + } + + value := m.(prom2json.Metric).Value + machine := metric.Labels["machine"] + bouncer := metric.Labels["bouncer"] + + route := metric.Labels["route"] + method := metric.Labels["method"] + + reason := metric.Labels["reason"] + origin := metric.Labels["origin"] + action := metric.Labels["action"] + + appsecEngine := metric.Labels["appsec_engine"] + appsecRule := metric.Labels["rule_name"] + + mtype := metric.Labels["type"] + + fval, err := strconv.ParseFloat(value, 32) + if err != nil { + log.Errorf("Unexpected int value %s : %s", value, err) + } + + ival := int(fval) + + switch fam.Name { + // + // buckets + // + case "cs_bucket_created_total": + mBucket.Process(name, "instantiation", ival) + case "cs_buckets": + mBucket.Process(name, "curr_count", ival) + case "cs_bucket_overflowed_total": + mBucket.Process(name, "overflow", ival) + case "cs_bucket_poured_total": + mBucket.Process(name, "pour", ival) + mAcquis.Process(source, "pour", ival) + case "cs_bucket_underflowed_total": + mBucket.Process(name, "underflow", ival) + // + // parsers + // + case "cs_parser_hits_total": + mAcquis.Process(source, "reads", ival) + case "cs_parser_hits_ok_total": + mAcquis.Process(source, "parsed", ival) + case "cs_parser_hits_ko_total": + mAcquis.Process(source, "unparsed", ival) + case "cs_node_hits_total": + mParser.Process(name, "hits", ival) + case "cs_node_hits_ok_total": + mParser.Process(name, "parsed", ival) + case "cs_node_hits_ko_total": + mParser.Process(name, "unparsed", ival) + // + // whitelists + // + case "cs_node_wl_hits_total": + mWhitelist.Process(name, reason, "hits", ival) + case "cs_node_wl_hits_ok_total": + mWhitelist.Process(name, reason, "whitelisted", ival) + // track as well whitelisted lines at acquis level + mAcquis.Process(source, "whitelisted", ival) + // + // lapi + // + case "cs_lapi_route_requests_total": + mLapi.Process(route, method, ival) + case "cs_lapi_machine_requests_total": + mLapiMachine.Process(machine, route, method, ival) + case "cs_lapi_bouncer_requests_total": + mLapiBouncer.Process(bouncer, route, method, ival) + case "cs_lapi_decisions_ko_total", "cs_lapi_decisions_ok_total": + mLapiDecision.Process(bouncer, fam.Name, ival) + // + // decisions + // + case "cs_active_decisions": + mDecision.Process(reason, origin, action, ival) + case "cs_alerts": + mAlert.Process(reason, ival) + // + // stash + // + case "cs_cache_size": + mStash.Process(name, mtype, ival) + // + // appsec + // + case "cs_appsec_reqs_total": + mAppsecEngine.Process(appsecEngine, "processed", ival) + case "cs_appsec_block_total": + mAppsecEngine.Process(appsecEngine, "blocked", ival) + case "cs_appsec_rule_hits": + mAppsecRule.Process(appsecEngine, appsecRule, "triggered", ival) + default: + log.Debugf("unknown: %+v", fam.Name) + continue + } + } + } +} + +func (ms metricStore) Format(out io.Writer, wantColor string, sections []string, outputFormat string, noUnit bool) error { + // copy only the sections we want + want := map[string]metricSection{} + + // if explicitly asking for sections, we want to show empty tables + showEmpty := len(sections) > 0 + + // if no sections are specified, we want all of them + if len(sections) == 0 { + sections = maptools.SortedKeys(ms) + } + + for _, section := range sections { + want[section] = ms[section] + } + + switch outputFormat { + case "human": + for _, section := range maptools.SortedKeys(want) { + want[section].Table(out, wantColor, noUnit, showEmpty) + } + case "json": + x, err := json.MarshalIndent(want, "", " ") + if err != nil { + return fmt.Errorf("failed to serialize metrics: %w", err) + } + out.Write(x) + default: + return fmt.Errorf("output format '%s' not supported for this command", outputFormat) + } + + return nil +} diff --git a/cmd/crowdsec-cli/climetrics/table.go b/cmd/crowdsec-cli/climetrics/table.go new file mode 100644 index 00000000000..af13edce2f5 --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/table.go @@ -0,0 +1,122 @@ +package climetrics + +import ( + "errors" + "sort" + "strconv" + + "github.com/jedib0t/go-pretty/v6/table" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/go-cs-lib/maptools" +) + +// ErrNilTable means a nil pointer was passed instead of a table instance. This is a programming error. +var ErrNilTable = errors.New("nil table") + +func lapiMetricsToTable(t table.Writer, stats map[string]map[string]map[string]int) int { + // stats: machine -> route -> method -> count + // sort keys to keep consistent order when printing + machineKeys := []string{} + for k := range stats { + machineKeys = append(machineKeys, k) + } + + sort.Strings(machineKeys) + + numRows := 0 + + for _, machine := range machineKeys { + // oneRow: route -> method -> count + machineRow := stats[machine] + for routeName, route := range machineRow { + for methodName, count := range route { + row := table.Row{ + machine, + routeName, + methodName, + } + if count != 0 { + row = append(row, strconv.Itoa(count)) + } else { + row = append(row, "-") + } + + t.AppendRow(row) + + numRows++ + } + } + } + + return numRows +} + +func wlMetricsToTable(t table.Writer, stats map[string]map[string]map[string]int, noUnit bool) (int, error) { + if t == nil { + return 0, ErrNilTable + } + + numRows := 0 + + for _, name := range maptools.SortedKeys(stats) { + for _, reason := range maptools.SortedKeys(stats[name]) { + row := table.Row{ + name, + reason, + "-", + "-", + } + + for _, action := range maptools.SortedKeys(stats[name][reason]) { + value := stats[name][reason][action] + + switch action { + case "whitelisted": + row[3] = strconv.Itoa(value) + case "hits": + row[2] = strconv.Itoa(value) + default: + log.Debugf("unexpected counter '%s' for whitelists = %d", action, value) + } + } + + t.AppendRow(row) + + numRows++ + } + } + + return numRows, nil +} + +func metricsToTable(t table.Writer, stats map[string]map[string]int, keys []string, noUnit bool) (int, error) { + if t == nil { + return 0, ErrNilTable + } + + numRows := 0 + + for _, alabel := range maptools.SortedKeys(stats) { + astats, ok := stats[alabel] + if !ok { + continue + } + + row := table.Row{alabel} + + for _, sl := range keys { + if v, ok := astats[sl]; ok && v != 0 { + row = append(row, formatNumber(int64(v), !noUnit)) + } else { + row = append(row, "-") + } + } + + t.AppendRow(row) + + numRows++ + } + + return numRows, nil +} diff --git a/cmd/crowdsec-cli/clinotifications/notifications.go b/cmd/crowdsec-cli/clinotifications/notifications.go new file mode 100644 index 00000000000..5489faa37c8 --- /dev/null +++ b/cmd/crowdsec-cli/clinotifications/notifications.go @@ -0,0 +1,481 @@ +package clinotifications + +import ( + "context" + "encoding/csv" + "encoding/json" + "errors" + "fmt" + "io/fs" + "net/url" + "os" + "path/filepath" + "slices" + "strconv" + "strings" + "time" + + "github.com/fatih/color" + "github.com/go-openapi/strfmt" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "gopkg.in/tomb.v2" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/go-cs-lib/ptr" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/apiclient" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/csplugin" + "github.com/crowdsecurity/crowdsec/pkg/csprofiles" + "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +type NotificationsCfg struct { + Config csplugin.PluginConfig `json:"plugin_config"` + Profiles []*csconfig.ProfileCfg `json:"associated_profiles"` + ids []uint +} + +type configGetter func() *csconfig.Config + +type cliNotifications struct { + cfg configGetter +} + +func New(cfg configGetter) *cliNotifications { + return &cliNotifications{ + cfg: cfg, + } +} + +func (cli *cliNotifications) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "notifications [action]", + Short: "Helper for notification plugin configuration", + Long: "To list/inspect/test notification template", + Args: cobra.MinimumNArgs(1), + Aliases: []string{"notifications", "notification"}, + DisableAutoGenTag: true, + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + if err := require.LAPI(cfg); err != nil { + return err + } + if err := cfg.LoadAPIClient(); err != nil { + return fmt.Errorf("loading api client: %w", err) + } + + return require.Notifications(cfg) + }, + } + + cmd.AddCommand(cli.newListCmd()) + cmd.AddCommand(cli.newInspectCmd()) + cmd.AddCommand(cli.newReinjectCmd()) + cmd.AddCommand(cli.newTestCmd()) + + return cmd +} + +func (cli *cliNotifications) getPluginConfigs() (map[string]csplugin.PluginConfig, error) { + cfg := cli.cfg() + pcfgs := map[string]csplugin.PluginConfig{} + wf := func(path string, info fs.FileInfo, err error) error { + if info == nil { + return fmt.Errorf("error while traversing directory %s: %w", path, err) + } + + name := filepath.Join(cfg.ConfigPaths.NotificationDir, info.Name()) // Avoid calling info.Name() twice + if (strings.HasSuffix(name, "yaml") || strings.HasSuffix(name, "yml")) && !(info.IsDir()) { + ts, err := csplugin.ParsePluginConfigFile(name) + if err != nil { + return fmt.Errorf("loading notifification plugin configuration with %s: %w", name, err) + } + + for _, t := range ts { + csplugin.SetRequiredFields(&t) + pcfgs[t.Name] = t + } + } + + return nil + } + + if err := filepath.Walk(cfg.ConfigPaths.NotificationDir, wf); err != nil { + return nil, fmt.Errorf("while loading notifification plugin configuration: %w", err) + } + + return pcfgs, nil +} + +func (cli *cliNotifications) getProfilesConfigs() (map[string]NotificationsCfg, error) { + cfg := cli.cfg() + // A bit of a tricky stuff now: reconcile profiles and notification plugins + pcfgs, err := cli.getPluginConfigs() + if err != nil { + return nil, err + } + + ncfgs := map[string]NotificationsCfg{} + for _, pc := range pcfgs { + ncfgs[pc.Name] = NotificationsCfg{ + Config: pc, + } + } + + profiles, err := csprofiles.NewProfile(cfg.API.Server.Profiles) + if err != nil { + return nil, fmt.Errorf("while extracting profiles from configuration: %w", err) + } + + for profileID, profile := range profiles { + for _, notif := range profile.Cfg.Notifications { + pc, ok := pcfgs[notif] + if !ok { + return nil, fmt.Errorf("notification plugin '%s' does not exist", notif) + } + + tmp, ok := ncfgs[pc.Name] + if !ok { + return nil, fmt.Errorf("notification plugin '%s' does not exist", pc.Name) + } + + tmp.Profiles = append(tmp.Profiles, profile.Cfg) + tmp.ids = append(tmp.ids, uint(profileID)) + ncfgs[pc.Name] = tmp + } + } + + return ncfgs, nil +} + +func (cli *cliNotifications) newListCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "list", + Short: "list notifications plugins", + Long: `list notifications plugins and their status (active or not)`, + Example: `cscli notifications list`, + Args: cobra.ExactArgs(0), + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + ncfgs, err := cli.getProfilesConfigs() + if err != nil { + return fmt.Errorf("can't build profiles configuration: %w", err) + } + + if cfg.Cscli.Output == "human" { + notificationListTable(color.Output, cfg.Cscli.Color, ncfgs) + } else if cfg.Cscli.Output == "json" { + x, err := json.MarshalIndent(ncfgs, "", " ") + if err != nil { + return fmt.Errorf("failed to serialize notification configuration: %w", err) + } + fmt.Printf("%s", string(x)) + } else if cfg.Cscli.Output == "raw" { + csvwriter := csv.NewWriter(os.Stdout) + err := csvwriter.Write([]string{"Name", "Type", "Profile name"}) + if err != nil { + return fmt.Errorf("failed to write raw header: %w", err) + } + for _, b := range ncfgs { + profilesList := []string{} + for _, p := range b.Profiles { + profilesList = append(profilesList, p.Name) + } + err := csvwriter.Write([]string{b.Config.Name, b.Config.Type, strings.Join(profilesList, ", ")}) + if err != nil { + return fmt.Errorf("failed to write raw content: %w", err) + } + } + csvwriter.Flush() + } + + return nil + }, + } + + return cmd +} + +func (cli *cliNotifications) newInspectCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "inspect", + Short: "Inspect notifications plugin", + Long: `Inspect notifications plugin and show configuration`, + Example: `cscli notifications inspect `, + Args: cobra.ExactArgs(1), + ValidArgsFunction: cli.notificationConfigFilter, + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + cfg := cli.cfg() + ncfgs, err := cli.getProfilesConfigs() + if err != nil { + return fmt.Errorf("can't build profiles configuration: %w", err) + } + ncfg, ok := ncfgs[args[0]] + if !ok { + return fmt.Errorf("plugin '%s' does not exist or is not active", args[0]) + } + if cfg.Cscli.Output == "human" || cfg.Cscli.Output == "raw" { + fmt.Printf(" - %15s: %15s\n", "Type", ncfg.Config.Type) + fmt.Printf(" - %15s: %15s\n", "Name", ncfg.Config.Name) + fmt.Printf(" - %15s: %15s\n", "Timeout", ncfg.Config.TimeOut) + fmt.Printf(" - %15s: %15s\n", "Format", ncfg.Config.Format) + for k, v := range ncfg.Config.Config { + fmt.Printf(" - %15s: %15v\n", k, v) + } + } else if cfg.Cscli.Output == "json" { + x, err := json.MarshalIndent(cfg, "", " ") + if err != nil { + return fmt.Errorf("failed to serialize notification configuration: %w", err) + } + fmt.Printf("%s", string(x)) + } + + return nil + }, + } + + return cmd +} + +func (cli *cliNotifications) notificationConfigFilter(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + ncfgs, err := cli.getProfilesConfigs() + if err != nil { + return nil, cobra.ShellCompDirectiveError + } + + var ret []string + + for k := range ncfgs { + if strings.Contains(k, toComplete) && !slices.Contains(args, k) { + ret = append(ret, k) + } + } + + return ret, cobra.ShellCompDirectiveNoFileComp +} + +func (cli cliNotifications) newTestCmd() *cobra.Command { + var ( + pluginBroker csplugin.PluginBroker + pluginTomb tomb.Tomb + alertOverride string + ) + + cmd := &cobra.Command{ + Use: "test [plugin name]", + Short: "send a generic test alert to notification plugin", + Long: `send a generic test alert to a notification plugin even if it is not active in profiles`, + Example: `cscli notifications test [plugin_name]`, + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + ValidArgsFunction: cli.notificationConfigFilter, + PreRunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + cfg := cli.cfg() + pconfigs, err := cli.getPluginConfigs() + if err != nil { + return fmt.Errorf("can't build profiles configuration: %w", err) + } + pcfg, ok := pconfigs[args[0]] + if !ok { + return fmt.Errorf("plugin name: '%s' does not exist", args[0]) + } + // Create a single profile with plugin name as notification name + return pluginBroker.Init(ctx, cfg.PluginConfig, []*csconfig.ProfileCfg{ + { + Notifications: []string{ + pcfg.Name, + }, + }, + }, cfg.ConfigPaths) + }, + RunE: func(_ *cobra.Command, _ []string) error { + pluginTomb.Go(func() error { + pluginBroker.Run(&pluginTomb) + return nil + }) + alert := &models.Alert{ + Capacity: ptr.Of(int32(0)), + Decisions: []*models.Decision{{ + Duration: ptr.Of("4h"), + Scope: ptr.Of("Ip"), + Value: ptr.Of("10.10.10.10"), + Type: ptr.Of("ban"), + Scenario: ptr.Of("test alert"), + Origin: ptr.Of(types.CscliOrigin), + }}, + Events: []*models.Event{}, + EventsCount: ptr.Of(int32(1)), + Leakspeed: ptr.Of("0"), + Message: ptr.Of("test alert"), + ScenarioHash: ptr.Of(""), + Scenario: ptr.Of("test alert"), + ScenarioVersion: ptr.Of(""), + Simulated: ptr.Of(false), + Source: &models.Source{ + AsName: "", + AsNumber: "", + Cn: "", + IP: "10.10.10.10", + Range: "", + Scope: ptr.Of("Ip"), + Value: ptr.Of("10.10.10.10"), + }, + StartAt: ptr.Of(time.Now().UTC().Format(time.RFC3339)), + StopAt: ptr.Of(time.Now().UTC().Format(time.RFC3339)), + CreatedAt: time.Now().UTC().Format(time.RFC3339), + } + if err := yaml.Unmarshal([]byte(alertOverride), alert); err != nil { + return fmt.Errorf("failed to parse alert override: %w", err) + } + + pluginBroker.PluginChannel <- csplugin.ProfileAlert{ + ProfileID: uint(0), + Alert: alert, + } + + // time.Sleep(2 * time.Second) // There's no mechanism to ensure notification has been sent + pluginTomb.Kill(errors.New("terminating")) + pluginTomb.Wait() + + return nil + }, + } + cmd.Flags().StringVarP(&alertOverride, "alert", "a", "", "JSON string used to override alert fields in the generic alert (see crowdsec/pkg/models/alert.go in the source tree for the full definition of the object)") + + return cmd +} + +func (cli *cliNotifications) newReinjectCmd() *cobra.Command { + var ( + alertOverride string + alert *models.Alert + ) + + cmd := &cobra.Command{ + Use: "reinject", + Short: "reinject an alert into profiles to trigger notifications", + Long: `reinject an alert into profiles to be evaluated by the filter and sent to matched notifications plugins`, + Example: ` +cscli notifications reinject +cscli notifications reinject -a '{"remediation": false,"scenario":"notification/test"}' +cscli notifications reinject -a '{"remediation": true,"scenario":"notification/test"}' +`, + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + PreRunE: func(cmd *cobra.Command, args []string) error { + var err error + alert, err = cli.fetchAlertFromArgString(cmd.Context(), args[0]) + if err != nil { + return err + } + + return nil + }, + RunE: func(cmd *cobra.Command, _ []string) error { + var ( + pluginBroker csplugin.PluginBroker + pluginTomb tomb.Tomb + ) + + ctx := cmd.Context() + cfg := cli.cfg() + + if alertOverride != "" { + if err := json.Unmarshal([]byte(alertOverride), alert); err != nil { + return fmt.Errorf("can't parse data in the alert flag: %w", err) + } + } + + err := pluginBroker.Init(ctx, cfg.PluginConfig, cfg.API.Server.Profiles, cfg.ConfigPaths) + if err != nil { + return fmt.Errorf("can't initialize plugins: %w", err) + } + + pluginTomb.Go(func() error { + pluginBroker.Run(&pluginTomb) + return nil + }) + + profiles, err := csprofiles.NewProfile(cfg.API.Server.Profiles) + if err != nil { + return fmt.Errorf("cannot extract profiles from configuration: %w", err) + } + + for id, profile := range profiles { + _, matched, err := profile.EvaluateProfile(alert) + if err != nil { + return fmt.Errorf("can't evaluate profile %s: %w", profile.Cfg.Name, err) + } + if !matched { + log.Infof("The profile %s didn't match", profile.Cfg.Name) + continue + } + log.Infof("The profile %s matched, sending to its configured notification plugins", profile.Cfg.Name) + loop: + for { + select { + case pluginBroker.PluginChannel <- csplugin.ProfileAlert{ + ProfileID: uint(id), + Alert: alert, + }: + break loop + default: + time.Sleep(50 * time.Millisecond) + log.Info("sleeping\n") + } + } + + if profile.Cfg.OnSuccess == "break" { + log.Infof("The profile %s contains a 'on_success: break' so bailing out", profile.Cfg.Name) + break + } + } + // time.Sleep(2 * time.Second) // There's no mechanism to ensure notification has been sent + pluginTomb.Kill(errors.New("terminating")) + pluginTomb.Wait() + + return nil + }, + } + cmd.Flags().StringVarP(&alertOverride, "alert", "a", "", "JSON string used to override alert fields in the reinjected alert (see crowdsec/pkg/models/alert.go in the source tree for the full definition of the object)") + + return cmd +} + +func (cli *cliNotifications) fetchAlertFromArgString(ctx context.Context, toParse string) (*models.Alert, error) { + cfg := cli.cfg() + + id, err := strconv.Atoi(toParse) + if err != nil { + return nil, fmt.Errorf("bad alert id %s", toParse) + } + + apiURL, err := url.Parse(cfg.API.Client.Credentials.URL) + if err != nil { + return nil, fmt.Errorf("error parsing the URL of the API: %w", err) + } + + client, err := apiclient.NewClient(&apiclient.Config{ + MachineID: cfg.API.Client.Credentials.Login, + Password: strfmt.Password(cfg.API.Client.Credentials.Password), + URL: apiURL, + VersionPrefix: "v1", + }) + if err != nil { + return nil, fmt.Errorf("error creating the client for the API: %w", err) + } + + alert, _, err := client.Alerts.GetByID(ctx, id) + if err != nil { + return nil, fmt.Errorf("can't find alert with id %d: %w", id, err) + } + + return alert, nil +} diff --git a/cmd/crowdsec-cli/clinotifications/notifications_table.go b/cmd/crowdsec-cli/clinotifications/notifications_table.go new file mode 100644 index 00000000000..0b6a3f58efc --- /dev/null +++ b/cmd/crowdsec-cli/clinotifications/notifications_table.go @@ -0,0 +1,46 @@ +package clinotifications + +import ( + "io" + "sort" + "strings" + + "github.com/jedib0t/go-pretty/v6/text" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" + "github.com/crowdsecurity/crowdsec/pkg/emoji" +) + +func notificationListTable(out io.Writer, wantColor string, ncfgs map[string]NotificationsCfg) { + t := cstable.NewLight(out, wantColor) + t.SetHeaders("Active", "Name", "Type", "Profile name") + t.SetHeaderAlignment(text.AlignLeft, text.AlignLeft, text.AlignLeft, text.AlignLeft) + t.SetAlignment(text.AlignLeft, text.AlignLeft, text.AlignLeft, text.AlignLeft) + + keys := make([]string, 0, len(ncfgs)) + for k := range ncfgs { + keys = append(keys, k) + } + + sort.Slice(keys, func(i, j int) bool { + return len(ncfgs[keys[i]].Profiles) > len(ncfgs[keys[j]].Profiles) + }) + + for _, k := range keys { + b := ncfgs[k] + profilesList := []string{} + + for _, p := range b.Profiles { + profilesList = append(profilesList, p.Name) + } + + active := emoji.CheckMark + if len(profilesList) == 0 { + active = emoji.Prohibited + } + + t.AddRow(active, b.Config.Name, b.Config.Type, strings.Join(profilesList, ", ")) + } + + t.Render() +} diff --git a/cmd/crowdsec-cli/clipapi/papi.go b/cmd/crowdsec-cli/clipapi/papi.go new file mode 100644 index 00000000000..461215c3a39 --- /dev/null +++ b/cmd/crowdsec-cli/clipapi/papi.go @@ -0,0 +1,174 @@ +package clipapi + +import ( + "context" + "fmt" + "io" + "time" + + "github.com/fatih/color" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "gopkg.in/tomb.v2" + + "github.com/crowdsecurity/go-cs-lib/ptr" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/apiserver" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/database" +) + +type configGetter = func() *csconfig.Config + +type cliPapi struct { + cfg configGetter +} + +func New(cfg configGetter) *cliPapi { + return &cliPapi{ + cfg: cfg, + } +} + +func (cli *cliPapi) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "papi [action]", + Short: "Manage interaction with Polling API (PAPI)", + Args: cobra.MinimumNArgs(1), + DisableAutoGenTag: true, + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + if err := require.LAPI(cfg); err != nil { + return err + } + if err := require.CAPI(cfg); err != nil { + return err + } + + return require.PAPI(cfg) + }, + } + + cmd.AddCommand(cli.newStatusCmd()) + cmd.AddCommand(cli.newSyncCmd()) + + return cmd +} + +func (cli *cliPapi) Status(ctx context.Context, out io.Writer, db *database.Client) error { + cfg := cli.cfg() + + apic, err := apiserver.NewAPIC(ctx, cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists) + if err != nil { + return fmt.Errorf("unable to initialize API client: %w", err) + } + + papi, err := apiserver.NewPAPI(apic, db, cfg.API.Server.ConsoleConfig, log.GetLevel()) + if err != nil { + return fmt.Errorf("unable to initialize PAPI client: %w", err) + } + + perms, err := papi.GetPermissions(ctx) + if err != nil { + return fmt.Errorf("unable to get PAPI permissions: %w", err) + } + + lastTimestampStr, err := db.GetConfigItem(ctx, apiserver.PapiPullKey) + if err != nil { + lastTimestampStr = ptr.Of("never") + } + + // both can and did happen + if lastTimestampStr == nil || *lastTimestampStr == "0001-01-01T00:00:00Z" { + lastTimestampStr = ptr.Of("never") + } + + fmt.Fprint(out, "You can successfully interact with Polling API (PAPI)\n") + fmt.Fprintf(out, "Console plan: %s\n", perms.Plan) + fmt.Fprintf(out, "Last order received: %s\n", *lastTimestampStr) + fmt.Fprint(out, "PAPI subscriptions:\n") + + for _, sub := range perms.Categories { + fmt.Fprintf(out, " - %s\n", sub) + } + + return nil +} + +func (cli *cliPapi) newStatusCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "status", + Short: "Get status of the Polling API", + Args: cobra.MinimumNArgs(0), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + cfg := cli.cfg() + ctx := cmd.Context() + + db, err := require.DBClient(ctx, cfg.DbConfig) + if err != nil { + return err + } + + return cli.Status(ctx, color.Output, db) + }, + } + + return cmd +} + +func (cli *cliPapi) sync(ctx context.Context, out io.Writer, db *database.Client) error { + cfg := cli.cfg() + t := tomb.Tomb{} + + apic, err := apiserver.NewAPIC(ctx, cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists) + if err != nil { + return fmt.Errorf("unable to initialize API client: %w", err) + } + + t.Go(func() error { return apic.Push(ctx) }) + + papi, err := apiserver.NewPAPI(apic, db, cfg.API.Server.ConsoleConfig, log.GetLevel()) + if err != nil { + return fmt.Errorf("unable to initialize PAPI client: %w", err) + } + + t.Go(papi.SyncDecisions) + + err = papi.PullOnce(time.Time{}, true) + if err != nil { + return fmt.Errorf("unable to sync decisions: %w", err) + } + + log.Infof("Sending acknowledgements to CAPI") + + apic.Shutdown() + papi.Shutdown() + t.Wait() + time.Sleep(5 * time.Second) // FIXME: the push done by apic.Push is run inside a sub goroutine, sleep to make sure it's done + + return nil +} + +func (cli *cliPapi) newSyncCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "sync", + Short: "Sync with the Polling API, pulling all non-expired orders for the instance", + Args: cobra.MinimumNArgs(0), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + cfg := cli.cfg() + ctx := cmd.Context() + + db, err := require.DBClient(ctx, cfg.DbConfig) + if err != nil { + return err + } + + return cli.sync(ctx, color.Output, db) + }, + } + + return cmd +} diff --git a/cmd/crowdsec-cli/clisetup/setup.go b/cmd/crowdsec-cli/clisetup/setup.go new file mode 100644 index 00000000000..269cdfb78e9 --- /dev/null +++ b/cmd/crowdsec-cli/clisetup/setup.go @@ -0,0 +1,307 @@ +package clisetup + +import ( + "bytes" + "context" + "errors" + "fmt" + "os" + "os/exec" + + goccyyaml "github.com/goccy/go-yaml" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/setup" +) + +type configGetter func() *csconfig.Config + +type cliSetup struct { + cfg configGetter +} + +func New(cfg configGetter) *cliSetup { + return &cliSetup{ + cfg: cfg, + } +} + +func (cli *cliSetup) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "setup", + Short: "Tools to configure crowdsec", + Long: "Manage hub configuration and service detection", + Args: cobra.MinimumNArgs(0), + DisableAutoGenTag: true, + } + + cmd.AddCommand(cli.newDetectCmd()) + cmd.AddCommand(cli.newInstallHubCmd()) + cmd.AddCommand(cli.newDataSourcesCmd()) + cmd.AddCommand(cli.newValidateCmd()) + + return cmd +} + +type detectFlags struct { + detectConfigFile string + listSupportedServices bool + forcedUnits []string + forcedProcesses []string + forcedOSFamily string + forcedOSID string + forcedOSVersion string + skipServices []string + snubSystemd bool + outYaml bool +} + +func (f *detectFlags) bind(cmd *cobra.Command) { + defaultServiceDetect := csconfig.DefaultConfigPath("hub", "detect.yaml") + + flags := cmd.Flags() + flags.StringVar(&f.detectConfigFile, "detect-config", defaultServiceDetect, "path to service detection configuration") + flags.BoolVar(&f.listSupportedServices, "list-supported-services", false, "do not detect; only print supported services") + flags.StringSliceVar(&f.forcedUnits, "force-unit", nil, "force detection of a systemd unit (can be repeated)") + flags.StringSliceVar(&f.forcedProcesses, "force-process", nil, "force detection of a running process (can be repeated)") + flags.StringSliceVar(&f.skipServices, "skip-service", nil, "ignore a service, don't recommend hub/datasources (can be repeated)") + flags.StringVar(&f.forcedOSFamily, "force-os-family", "", "override OS.Family: one of linux, freebsd, windows or darwin") + flags.StringVar(&f.forcedOSID, "force-os-id", "", "override OS.ID=[debian | ubuntu | , redhat...]") + flags.StringVar(&f.forcedOSVersion, "force-os-version", "", "override OS.RawVersion (of OS or Linux distribution)") + flags.BoolVar(&f.snubSystemd, "snub-systemd", false, "don't use systemd, even if available") + flags.BoolVar(&f.outYaml, "yaml", false, "output yaml, not json") +} + +func (cli *cliSetup) newDetectCmd() *cobra.Command { + f := detectFlags{} + + cmd := &cobra.Command{ + Use: "detect", + Short: "detect running services, generate a setup file", + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + return cli.detect(f) + }, + } + + f.bind(cmd) + + return cmd +} + +func (cli *cliSetup) newInstallHubCmd() *cobra.Command { + var dryRun bool + + cmd := &cobra.Command{ + Use: "install-hub [setup_file] [flags]", + Short: "install items from a setup file", + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.install(cmd.Context(), dryRun, args[0]) + }, + } + + flags := cmd.Flags() + flags.BoolVar(&dryRun, "dry-run", false, "don't install anything; print out what would have been") + + return cmd +} + +func (cli *cliSetup) newDataSourcesCmd() *cobra.Command { + var toDir string + + cmd := &cobra.Command{ + Use: "datasources [setup_file] [flags]", + Short: "generate datasource (acquisition) configuration from a setup file", + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.dataSources(args[0], toDir) + }, + } + + flags := cmd.Flags() + flags.StringVar(&toDir, "to-dir", "", "write the configuration to a directory, in multiple files") + + return cmd +} + +func (cli *cliSetup) newValidateCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "validate [setup_file]", + Short: "validate a setup file", + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.validate(args[0]) + }, + } + + return cmd +} + +func (cli *cliSetup) detect(f detectFlags) error { + var ( + detectReader *os.File + err error + ) + + switch f.detectConfigFile { + case "-": + log.Tracef("Reading detection rules from stdin") + + detectReader = os.Stdin + default: + log.Tracef("Reading detection rules: %s", f.detectConfigFile) + + detectReader, err = os.Open(f.detectConfigFile) + if err != nil { + return err + } + } + + if !f.snubSystemd { + _, err = exec.LookPath("systemctl") + if err != nil { + log.Debug("systemctl not available: snubbing systemd") + + f.snubSystemd = true + } + } + + if f.forcedOSFamily == "" && f.forcedOSID != "" { + log.Debug("force-os-id is set: force-os-family defaults to 'linux'") + + f.forcedOSFamily = "linux" + } + + if f.listSupportedServices { + supported, err := setup.ListSupported(detectReader) + if err != nil { + return err + } + + for _, svc := range supported { + fmt.Println(svc) + } + + return nil + } + + opts := setup.DetectOptions{ + ForcedUnits: f.forcedUnits, + ForcedProcesses: f.forcedProcesses, + ForcedOS: setup.ExprOS{ + Family: f.forcedOSFamily, + ID: f.forcedOSID, + RawVersion: f.forcedOSVersion, + }, + SkipServices: f.skipServices, + SnubSystemd: f.snubSystemd, + } + + hubSetup, err := setup.Detect(detectReader, opts) + if err != nil { + return fmt.Errorf("detecting services: %w", err) + } + + setup, err := setupAsString(hubSetup, f.outYaml) + if err != nil { + return err + } + + fmt.Println(setup) + + return nil +} + +func setupAsString(cs setup.Setup, outYaml bool) (string, error) { + var ( + ret []byte + err error + ) + + wrap := func(err error) error { + return fmt.Errorf("while serializing setup: %w", err) + } + + indentLevel := 2 + buf := &bytes.Buffer{} + enc := yaml.NewEncoder(buf) + enc.SetIndent(indentLevel) + + if err = enc.Encode(cs); err != nil { + return "", wrap(err) + } + + if err = enc.Close(); err != nil { + return "", wrap(err) + } + + ret = buf.Bytes() + + if !outYaml { + // take a general approach to output json, so we avoid the + // double tags in the structures and can use go-yaml features + // missing from the json package + ret, err = goccyyaml.YAMLToJSON(ret) + if err != nil { + return "", wrap(err) + } + } + + return string(ret), nil +} + +func (cli *cliSetup) dataSources(fromFile string, toDir string) error { + input, err := os.ReadFile(fromFile) + if err != nil { + return fmt.Errorf("while reading setup file: %w", err) + } + + output, err := setup.DataSources(input, toDir) + if err != nil { + return err + } + + if toDir == "" { + fmt.Println(output) + } + + return nil +} + +func (cli *cliSetup) install(ctx context.Context, dryRun bool, fromFile string) error { + input, err := os.ReadFile(fromFile) + if err != nil { + return fmt.Errorf("while reading file %s: %w", fromFile, err) + } + + cfg := cli.cfg() + + hub, err := require.Hub(cfg, require.RemoteHub(ctx, cfg), log.StandardLogger()) + if err != nil { + return err + } + + return setup.InstallHubItems(ctx, hub, input, dryRun) +} + +func (cli *cliSetup) validate(fromFile string) error { + input, err := os.ReadFile(fromFile) + if err != nil { + return fmt.Errorf("while reading stdin: %w", err) + } + + if err = setup.Validate(input); err != nil { + fmt.Printf("%v\n", err) + return errors.New("invalid setup file") + } + + return nil +} diff --git a/cmd/crowdsec-cli/clisimulation/simulation.go b/cmd/crowdsec-cli/clisimulation/simulation.go new file mode 100644 index 00000000000..8136aa213c3 --- /dev/null +++ b/cmd/crowdsec-cli/clisimulation/simulation.go @@ -0,0 +1,286 @@ +package clisimulation + +import ( + "errors" + "fmt" + "os" + "slices" + + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/reload" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" +) + +type configGetter func() *csconfig.Config + +type cliSimulation struct { + cfg configGetter +} + +func New(cfg configGetter) *cliSimulation { + return &cliSimulation{ + cfg: cfg, + } +} + +func (cli *cliSimulation) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "simulation [command]", + Short: "Manage simulation status of scenarios", + Example: `cscli simulation status +cscli simulation enable crowdsecurity/ssh-bf +cscli simulation disable crowdsecurity/ssh-bf`, + DisableAutoGenTag: true, + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + if err := cli.cfg().LoadSimulation(); err != nil { + return err + } + if cli.cfg().Cscli.SimulationConfig == nil { + return errors.New("no simulation configured") + } + + return nil + }, + PersistentPostRun: func(cmd *cobra.Command, _ []string) { + if cmd.Name() != "status" { + log.Info(reload.Message) + } + }, + } + cmd.Flags().SortFlags = false + cmd.PersistentFlags().SortFlags = false + + cmd.AddCommand(cli.newEnableCmd()) + cmd.AddCommand(cli.newDisableCmd()) + cmd.AddCommand(cli.newStatusCmd()) + + return cmd +} + +func (cli *cliSimulation) newEnableCmd() *cobra.Command { + var forceGlobalSimulation bool + + cmd := &cobra.Command{ + Use: "enable [scenario] [-global]", + Short: "Enable the simulation, globally or on specified scenarios", + Example: `cscli simulation enable`, + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + hub, err := require.Hub(cli.cfg(), nil, nil) + if err != nil { + return err + } + + if len(args) > 0 { + for _, scenario := range args { + item := hub.GetItem(cwhub.SCENARIOS, scenario) + if item == nil { + log.Errorf("'%s' doesn't exist or is not a scenario", scenario) + continue + } + if !item.State.Installed { + log.Warningf("'%s' isn't enabled", scenario) + } + isExcluded := slices.Contains(cli.cfg().Cscli.SimulationConfig.Exclusions, scenario) + if *cli.cfg().Cscli.SimulationConfig.Simulation && !isExcluded { + log.Warning("global simulation is already enabled") + continue + } + if !*cli.cfg().Cscli.SimulationConfig.Simulation && isExcluded { + log.Warningf("simulation for '%s' already enabled", scenario) + continue + } + if *cli.cfg().Cscli.SimulationConfig.Simulation && isExcluded { + cli.removeFromExclusion(scenario) + log.Printf("simulation enabled for '%s'", scenario) + continue + } + cli.addToExclusion(scenario) + log.Printf("simulation mode for '%s' enabled", scenario) + } + if err := cli.dumpSimulationFile(); err != nil { + return fmt.Errorf("simulation enable: %w", err) + } + } else if forceGlobalSimulation { + if err := cli.enableGlobalSimulation(); err != nil { + return fmt.Errorf("unable to enable global simulation mode: %w", err) + } + } else { + _ = cmd.Help() + } + + return nil + }, + } + cmd.Flags().BoolVarP(&forceGlobalSimulation, "global", "g", false, "Enable global simulation (reverse mode)") + + return cmd +} + +func (cli *cliSimulation) newDisableCmd() *cobra.Command { + var forceGlobalSimulation bool + + cmd := &cobra.Command{ + Use: "disable [scenario]", + Short: "Disable the simulation mode. Disable only specified scenarios", + Example: `cscli simulation disable`, + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + if len(args) > 0 { + for _, scenario := range args { + isExcluded := slices.Contains(cli.cfg().Cscli.SimulationConfig.Exclusions, scenario) + if !*cli.cfg().Cscli.SimulationConfig.Simulation && !isExcluded { + log.Warningf("%s isn't in simulation mode", scenario) + continue + } + if !*cli.cfg().Cscli.SimulationConfig.Simulation && isExcluded { + cli.removeFromExclusion(scenario) + log.Printf("simulation mode for '%s' disabled", scenario) + continue + } + if isExcluded { + log.Warningf("simulation mode is enabled but is already disable for '%s'", scenario) + continue + } + cli.addToExclusion(scenario) + log.Printf("simulation mode for '%s' disabled", scenario) + } + if err := cli.dumpSimulationFile(); err != nil { + return fmt.Errorf("simulation disable: %w", err) + } + } else if forceGlobalSimulation { + if err := cli.disableGlobalSimulation(); err != nil { + return fmt.Errorf("unable to disable global simulation mode: %w", err) + } + } else { + _ = cmd.Help() + } + + return nil + }, + } + cmd.Flags().BoolVarP(&forceGlobalSimulation, "global", "g", false, "Disable global simulation (reverse mode)") + + return cmd +} + +func (cli *cliSimulation) newStatusCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "status", + Short: "Show simulation mode status", + Example: `cscli simulation status`, + DisableAutoGenTag: true, + Run: func(_ *cobra.Command, _ []string) { + cli.status() + }, + PersistentPostRun: func(cmd *cobra.Command, args []string) { + }, + } + + return cmd +} + +func (cli *cliSimulation) addToExclusion(name string) { + cfg := cli.cfg() + cfg.Cscli.SimulationConfig.Exclusions = append(cfg.Cscli.SimulationConfig.Exclusions, name) +} + +func (cli *cliSimulation) removeFromExclusion(name string) { + cfg := cli.cfg() + index := slices.Index(cfg.Cscli.SimulationConfig.Exclusions, name) + + // Remove element from the slice + cfg.Cscli.SimulationConfig.Exclusions[index] = cfg.Cscli.SimulationConfig.Exclusions[len(cfg.Cscli.SimulationConfig.Exclusions)-1] + cfg.Cscli.SimulationConfig.Exclusions[len(cfg.Cscli.SimulationConfig.Exclusions)-1] = "" + cfg.Cscli.SimulationConfig.Exclusions = cfg.Cscli.SimulationConfig.Exclusions[:len(cfg.Cscli.SimulationConfig.Exclusions)-1] +} + +func (cli *cliSimulation) enableGlobalSimulation() error { + cfg := cli.cfg() + cfg.Cscli.SimulationConfig.Simulation = new(bool) + *cfg.Cscli.SimulationConfig.Simulation = true + cfg.Cscli.SimulationConfig.Exclusions = []string{} + + if err := cli.dumpSimulationFile(); err != nil { + return fmt.Errorf("unable to dump simulation file: %w", err) + } + + log.Printf("global simulation: enabled") + + return nil +} + +func (cli *cliSimulation) dumpSimulationFile() error { + cfg := cli.cfg() + + newConfigSim, err := yaml.Marshal(cfg.Cscli.SimulationConfig) + if err != nil { + return fmt.Errorf("unable to serialize simulation configuration: %w", err) + } + + err = os.WriteFile(cfg.ConfigPaths.SimulationFilePath, newConfigSim, 0o644) + if err != nil { + return fmt.Errorf("write simulation config in '%s' failed: %w", cfg.ConfigPaths.SimulationFilePath, err) + } + + log.Debugf("updated simulation file %s", cfg.ConfigPaths.SimulationFilePath) + + return nil +} + +func (cli *cliSimulation) disableGlobalSimulation() error { + cfg := cli.cfg() + cfg.Cscli.SimulationConfig.Simulation = new(bool) + *cfg.Cscli.SimulationConfig.Simulation = false + + cfg.Cscli.SimulationConfig.Exclusions = []string{} + + newConfigSim, err := yaml.Marshal(cfg.Cscli.SimulationConfig) + if err != nil { + return fmt.Errorf("unable to serialize new simulation configuration: %w", err) + } + + err = os.WriteFile(cfg.ConfigPaths.SimulationFilePath, newConfigSim, 0o644) + if err != nil { + return fmt.Errorf("unable to write new simulation config in '%s': %w", cfg.ConfigPaths.SimulationFilePath, err) + } + + log.Printf("global simulation: disabled") + + return nil +} + +func (cli *cliSimulation) status() { + cfg := cli.cfg() + if cfg.Cscli.SimulationConfig == nil { + log.Printf("global simulation: disabled (configuration file is missing)") + return + } + + if *cfg.Cscli.SimulationConfig.Simulation { + log.Println("global simulation: enabled") + + if len(cfg.Cscli.SimulationConfig.Exclusions) > 0 { + log.Println("Scenarios not in simulation mode :") + + for _, scenario := range cfg.Cscli.SimulationConfig.Exclusions { + log.Printf(" - %s", scenario) + } + } + } else { + log.Println("global simulation: disabled") + + if len(cfg.Cscli.SimulationConfig.Exclusions) > 0 { + log.Println("Scenarios in simulation mode :") + + for _, scenario := range cfg.Cscli.SimulationConfig.Exclusions { + log.Printf(" - %s", scenario) + } + } + } +} diff --git a/cmd/crowdsec-cli/clisupport/support.go b/cmd/crowdsec-cli/clisupport/support.go new file mode 100644 index 00000000000..4474f5c8f11 --- /dev/null +++ b/cmd/crowdsec-cli/clisupport/support.go @@ -0,0 +1,642 @@ +package clisupport + +import ( + "archive/zip" + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "os" + "path/filepath" + "regexp" + "strconv" + "strings" + "time" + + "github.com/blackfireio/osinfo" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/go-cs-lib/trace" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clibouncer" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clicapi" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clihub" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clilapi" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/climachine" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/climetrics" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clipapi" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" + "github.com/crowdsecurity/crowdsec/pkg/cwversion" + "github.com/crowdsecurity/crowdsec/pkg/database" + "github.com/crowdsecurity/crowdsec/pkg/fflag" +) + +const ( + SUPPORT_METRICS_DIR = "metrics/" + SUPPORT_VERSION_PATH = "version.txt" + SUPPORT_FEATURES_PATH = "features.txt" + SUPPORT_OS_INFO_PATH = "osinfo.txt" + SUPPORT_HUB = "hub.txt" + SUPPORT_BOUNCERS_PATH = "lapi/bouncers.txt" + SUPPORT_AGENTS_PATH = "lapi/agents.txt" + SUPPORT_CROWDSEC_CONFIG_PATH = "config/crowdsec.yaml" + SUPPORT_LAPI_STATUS_PATH = "lapi_status.txt" + SUPPORT_CAPI_STATUS_PATH = "capi_status.txt" + SUPPORT_PAPI_STATUS_PATH = "papi_status.txt" + SUPPORT_ACQUISITION_DIR = "config/acquis/" + SUPPORT_CROWDSEC_PROFILE_PATH = "config/profiles.yaml" + SUPPORT_CRASH_DIR = "crash/" + SUPPORT_LOG_DIR = "log/" + SUPPORT_PPROF_DIR = "pprof/" +) + +// StringHook collects log entries in a string +type StringHook struct { + LogBuilder strings.Builder + LogLevels []log.Level +} + +func (hook *StringHook) Levels() []log.Level { + return hook.LogLevels +} + +func (hook *StringHook) Fire(entry *log.Entry) error { + logEntry, err := entry.String() + if err != nil { + return err + } + + hook.LogBuilder.WriteString(logEntry) + + return nil +} + +// from https://github.com/acarl005/stripansi +var reStripAnsi = regexp.MustCompile("[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))") + +func stripAnsiString(str string) string { + // the byte version doesn't strip correctly + return reStripAnsi.ReplaceAllString(str, "") +} + +func (cli *cliSupport) dumpMetrics(ctx context.Context, db *database.Client, zw *zip.Writer) error { + log.Info("Collecting prometheus metrics") + + cfg := cli.cfg() + + if cfg.Cscli.PrometheusUrl == "" { + log.Warn("can't collect metrics: prometheus_uri is not set") + } + + humanMetrics := new(bytes.Buffer) + + ms := climetrics.NewMetricStore() + + if err := ms.Fetch(ctx, cfg.Cscli.PrometheusUrl, db); err != nil { + return err + } + + if err := ms.Format(humanMetrics, cfg.Cscli.Color, nil, "human", false); err != nil { + return fmt.Errorf("could not format prometheus metrics: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, cfg.Cscli.PrometheusUrl, nil) + if err != nil { + return fmt.Errorf("could not create request to prometheus endpoint: %w", err) + } + + client := &http.Client{} + + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("could not get metrics from prometheus endpoint: %w", err) + } + + defer resp.Body.Close() + + cli.writeToZip(zw, SUPPORT_METRICS_DIR+"metrics.prometheus", time.Now(), resp.Body) + + stripped := stripAnsiString(humanMetrics.String()) + + cli.writeToZip(zw, SUPPORT_METRICS_DIR+"metrics.human", time.Now(), strings.NewReader(stripped)) + + return nil +} + +func (cli *cliSupport) dumpVersion(zw *zip.Writer) { + log.Info("Collecting version") + + cli.writeToZip(zw, SUPPORT_VERSION_PATH, time.Now(), strings.NewReader(cwversion.FullString())) +} + +func (cli *cliSupport) dumpFeatures(zw *zip.Writer) { + log.Info("Collecting feature flags") + + w := new(bytes.Buffer) + for _, k := range fflag.Crowdsec.GetEnabledFeatures() { + fmt.Fprintln(w, k) + } + + cli.writeToZip(zw, SUPPORT_FEATURES_PATH, time.Now(), w) +} + +func (cli *cliSupport) dumpOSInfo(zw *zip.Writer) error { + log.Info("Collecting OS info") + + info, err := osinfo.GetOSInfo() + if err != nil { + return err + } + + w := new(bytes.Buffer) + fmt.Fprintf(w, "Architecture: %s\n", info.Architecture) + fmt.Fprintf(w, "Family: %s\n", info.Family) + fmt.Fprintf(w, "ID: %s\n", info.ID) + fmt.Fprintf(w, "Name: %s\n", info.Name) + fmt.Fprintf(w, "Codename: %s\n", info.Codename) + fmt.Fprintf(w, "Version: %s\n", info.Version) + fmt.Fprintf(w, "Build: %s\n", info.Build) + + cli.writeToZip(zw, SUPPORT_OS_INFO_PATH, time.Now(), w) + + return nil +} + +func (cli *cliSupport) dumpHubItems(zw *zip.Writer, hub *cwhub.Hub) error { + log.Infof("Collecting hub") + + if hub == nil { + return errors.New("no hub connection") + } + + out := new(bytes.Buffer) + ch := clihub.New(cli.cfg) + + if err := ch.List(out, hub, false); err != nil { + return err + } + + stripped := stripAnsiString(out.String()) + + cli.writeToZip(zw, SUPPORT_HUB, time.Now(), strings.NewReader(stripped)) + + return nil +} + +func (cli *cliSupport) dumpBouncers(ctx context.Context, zw *zip.Writer, db *database.Client) error { + log.Info("Collecting bouncers") + + if db == nil { + return errors.New("no database connection") + } + + out := new(bytes.Buffer) + cb := clibouncer.New(cli.cfg) + + if err := cb.List(ctx, out, db); err != nil { + return err + } + + stripped := stripAnsiString(out.String()) + + cli.writeToZip(zw, SUPPORT_BOUNCERS_PATH, time.Now(), strings.NewReader(stripped)) + + return nil +} + +func (cli *cliSupport) dumpAgents(ctx context.Context, zw *zip.Writer, db *database.Client) error { + log.Info("Collecting agents") + + if db == nil { + return errors.New("no database connection") + } + + out := new(bytes.Buffer) + cm := climachine.New(cli.cfg) + + if err := cm.List(ctx, out, db); err != nil { + return err + } + + stripped := stripAnsiString(out.String()) + + cli.writeToZip(zw, SUPPORT_AGENTS_PATH, time.Now(), strings.NewReader(stripped)) + + return nil +} + +func (cli *cliSupport) dumpLAPIStatus(ctx context.Context, zw *zip.Writer, hub *cwhub.Hub) error { + log.Info("Collecting LAPI status") + + out := new(bytes.Buffer) + cl := clilapi.New(cli.cfg) + + err := cl.Status(ctx, out, hub) + if err != nil { + fmt.Fprintf(out, "%s\n", err) + } + + stripped := stripAnsiString(out.String()) + + cli.writeToZip(zw, SUPPORT_LAPI_STATUS_PATH, time.Now(), strings.NewReader(stripped)) + + return nil +} + +func (cli *cliSupport) dumpCAPIStatus(ctx context.Context, zw *zip.Writer, hub *cwhub.Hub) error { + log.Info("Collecting CAPI status") + + out := new(bytes.Buffer) + cc := clicapi.New(cli.cfg) + + err := cc.Status(ctx, out, hub) + if err != nil { + fmt.Fprintf(out, "%s\n", err) + } + + stripped := stripAnsiString(out.String()) + + cli.writeToZip(zw, SUPPORT_CAPI_STATUS_PATH, time.Now(), strings.NewReader(stripped)) + + return nil +} + +func (cli *cliSupport) dumpPAPIStatus(ctx context.Context, zw *zip.Writer, db *database.Client) error { + log.Info("Collecting PAPI status") + + out := new(bytes.Buffer) + cp := clipapi.New(cli.cfg) + + err := cp.Status(ctx, out, db) + if err != nil { + fmt.Fprintf(out, "%s\n", err) + } + + stripped := stripAnsiString(out.String()) + + cli.writeToZip(zw, SUPPORT_PAPI_STATUS_PATH, time.Now(), strings.NewReader(stripped)) + + return nil +} + +func (cli *cliSupport) dumpConfigYAML(zw *zip.Writer) error { + log.Info("Collecting crowdsec config") + + cfg := cli.cfg() + + config, err := os.ReadFile(*cfg.FilePath) + if err != nil { + return fmt.Errorf("could not read config file: %w", err) + } + + r := regexp.MustCompile(`(\s+password:|\s+user:|\s+host:)\s+.*`) + + redacted := r.ReplaceAll(config, []byte("$1 ****REDACTED****")) + + cli.writeToZip(zw, SUPPORT_CROWDSEC_CONFIG_PATH, time.Now(), bytes.NewReader(redacted)) + + return nil +} + +func (cli *cliSupport) dumpPprof(ctx context.Context, zw *zip.Writer, prometheusCfg csconfig.PrometheusCfg, endpoint string) error { + log.Infof("Collecting pprof/%s data", endpoint) + + ctx, cancel := context.WithTimeout(ctx, 120*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext( + ctx, + http.MethodGet, + fmt.Sprintf( + "http://%s/debug/pprof/%s?debug=1", + net.JoinHostPort( + prometheusCfg.ListenAddr, + strconv.Itoa(prometheusCfg.ListenPort), + ), + endpoint, + ), + nil, + ) + if err != nil { + return fmt.Errorf("could not create request to pprof endpoint: %w", err) + } + + client := &http.Client{} + + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("could not get pprof data from endpoint: %w", err) + } + + defer resp.Body.Close() + + cli.writeToZip(zw, SUPPORT_PPROF_DIR+endpoint+".pprof", time.Now(), resp.Body) + + return nil +} + +func (cli *cliSupport) dumpProfiles(zw *zip.Writer) { + log.Info("Collecting crowdsec profile") + + cfg := cli.cfg() + cli.writeFileToZip(zw, SUPPORT_CROWDSEC_PROFILE_PATH, cfg.API.Server.ProfilesPath) +} + +func (cli *cliSupport) dumpAcquisitionConfig(zw *zip.Writer) { + log.Info("Collecting acquisition config") + + cfg := cli.cfg() + + for _, filename := range cfg.Crowdsec.AcquisitionFiles { + fname := strings.ReplaceAll(filename, string(filepath.Separator), "___") + cli.writeFileToZip(zw, SUPPORT_ACQUISITION_DIR+fname, filename) + } +} + +func (cli *cliSupport) dumpLogs(zw *zip.Writer) error { + log.Info("Collecting CrowdSec logs") + + cfg := cli.cfg() + + logDir := cfg.Common.LogDir + + logFiles, err := filepath.Glob(filepath.Join(logDir, "crowdsec*.log")) + if err != nil { + return fmt.Errorf("could not list log files: %w", err) + } + + for _, filename := range logFiles { + cli.writeFileToZip(zw, SUPPORT_LOG_DIR+filepath.Base(filename), filename) + } + + return nil +} + +func (cli *cliSupport) dumpCrash(zw *zip.Writer) error { + log.Info("Collecting crash dumps") + + traceFiles, err := trace.List() + if err != nil { + return fmt.Errorf("could not list crash dumps: %w", err) + } + + for _, filename := range traceFiles { + cli.writeFileToZip(zw, SUPPORT_CRASH_DIR+filepath.Base(filename), filename) + } + + return nil +} + +type configGetter func() *csconfig.Config + +type cliSupport struct { + cfg configGetter +} + +func New(cfg configGetter) *cliSupport { + return &cliSupport{ + cfg: cfg, + } +} + +func (cli *cliSupport) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "support [action]", + Short: "Provide commands to help during support", + Args: cobra.MinimumNArgs(1), + DisableAutoGenTag: true, + } + + cmd.AddCommand(cli.NewDumpCmd()) + + return cmd +} + +// writeToZip adds a file to the zip archive, from a reader +func (cli *cliSupport) writeToZip(zipWriter *zip.Writer, filename string, mtime time.Time, reader io.Reader) { + header := &zip.FileHeader{ + Name: filename, + Method: zip.Deflate, + Modified: mtime, + } + + fw, err := zipWriter.CreateHeader(header) + if err != nil { + log.Errorf("could not add zip entry for %s: %s", filename, err) + return + } + + _, err = io.Copy(fw, reader) + if err != nil { + log.Errorf("could not write zip entry for %s: %s", filename, err) + } +} + +// writeFileToZip adds a file to the zip archive, from a file, and retains the mtime +func (cli *cliSupport) writeFileToZip(zw *zip.Writer, filename string, fromFile string) { + mtime := time.Now() + + fi, err := os.Stat(fromFile) + if err == nil { + mtime = fi.ModTime() + } + + fin, err := os.Open(fromFile) + if err != nil { + log.Errorf("could not open file %s: %s", fromFile, err) + return + } + defer fin.Close() + + cli.writeToZip(zw, filename, mtime, fin) +} + +func (cli *cliSupport) dump(ctx context.Context, outFile string) error { + var skipCAPI, skipLAPI, skipAgent bool + + collector := &StringHook{ + LogLevels: log.AllLevels, + } + log.AddHook(collector) + + cfg := cli.cfg() + + if outFile == "" { + outFile = filepath.Join(os.TempDir(), "crowdsec-support.zip") + } + + w := bytes.NewBuffer(nil) + zipWriter := zip.NewWriter(w) + + db, err := require.DBClient(ctx, cfg.DbConfig) + if err != nil { + log.Warn(err) + } + + if err = cfg.LoadAPIServer(true); err != nil { + log.Warnf("could not load LAPI, skipping CAPI check") + + skipCAPI = true + } + + if err = cfg.LoadCrowdsec(); err != nil { + log.Warnf("could not load agent config, skipping crowdsec config check") + + skipAgent = true + } + + hub, err := require.Hub(cfg, nil, nil) + if err != nil { + log.Warn("Could not init hub, running on LAPI ? Hub related information will not be collected") + // XXX: lapi status check requires scenarios, will return an error + } + + if cfg.API.Client == nil || cfg.API.Client.Credentials == nil { + log.Warn("no agent credentials found, skipping LAPI connectivity check") + + skipLAPI = true + } + + if cfg.API.Server == nil || cfg.API.Server.OnlineClient == nil || cfg.API.Server.OnlineClient.Credentials == nil { + log.Warn("no CAPI credentials found, skipping CAPI connectivity check") + + skipCAPI = true + } + + if err = cli.dumpMetrics(ctx, db, zipWriter); err != nil { + log.Warn(err) + } + + if err = cli.dumpOSInfo(zipWriter); err != nil { + log.Warnf("could not collect OS information: %s", err) + } + + if err = cli.dumpConfigYAML(zipWriter); err != nil { + log.Warnf("could not collect main config file: %s", err) + } + + if err = cli.dumpHubItems(zipWriter, hub); err != nil { + log.Warnf("could not collect hub information: %s", err) + } + + if err = cli.dumpBouncers(ctx, zipWriter, db); err != nil { + log.Warnf("could not collect bouncers information: %s", err) + } + + if err = cli.dumpAgents(ctx, zipWriter, db); err != nil { + log.Warnf("could not collect agents information: %s", err) + } + + if !skipCAPI { + if err = cli.dumpCAPIStatus(ctx, zipWriter, hub); err != nil { + log.Warnf("could not collect CAPI status: %s", err) + } + + if err = cli.dumpPAPIStatus(ctx, zipWriter, db); err != nil { + log.Warnf("could not collect PAPI status: %s", err) + } + } + + if !skipLAPI { + if err = cli.dumpLAPIStatus(ctx, zipWriter, hub); err != nil { + log.Warnf("could not collect LAPI status: %s", err) + } + + // call pprof separately, one might fail for timeout + + if err = cli.dumpPprof(ctx, zipWriter, *cfg.Prometheus, "goroutine"); err != nil { + log.Warnf("could not collect pprof goroutine data: %s", err) + } + + if err = cli.dumpPprof(ctx, zipWriter, *cfg.Prometheus, "heap"); err != nil { + log.Warnf("could not collect pprof heap data: %s", err) + } + + if err = cli.dumpPprof(ctx, zipWriter, *cfg.Prometheus, "profile"); err != nil { + log.Warnf("could not collect pprof cpu data: %s", err) + } + + cli.dumpProfiles(zipWriter) + } + + if !skipAgent { + cli.dumpAcquisitionConfig(zipWriter) + } + + if err = cli.dumpCrash(zipWriter); err != nil { + log.Warnf("could not collect crash dumps: %s", err) + } + + if err = cli.dumpLogs(zipWriter); err != nil { + log.Warnf("could not collect log files: %s", err) + } + + cli.dumpVersion(zipWriter) + cli.dumpFeatures(zipWriter) + + // log of the dump process, without color codes + collectedOutput := stripAnsiString(collector.LogBuilder.String()) + + cli.writeToZip(zipWriter, "dump.log", time.Now(), strings.NewReader(collectedOutput)) + + err = zipWriter.Close() + if err != nil { + return fmt.Errorf("could not finalize zip file: %w", err) + } + + if outFile == "-" { + _, err = os.Stdout.Write(w.Bytes()) + return err + } + + err = os.WriteFile(outFile, w.Bytes(), 0o600) + if err != nil { + return fmt.Errorf("could not write zip file to %s: %w", outFile, err) + } + + log.Infof("Written zip file to %s", outFile) + + return nil +} + +func (cli *cliSupport) NewDumpCmd() *cobra.Command { + var outFile string + + cmd := &cobra.Command{ + Use: "dump", + Short: "Dump all your configuration to a zip file for easier support", + Long: `Dump the following information: +- Crowdsec version +- OS version +- Enabled feature flags +- Latest Crowdsec logs (log processor, LAPI, remediation components) +- Installed collections, parsers, scenarios... +- Bouncers and machines list +- CAPI/LAPI status +- Crowdsec config (sensitive information like username and password are redacted) +- Crowdsec metrics +- Stack trace in case of process crash`, + Example: `cscli support dump +cscli support dump -f /tmp/crowdsec-support.zip +`, + Args: cobra.NoArgs, + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + output := cli.cfg().Cscli.Output + if output != "human" { + return fmt.Errorf("output format %s not supported for this command", output) + } + return cli.dump(cmd.Context(), outFile) + }, + } + + cmd.Flags().StringVarP(&outFile, "outFile", "f", "", "File to dump the information to") + + return cmd +} diff --git a/cmd/crowdsec-cli/collections.go b/cmd/crowdsec-cli/collections.go deleted file mode 100644 index 3e24a586034..00000000000 --- a/cmd/crowdsec-cli/collections.go +++ /dev/null @@ -1,183 +0,0 @@ -package main - -import ( - "fmt" - - "github.com/fatih/color" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - - "github.com/crowdsecurity/crowdsec/pkg/cwhub" -) - -func NewCollectionsCmd() *cobra.Command { - var cmdCollections = &cobra.Command{ - Use: "collections [action]", - Short: "Manage collections from hub", - Long: `Install/Remove/Upgrade/Inspect collections from the CrowdSec Hub.`, - /*TBD fix help*/ - Args: cobra.MinimumNArgs(1), - Aliases: []string{"collection"}, - DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if err := csConfig.LoadHub(); err != nil { - log.Fatal(err) - } - if csConfig.Hub == nil { - return fmt.Errorf("you must configure cli before interacting with hub") - } - - if err := cwhub.SetHubBranch(); err != nil { - return fmt.Errorf("error while setting hub branch: %s", err) - } - - if err := cwhub.GetHubIdx(csConfig.Hub); err != nil { - log.Info("Run 'sudo cscli hub update' to get the hub index") - log.Fatalf("Failed to get Hub index : %v", err) - } - - return nil - }, - PersistentPostRun: func(cmd *cobra.Command, args []string) { - if cmd.Name() == "inspect" || cmd.Name() == "list" { - return - } - log.Infof(ReloadMessage()) - }, - } - - var ignoreError bool - var cmdCollectionsInstall = &cobra.Command{ - Use: "install collection", - Short: "Install given collection(s)", - Long: `Fetch and install given collection(s) from hub`, - Example: `cscli collections install crowdsec/xxx crowdsec/xyz`, - Args: cobra.MinimumNArgs(1), - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - return compAllItems(cwhub.COLLECTIONS, args, toComplete) - }, - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - for _, name := range args { - t := cwhub.GetItem(cwhub.COLLECTIONS, name) - if t == nil { - nearestItem, score := GetDistance(cwhub.COLLECTIONS, name) - Suggest(cwhub.COLLECTIONS, name, nearestItem.Name, score, ignoreError) - continue - } - if err := cwhub.InstallItem(csConfig, name, cwhub.COLLECTIONS, forceAction, downloadOnly); err != nil { - if !ignoreError { - log.Fatalf("Error while installing '%s': %s", name, err) - } - log.Errorf("Error while installing '%s': %s", name, err) - } - } - }, - } - cmdCollectionsInstall.PersistentFlags().BoolVarP(&downloadOnly, "download-only", "d", false, "Only download packages, don't enable") - cmdCollectionsInstall.PersistentFlags().BoolVar(&forceAction, "force", false, "Force install : Overwrite tainted and outdated files") - cmdCollectionsInstall.PersistentFlags().BoolVar(&ignoreError, "ignore", false, "Ignore errors when installing multiple collections") - cmdCollections.AddCommand(cmdCollectionsInstall) - - var cmdCollectionsRemove = &cobra.Command{ - Use: "remove collection", - Short: "Remove given collection(s)", - Long: `Remove given collection(s) from hub`, - Example: `cscli collections remove crowdsec/xxx crowdsec/xyz`, - Aliases: []string{"delete"}, - DisableAutoGenTag: true, - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - return compInstalledItems(cwhub.COLLECTIONS, args, toComplete) - }, - Run: func(cmd *cobra.Command, args []string) { - if all { - cwhub.RemoveMany(csConfig, cwhub.COLLECTIONS, "", all, purge, forceAction) - return - } - - if len(args) == 0 { - log.Fatal("Specify at least one collection to remove or '--all' flag.") - } - - for _, name := range args { - if !forceAction { - item := cwhub.GetItem(cwhub.COLLECTIONS, name) - if item == nil { - log.Fatalf("unable to retrieve: %s\n", name) - } - if len(item.BelongsToCollections) > 0 { - log.Warningf("%s belongs to other collections :\n%s\n", name, item.BelongsToCollections) - log.Printf("Run 'sudo cscli collections remove %s --force' if you want to force remove this sub collection\n", name) - continue - } - } - cwhub.RemoveMany(csConfig, cwhub.COLLECTIONS, name, all, purge, forceAction) - } - }, - } - cmdCollectionsRemove.PersistentFlags().BoolVar(&purge, "purge", false, "Delete source file too") - cmdCollectionsRemove.PersistentFlags().BoolVar(&forceAction, "force", false, "Force remove : Remove tainted and outdated files") - cmdCollectionsRemove.PersistentFlags().BoolVar(&all, "all", false, "Delete all the collections") - cmdCollections.AddCommand(cmdCollectionsRemove) - - var cmdCollectionsUpgrade = &cobra.Command{ - Use: "upgrade collection", - Short: "Upgrade given collection(s)", - Long: `Fetch and upgrade given collection(s) from hub`, - Example: `cscli collections upgrade crowdsec/xxx crowdsec/xyz`, - DisableAutoGenTag: true, - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - return compInstalledItems(cwhub.COLLECTIONS, args, toComplete) - }, - Run: func(cmd *cobra.Command, args []string) { - if all { - cwhub.UpgradeConfig(csConfig, cwhub.COLLECTIONS, "", forceAction) - } else { - if len(args) == 0 { - log.Fatalf("no target collection to upgrade") - } - for _, name := range args { - cwhub.UpgradeConfig(csConfig, cwhub.COLLECTIONS, name, forceAction) - } - } - }, - } - cmdCollectionsUpgrade.PersistentFlags().BoolVarP(&all, "all", "a", false, "Upgrade all the collections") - cmdCollectionsUpgrade.PersistentFlags().BoolVar(&forceAction, "force", false, "Force upgrade : Overwrite tainted and outdated files") - cmdCollections.AddCommand(cmdCollectionsUpgrade) - - var cmdCollectionsInspect = &cobra.Command{ - Use: "inspect collection", - Short: "Inspect given collection", - Long: `Inspect given collection`, - Example: `cscli collections inspect crowdsec/xxx crowdsec/xyz`, - Args: cobra.MinimumNArgs(1), - DisableAutoGenTag: true, - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - return compInstalledItems(cwhub.COLLECTIONS, args, toComplete) - }, - Run: func(cmd *cobra.Command, args []string) { - for _, name := range args { - InspectItem(name, cwhub.COLLECTIONS) - } - }, - } - cmdCollectionsInspect.PersistentFlags().StringVarP(&prometheusURL, "url", "u", "", "Prometheus url") - cmdCollections.AddCommand(cmdCollectionsInspect) - - var cmdCollectionsList = &cobra.Command{ - Use: "list collection [-a]", - Short: "List all collections", - Long: `List all collections`, - Example: `cscli collections list`, - Args: cobra.ExactArgs(0), - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - ListItems(color.Output, []string{cwhub.COLLECTIONS}, args, false, true, all) - }, - } - cmdCollectionsList.PersistentFlags().BoolVarP(&all, "all", "a", false, "List disabled items as well") - cmdCollections.AddCommand(cmdCollectionsList) - - return cmdCollections -} diff --git a/cmd/crowdsec-cli/completion.go b/cmd/crowdsec-cli/completion.go index fd76b571d05..7b6531f5516 100644 --- a/cmd/crowdsec-cli/completion.go +++ b/cmd/crowdsec-cli/completion.go @@ -7,8 +7,7 @@ import ( ) func NewCompletionCmd() *cobra.Command { - - var completionCmd = &cobra.Command{ + completionCmd := &cobra.Command{ Use: "completion [bash|zsh|powershell|fish]", Short: "Generate completion script", Long: `To load completions: @@ -82,5 +81,6 @@ func NewCompletionCmd() *cobra.Command { } }, } + return completionCmd } diff --git a/cmd/crowdsec-cli/config.go b/cmd/crowdsec-cli/config.go index e60246db790..e88845798e2 100644 --- a/cmd/crowdsec-cli/config.go +++ b/cmd/crowdsec-cli/config.go @@ -4,19 +4,29 @@ import ( "github.com/spf13/cobra" ) -func NewConfigCmd() *cobra.Command { - cmdConfig := &cobra.Command{ +type cliConfig struct { + cfg configGetter +} + +func NewCLIConfig(cfg configGetter) *cliConfig { + return &cliConfig{ + cfg: cfg, + } +} + +func (cli *cliConfig) NewCommand() *cobra.Command { + cmd := &cobra.Command{ Use: "config [command]", Short: "Allows to view current config", Args: cobra.ExactArgs(0), DisableAutoGenTag: true, } - cmdConfig.AddCommand(NewConfigShowCmd()) - cmdConfig.AddCommand(NewConfigShowYAMLCmd()) - cmdConfig.AddCommand(NewConfigBackupCmd()) - cmdConfig.AddCommand(NewConfigRestoreCmd()) - cmdConfig.AddCommand(NewConfigFeatureFlagsCmd()) + cmd.AddCommand(cli.newShowCmd()) + cmd.AddCommand(cli.newShowYAMLCmd()) + cmd.AddCommand(cli.newBackupCmd()) + cmd.AddCommand(cli.newRestoreCmd()) + cmd.AddCommand(cli.newFeatureFlagsCmd()) - return cmdConfig + return cmd } diff --git a/cmd/crowdsec-cli/config_backup.go b/cmd/crowdsec-cli/config_backup.go index 717fc990b9b..d23aff80a78 100644 --- a/cmd/crowdsec-cli/config_backup.go +++ b/cmd/crowdsec-cli/config_backup.go @@ -1,6 +1,8 @@ package main import ( + "encoding/json" + "errors" "fmt" "os" "path/filepath" @@ -8,10 +10,86 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -/* Backup crowdsec configurations to directory : +func (cli *cliConfig) backupHub(dirPath string) error { + hub, err := require.Hub(cli.cfg(), nil, nil) + if err != nil { + return err + } + + for _, itemType := range cwhub.ItemTypes { + clog := log.WithField("type", itemType) + + itemMap := hub.GetItemMap(itemType) + if itemMap == nil { + clog.Infof("No %s to backup.", itemType) + continue + } + + itemDirectory := fmt.Sprintf("%s/%s/", dirPath, itemType) + if err = os.MkdirAll(itemDirectory, os.ModePerm); err != nil { + return fmt.Errorf("error while creating %s: %w", itemDirectory, err) + } + + upstreamParsers := []string{} + + for k, v := range itemMap { + clog = clog.WithField("file", v.Name) + if !v.State.Installed { // only backup installed ones + clog.Debugf("[%s]: not installed", k) + continue + } + + // for the local/tainted ones, we back up the full file + if v.State.Tainted || v.State.IsLocal() || !v.State.UpToDate { + // we need to backup stages for parsers + if itemType == cwhub.PARSERS || itemType == cwhub.POSTOVERFLOWS { + fstagedir := fmt.Sprintf("%s%s", itemDirectory, v.Stage) + if err = os.MkdirAll(fstagedir, os.ModePerm); err != nil { + return fmt.Errorf("error while creating stage dir %s: %w", fstagedir, err) + } + } + + clog.Debugf("[%s]: backing up file (tainted:%t local:%t up-to-date:%t)", k, v.State.Tainted, v.State.IsLocal(), v.State.UpToDate) + + tfile := fmt.Sprintf("%s%s/%s", itemDirectory, v.Stage, v.FileName) + if err = CopyFile(v.State.LocalPath, tfile); err != nil { + return fmt.Errorf("failed copy %s %s to %s: %w", itemType, v.State.LocalPath, tfile, err) + } + + clog.Infof("local/tainted saved %s to %s", v.State.LocalPath, tfile) + + continue + } + + clog.Debugf("[%s]: from hub, just backup name (up-to-date:%t)", k, v.State.UpToDate) + clog.Infof("saving, version:%s, up-to-date:%t", v.Version, v.State.UpToDate) + upstreamParsers = append(upstreamParsers, v.Name) + } + // write the upstream items + upstreamParsersFname := fmt.Sprintf("%s/upstream-%s.json", itemDirectory, itemType) + + upstreamParsersContent, err := json.MarshalIndent(upstreamParsers, "", " ") + if err != nil { + return fmt.Errorf("failed to serialize upstream parsers: %w", err) + } + + err = os.WriteFile(upstreamParsersFname, upstreamParsersContent, 0o644) + if err != nil { + return fmt.Errorf("unable to write to %s %s: %w", itemType, upstreamParsersFname, err) + } + + clog.Infof("Wrote %d entries for %s to %s", len(upstreamParsers), itemType, upstreamParsersFname) + } + + return nil +} + +/* + Backup crowdsec configurations to directory : - Main config (config.yaml) - Profiles config (profiles.yaml) @@ -19,19 +97,22 @@ import ( - Backup of API credentials (local API and online API) - List of scenarios, parsers, postoverflows and collections that are up-to-date - Tainted/local/out-of-date scenarios, parsers, postoverflows and collections +- Acquisition files (acquis.yaml, acquis.d/*.yaml) */ -func backupConfigToDirectory(dirPath string) error { +func (cli *cliConfig) backup(dirPath string) error { var err error + cfg := cli.cfg() + if dirPath == "" { - return fmt.Errorf("directory path can't be empty") + return errors.New("directory path can't be empty") } log.Infof("Starting configuration backup") /*if parent directory doesn't exist, bail out. create final dir with Mkdir*/ parentDir := filepath.Dir(dirPath) - if _, err := os.Stat(parentDir); err != nil { + if _, err = os.Stat(parentDir); err != nil { return fmt.Errorf("while checking parent directory %s existence: %w", parentDir, err) } @@ -39,10 +120,10 @@ func backupConfigToDirectory(dirPath string) error { return fmt.Errorf("while creating %s: %w", dirPath, err) } - if csConfig.ConfigPaths.SimulationFilePath != "" { + if cfg.ConfigPaths.SimulationFilePath != "" { backupSimulation := filepath.Join(dirPath, "simulation.yaml") - if err = CopyFile(csConfig.ConfigPaths.SimulationFilePath, backupSimulation); err != nil { - return fmt.Errorf("failed copy %s to %s: %w", csConfig.ConfigPaths.SimulationFilePath, backupSimulation, err) + if err = CopyFile(cfg.ConfigPaths.SimulationFilePath, backupSimulation); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", cfg.ConfigPaths.SimulationFilePath, backupSimulation, err) } log.Infof("Saved simulation to %s", backupSimulation) @@ -52,22 +133,22 @@ func backupConfigToDirectory(dirPath string) error { - backup AcquisitionFilePath - backup the other files of acquisition directory */ - if csConfig.Crowdsec != nil && csConfig.Crowdsec.AcquisitionFilePath != "" { + if cfg.Crowdsec != nil && cfg.Crowdsec.AcquisitionFilePath != "" { backupAcquisition := filepath.Join(dirPath, "acquis.yaml") - if err = CopyFile(csConfig.Crowdsec.AcquisitionFilePath, backupAcquisition); err != nil { - return fmt.Errorf("failed copy %s to %s: %s", csConfig.Crowdsec.AcquisitionFilePath, backupAcquisition, err) + if err = CopyFile(cfg.Crowdsec.AcquisitionFilePath, backupAcquisition); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", cfg.Crowdsec.AcquisitionFilePath, backupAcquisition, err) } } acquisBackupDir := filepath.Join(dirPath, "acquis") if err = os.Mkdir(acquisBackupDir, 0o700); err != nil { - return fmt.Errorf("error while creating %s: %s", acquisBackupDir, err) + return fmt.Errorf("error while creating %s: %w", acquisBackupDir, err) } - if csConfig.Crowdsec != nil && len(csConfig.Crowdsec.AcquisitionFiles) > 0 { - for _, acquisFile := range csConfig.Crowdsec.AcquisitionFiles { + if cfg.Crowdsec != nil && len(cfg.Crowdsec.AcquisitionFiles) > 0 { + for _, acquisFile := range cfg.Crowdsec.AcquisitionFiles { /*if it was the default one, it was already backup'ed*/ - if csConfig.Crowdsec.AcquisitionFilePath == acquisFile { + if cfg.Crowdsec.AcquisitionFilePath == acquisFile { continue } @@ -87,65 +168,48 @@ func backupConfigToDirectory(dirPath string) error { if ConfigFilePath != "" { backupMain := fmt.Sprintf("%s/config.yaml", dirPath) if err = CopyFile(ConfigFilePath, backupMain); err != nil { - return fmt.Errorf("failed copy %s to %s: %s", ConfigFilePath, backupMain, err) + return fmt.Errorf("failed copy %s to %s: %w", ConfigFilePath, backupMain, err) } log.Infof("Saved default yaml to %s", backupMain) } - if csConfig.API != nil && csConfig.API.Server != nil && csConfig.API.Server.OnlineClient != nil && csConfig.API.Server.OnlineClient.CredentialsFilePath != "" { + if cfg.API != nil && cfg.API.Server != nil && cfg.API.Server.OnlineClient != nil && cfg.API.Server.OnlineClient.CredentialsFilePath != "" { backupCAPICreds := fmt.Sprintf("%s/online_api_credentials.yaml", dirPath) - if err = CopyFile(csConfig.API.Server.OnlineClient.CredentialsFilePath, backupCAPICreds); err != nil { - return fmt.Errorf("failed copy %s to %s: %s", csConfig.API.Server.OnlineClient.CredentialsFilePath, backupCAPICreds, err) + if err = CopyFile(cfg.API.Server.OnlineClient.CredentialsFilePath, backupCAPICreds); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", cfg.API.Server.OnlineClient.CredentialsFilePath, backupCAPICreds, err) } log.Infof("Saved online API credentials to %s", backupCAPICreds) } - if csConfig.API != nil && csConfig.API.Client != nil && csConfig.API.Client.CredentialsFilePath != "" { + if cfg.API != nil && cfg.API.Client != nil && cfg.API.Client.CredentialsFilePath != "" { backupLAPICreds := fmt.Sprintf("%s/local_api_credentials.yaml", dirPath) - if err = CopyFile(csConfig.API.Client.CredentialsFilePath, backupLAPICreds); err != nil { - return fmt.Errorf("failed copy %s to %s: %s", csConfig.API.Client.CredentialsFilePath, backupLAPICreds, err) + if err = CopyFile(cfg.API.Client.CredentialsFilePath, backupLAPICreds); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", cfg.API.Client.CredentialsFilePath, backupLAPICreds, err) } log.Infof("Saved local API credentials to %s", backupLAPICreds) } - if csConfig.API != nil && csConfig.API.Server != nil && csConfig.API.Server.ProfilesPath != "" { + if cfg.API != nil && cfg.API.Server != nil && cfg.API.Server.ProfilesPath != "" { backupProfiles := fmt.Sprintf("%s/profiles.yaml", dirPath) - if err = CopyFile(csConfig.API.Server.ProfilesPath, backupProfiles); err != nil { - return fmt.Errorf("failed copy %s to %s: %s", csConfig.API.Server.ProfilesPath, backupProfiles, err) + if err = CopyFile(cfg.API.Server.ProfilesPath, backupProfiles); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", cfg.API.Server.ProfilesPath, backupProfiles, err) } log.Infof("Saved profiles to %s", backupProfiles) } - if err = BackupHub(dirPath); err != nil { - return fmt.Errorf("failed to backup hub config: %s", err) + if err = cli.backupHub(dirPath); err != nil { + return fmt.Errorf("failed to backup hub config: %w", err) } return nil } -func runConfigBackup(cmd *cobra.Command, args []string) error { - if err := csConfig.LoadHub(); err != nil { - return err - } - - if err := cwhub.GetHubIdx(csConfig.Hub); err != nil { - log.Info("Run 'sudo cscli hub update' to get the hub index") - return fmt.Errorf("failed to get Hub index: %w", err) - } - - if err := backupConfigToDirectory(args[0]); err != nil { - return fmt.Errorf("failed to backup config: %w", err) - } - - return nil -} - -func NewConfigBackupCmd() *cobra.Command { - cmdConfigBackup := &cobra.Command{ +func (cli *cliConfig) newBackupCmd() *cobra.Command { + cmd := &cobra.Command{ Use: `backup "directory"`, Short: "Backup current config", Long: `Backup the current crowdsec configuration including : @@ -159,8 +223,14 @@ func NewConfigBackupCmd() *cobra.Command { Example: `cscli config backup ./my-backup`, Args: cobra.ExactArgs(1), DisableAutoGenTag: true, - RunE: runConfigBackup, + RunE: func(_ *cobra.Command, args []string) error { + if err := cli.backup(args[0]); err != nil { + return fmt.Errorf("failed to backup config: %w", err) + } + + return nil + }, } - return cmdConfigBackup + return cmd } diff --git a/cmd/crowdsec-cli/config_feature_flags.go b/cmd/crowdsec-cli/config_feature_flags.go index ed672711fe8..d1dbe2b93b7 100644 --- a/cmd/crowdsec-cli/config_feature_flags.go +++ b/cmd/crowdsec-cli/config_feature_flags.go @@ -2,21 +2,16 @@ package main import ( "fmt" + "path/filepath" "github.com/fatih/color" "github.com/spf13/cobra" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/fflag" ) -func runConfigFeatureFlags(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - showRetired, err := flags.GetBool("retired") - if err != nil { - return err - } - +func (cli *cliConfig) featureFlags(showRetired bool) error { green := color.New(color.FgGreen).SprintFunc() red := color.New(color.FgRed).SprintFunc() yellow := color.New(color.FgYellow).SprintFunc() @@ -42,6 +37,7 @@ func runConfigFeatureFlags(cmd *cobra.Command, args []string) error { if feat.State == fflag.RetiredState { fmt.Printf("\n %s %s", magenta("RETIRED"), feat.DeprecationMsg) } + fmt.Println() } @@ -56,10 +52,12 @@ func runConfigFeatureFlags(cmd *cobra.Command, args []string) error { retired = append(retired, feat) continue } + if feat.IsEnabled() { enabled = append(enabled, feat) continue } + disabled = append(disabled, feat) } @@ -87,7 +85,14 @@ func runConfigFeatureFlags(cmd *cobra.Command, args []string) error { fmt.Println("To enable a feature you can: ") fmt.Println(" - set the environment variable CROWDSEC_FEATURE_ to true") - fmt.Printf(" - add the line '- ' to the file %s/feature.yaml\n", csConfig.ConfigPaths.ConfigDir) + + featurePath, err := filepath.Abs(csconfig.GetFeatureFilePath(ConfigFilePath)) + if err != nil { + // we already read the file, shouldn't happen + return err + } + + fmt.Printf(" - add the line '- ' to the file %s\n", featurePath) fmt.Println() if len(enabled) == 0 && len(disabled) == 0 { @@ -109,18 +114,22 @@ func runConfigFeatureFlags(cmd *cobra.Command, args []string) error { return nil } -func NewConfigFeatureFlagsCmd() *cobra.Command { - cmdConfigFeatureFlags := &cobra.Command{ +func (cli *cliConfig) newFeatureFlagsCmd() *cobra.Command { + var showRetired bool + + cmd := &cobra.Command{ Use: "feature-flags", Short: "Displays feature flag status", Long: `Displays the supported feature flags and their current status.`, Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - RunE: runConfigFeatureFlags, + RunE: func(_ *cobra.Command, _ []string) error { + return cli.featureFlags(showRetired) + }, } - flags := cmdConfigFeatureFlags.Flags() - flags.Bool("retired", false, "Show retired features") + flags := cmd.Flags() + flags.BoolVar(&showRetired, "retired", false, "Show retired features") - return cmdConfigFeatureFlags + return cmd } diff --git a/cmd/crowdsec-cli/config_restore.go b/cmd/crowdsec-cli/config_restore.go index 55ab7aa9bad..c32328485ec 100644 --- a/cmd/crowdsec-cli/config_restore.go +++ b/cmd/crowdsec-cli/config_restore.go @@ -1,26 +1,123 @@ package main import ( + "context" "encoding/json" "fmt" - "io" "os" "path/filepath" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "gopkg.in/yaml.v2" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -type OldAPICfg struct { - MachineID string `json:"machine_id"` - Password string `json:"password"` +func (cli *cliConfig) restoreHub(ctx context.Context, dirPath string) error { + cfg := cli.cfg() + + hub, err := require.Hub(cfg, require.RemoteHub(ctx, cfg), nil) + if err != nil { + return err + } + + for _, itype := range cwhub.ItemTypes { + itemDirectory := fmt.Sprintf("%s/%s/", dirPath, itype) + if _, err = os.Stat(itemDirectory); err != nil { + log.Infof("no %s in backup", itype) + continue + } + /*restore the upstream items*/ + upstreamListFN := fmt.Sprintf("%s/upstream-%s.json", itemDirectory, itype) + + file, err := os.ReadFile(upstreamListFN) + if err != nil { + return fmt.Errorf("error while opening %s: %w", upstreamListFN, err) + } + + var upstreamList []string + + err = json.Unmarshal(file, &upstreamList) + if err != nil { + return fmt.Errorf("error parsing %s: %w", upstreamListFN, err) + } + + for _, toinstall := range upstreamList { + item := hub.GetItem(itype, toinstall) + if item == nil { + log.Errorf("Item %s/%s not found in hub", itype, toinstall) + continue + } + + if err = item.Install(ctx, false, false); err != nil { + log.Errorf("Error while installing %s : %s", toinstall, err) + } + } + + /*restore the local and tainted items*/ + files, err := os.ReadDir(itemDirectory) + if err != nil { + return fmt.Errorf("failed enumerating files of %s: %w", itemDirectory, err) + } + + for _, file := range files { + // this was the upstream data + if file.Name() == fmt.Sprintf("upstream-%s.json", itype) { + continue + } + + if itype == cwhub.PARSERS || itype == cwhub.POSTOVERFLOWS { + // we expect a stage here + if !file.IsDir() { + continue + } + + stage := file.Name() + stagedir := fmt.Sprintf("%s/%s/%s/", cfg.ConfigPaths.ConfigDir, itype, stage) + log.Debugf("Found stage %s in %s, target directory : %s", stage, itype, stagedir) + + if err = os.MkdirAll(stagedir, os.ModePerm); err != nil { + return fmt.Errorf("error while creating stage directory %s: %w", stagedir, err) + } + + // find items + ifiles, err := os.ReadDir(itemDirectory + "/" + stage + "/") + if err != nil { + return fmt.Errorf("failed enumerating files of %s: %w", itemDirectory+"/"+stage, err) + } + + // finally copy item + for _, tfile := range ifiles { + log.Infof("Going to restore local/tainted [%s]", tfile.Name()) + sourceFile := fmt.Sprintf("%s/%s/%s", itemDirectory, stage, tfile.Name()) + + destinationFile := fmt.Sprintf("%s%s", stagedir, tfile.Name()) + if err = CopyFile(sourceFile, destinationFile); err != nil { + return fmt.Errorf("failed copy %s %s to %s: %w", itype, sourceFile, destinationFile, err) + } + + log.Infof("restored %s to %s", sourceFile, destinationFile) + } + } else { + log.Infof("Going to restore local/tainted [%s]", file.Name()) + sourceFile := fmt.Sprintf("%s/%s", itemDirectory, file.Name()) + destinationFile := fmt.Sprintf("%s/%s/%s", cfg.ConfigPaths.ConfigDir, itype, file.Name()) + + if err = CopyFile(sourceFile, destinationFile); err != nil { + return fmt.Errorf("failed copy %s %s to %s: %w", itype, sourceFile, destinationFile, err) + } + + log.Infof("restored %s to %s", sourceFile, destinationFile) + } + } + } + + return nil } -/* Restore crowdsec configurations to directory : +/* + Restore crowdsec configurations to directory : - Main config (config.yaml) - Profiles config (profiles.yaml) @@ -28,91 +125,66 @@ type OldAPICfg struct { - Backup of API credentials (local API and online API) - List of scenarios, parsers, postoverflows and collections that are up-to-date - Tainted/local/out-of-date scenarios, parsers, postoverflows and collections +- Acquisition files (acquis.yaml, acquis.d/*.yaml) */ -func restoreConfigFromDirectory(dirPath string, oldBackup bool) error { +func (cli *cliConfig) restore(ctx context.Context, dirPath string) error { var err error - if !oldBackup { - backupMain := fmt.Sprintf("%s/config.yaml", dirPath) - if _, err = os.Stat(backupMain); err == nil { - if csConfig.ConfigPaths != nil && csConfig.ConfigPaths.ConfigDir != "" { - if err = CopyFile(backupMain, fmt.Sprintf("%s/config.yaml", csConfig.ConfigPaths.ConfigDir)); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", backupMain, csConfig.ConfigPaths.ConfigDir, err) - } + cfg := cli.cfg() + + backupMain := fmt.Sprintf("%s/config.yaml", dirPath) + if _, err = os.Stat(backupMain); err == nil { + if cfg.ConfigPaths != nil && cfg.ConfigPaths.ConfigDir != "" { + if err = CopyFile(backupMain, fmt.Sprintf("%s/config.yaml", cfg.ConfigPaths.ConfigDir)); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupMain, cfg.ConfigPaths.ConfigDir, err) } } + } - // Now we have config.yaml, we should regenerate config struct to have rights paths etc - ConfigFilePath = fmt.Sprintf("%s/config.yaml", csConfig.ConfigPaths.ConfigDir) + // Now we have config.yaml, we should regenerate config struct to have rights paths etc + ConfigFilePath = fmt.Sprintf("%s/config.yaml", cfg.ConfigPaths.ConfigDir) - initConfig() + log.Debug("Reloading configuration") - backupCAPICreds := fmt.Sprintf("%s/online_api_credentials.yaml", dirPath) - if _, err = os.Stat(backupCAPICreds); err == nil { - if err = CopyFile(backupCAPICreds, csConfig.API.Server.OnlineClient.CredentialsFilePath); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", backupCAPICreds, csConfig.API.Server.OnlineClient.CredentialsFilePath, err) - } - } + csConfig, _, err = loadConfigFor("config") + if err != nil { + return fmt.Errorf("failed to reload configuration: %w", err) + } - backupLAPICreds := fmt.Sprintf("%s/local_api_credentials.yaml", dirPath) - if _, err = os.Stat(backupLAPICreds); err == nil { - if err = CopyFile(backupLAPICreds, csConfig.API.Client.CredentialsFilePath); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", backupLAPICreds, csConfig.API.Client.CredentialsFilePath, err) - } - } + cfg = cli.cfg() - backupProfiles := fmt.Sprintf("%s/profiles.yaml", dirPath) - if _, err = os.Stat(backupProfiles); err == nil { - if err = CopyFile(backupProfiles, csConfig.API.Server.ProfilesPath); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", backupProfiles, csConfig.API.Server.ProfilesPath, err) - } + backupCAPICreds := fmt.Sprintf("%s/online_api_credentials.yaml", dirPath) + if _, err = os.Stat(backupCAPICreds); err == nil { + if err = CopyFile(backupCAPICreds, cfg.API.Server.OnlineClient.CredentialsFilePath); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupCAPICreds, cfg.API.Server.OnlineClient.CredentialsFilePath, err) } - } else { - var oldAPICfg OldAPICfg - backupOldAPICfg := fmt.Sprintf("%s/api_creds.json", dirPath) + } - jsonFile, err := os.Open(backupOldAPICfg) - if err != nil { - log.Warningf("failed to open %s : %s", backupOldAPICfg, err) - } else { - byteValue, _ := io.ReadAll(jsonFile) - err = json.Unmarshal(byteValue, &oldAPICfg) - if err != nil { - return fmt.Errorf("failed to load json file %s : %s", backupOldAPICfg, err) - } + backupLAPICreds := fmt.Sprintf("%s/local_api_credentials.yaml", dirPath) + if _, err = os.Stat(backupLAPICreds); err == nil { + if err = CopyFile(backupLAPICreds, cfg.API.Client.CredentialsFilePath); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupLAPICreds, cfg.API.Client.CredentialsFilePath, err) + } + } - apiCfg := csconfig.ApiCredentialsCfg{ - Login: oldAPICfg.MachineID, - Password: oldAPICfg.Password, - URL: CAPIBaseURL, - } - apiConfigDump, err := yaml.Marshal(apiCfg) - if err != nil { - return fmt.Errorf("unable to dump api credentials: %s", err) - } - apiConfigDumpFile := fmt.Sprintf("%s/online_api_credentials.yaml", csConfig.ConfigPaths.ConfigDir) - if csConfig.API.Server.OnlineClient != nil && csConfig.API.Server.OnlineClient.CredentialsFilePath != "" { - apiConfigDumpFile = csConfig.API.Server.OnlineClient.CredentialsFilePath - } - err = os.WriteFile(apiConfigDumpFile, apiConfigDump, 0o644) - if err != nil { - return fmt.Errorf("write api credentials in '%s' failed: %s", apiConfigDumpFile, err) - } - log.Infof("Saved API credentials to %s", apiConfigDumpFile) + backupProfiles := fmt.Sprintf("%s/profiles.yaml", dirPath) + if _, err = os.Stat(backupProfiles); err == nil { + if err = CopyFile(backupProfiles, cfg.API.Server.ProfilesPath); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupProfiles, cfg.API.Server.ProfilesPath, err) } } backupSimulation := fmt.Sprintf("%s/simulation.yaml", dirPath) if _, err = os.Stat(backupSimulation); err == nil { - if err = CopyFile(backupSimulation, csConfig.ConfigPaths.SimulationFilePath); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", backupSimulation, csConfig.ConfigPaths.SimulationFilePath, err) + if err = CopyFile(backupSimulation, cfg.ConfigPaths.SimulationFilePath); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupSimulation, cfg.ConfigPaths.SimulationFilePath, err) } } /*if there is a acquisition dir, restore its content*/ - if csConfig.Crowdsec.AcquisitionDirPath != "" { - if err = os.Mkdir(csConfig.Crowdsec.AcquisitionDirPath, 0o700); err != nil { - return fmt.Errorf("error while creating %s : %s", csConfig.Crowdsec.AcquisitionDirPath, err) + if cfg.Crowdsec.AcquisitionDirPath != "" { + if err = os.MkdirAll(cfg.Crowdsec.AcquisitionDirPath, 0o700); err != nil { + return fmt.Errorf("error while creating %s: %w", cfg.Crowdsec.AcquisitionDirPath, err) } } @@ -121,16 +193,16 @@ func restoreConfigFromDirectory(dirPath string, oldBackup bool) error { if _, err = os.Stat(backupAcquisition); err == nil { log.Debugf("restoring backup'ed %s", backupAcquisition) - if err = CopyFile(backupAcquisition, csConfig.Crowdsec.AcquisitionFilePath); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", backupAcquisition, csConfig.Crowdsec.AcquisitionFilePath, err) + if err = CopyFile(backupAcquisition, cfg.Crowdsec.AcquisitionFilePath); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupAcquisition, cfg.Crowdsec.AcquisitionFilePath, err) } } - // if there is files in the acquis backup dir, restore them + // if there are files in the acquis backup dir, restore them acquisBackupDir := filepath.Join(dirPath, "acquis", "*.yaml") if acquisFiles, err := filepath.Glob(acquisBackupDir); err == nil { for _, acquisFile := range acquisFiles { - targetFname, err := filepath.Abs(csConfig.Crowdsec.AcquisitionDirPath + "/" + filepath.Base(acquisFile)) + targetFname, err := filepath.Abs(cfg.Crowdsec.AcquisitionDirPath + "/" + filepath.Base(acquisFile)) if err != nil { return fmt.Errorf("while saving %s to %s: %w", acquisFile, targetFname, err) } @@ -138,17 +210,17 @@ func restoreConfigFromDirectory(dirPath string, oldBackup bool) error { log.Debugf("restoring %s to %s", acquisFile, targetFname) if err = CopyFile(acquisFile, targetFname); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", acquisFile, targetFname, err) + return fmt.Errorf("failed copy %s to %s: %w", acquisFile, targetFname, err) } } } - if csConfig.Crowdsec != nil && len(csConfig.Crowdsec.AcquisitionFiles) > 0 { - for _, acquisFile := range csConfig.Crowdsec.AcquisitionFiles { + if cfg.Crowdsec != nil && len(cfg.Crowdsec.AcquisitionFiles) > 0 { + for _, acquisFile := range cfg.Crowdsec.AcquisitionFiles { log.Infof("backup filepath from dir -> %s", acquisFile) // if it was the default one, it has already been backed up - if csConfig.Crowdsec.AcquisitionFilePath == acquisFile { + if cfg.Crowdsec.AcquisitionFilePath == acquisFile { log.Infof("skip this one") continue } @@ -159,46 +231,22 @@ func restoreConfigFromDirectory(dirPath string, oldBackup bool) error { } if err = CopyFile(acquisFile, targetFname); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", acquisFile, targetFname, err) + return fmt.Errorf("failed copy %s to %s: %w", acquisFile, targetFname, err) } log.Infof("Saved acquis %s to %s", acquisFile, targetFname) } } - if err = RestoreHub(dirPath); err != nil { - return fmt.Errorf("failed to restore hub config : %s", err) + if err = cli.restoreHub(ctx, dirPath); err != nil { + return fmt.Errorf("failed to restore hub config: %w", err) } return nil } -func runConfigRestore(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - oldBackup, err := flags.GetBool("old-backup") - if err != nil { - return err - } - - if err := csConfig.LoadHub(); err != nil { - return err - } - - if err := cwhub.GetHubIdx(csConfig.Hub); err != nil { - log.Info("Run 'sudo cscli hub update' to get the hub index") - return fmt.Errorf("failed to get Hub index: %w", err) - } - - if err := restoreConfigFromDirectory(args[0], oldBackup); err != nil { - return fmt.Errorf("failed to restore config from %s: %w", args[0], err) - } - - return nil -} - -func NewConfigRestoreCmd() *cobra.Command { - cmdConfigRestore := &cobra.Command{ +func (cli *cliConfig) newRestoreCmd() *cobra.Command { + cmd := &cobra.Command{ Use: `restore "directory"`, Short: `Restore config in backup "directory"`, Long: `Restore the crowdsec configuration from specified backup "directory" including: @@ -211,11 +259,16 @@ func NewConfigRestoreCmd() *cobra.Command { - Backup of API credentials (local API and online API)`, Args: cobra.ExactArgs(1), DisableAutoGenTag: true, - RunE: runConfigRestore, - } + RunE: func(cmd *cobra.Command, args []string) error { + dirPath := args[0] - flags := cmdConfigRestore.Flags() - flags.BoolP("old-backup", "", false, "To use when you are upgrading crowdsec v0.X to v1.X and you need to restore backup from v0.X") + if err := cli.restore(cmd.Context(), dirPath); err != nil { + return fmt.Errorf("failed to restore config from %s: %w", dirPath, err) + } + + return nil + }, + } - return cmdConfigRestore + return cmd } diff --git a/cmd/crowdsec-cli/config_show.go b/cmd/crowdsec-cli/config_show.go index 2e1fc7092d7..2d3ac488ba2 100644 --- a/cmd/crowdsec-cli/config_show.go +++ b/cmd/crowdsec-cli/config_show.go @@ -6,16 +6,19 @@ import ( "os" "text/template" - "github.com/antonmedv/expr" + "github.com/expr-lang/expr" + "github.com/sanity-io/litter" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" ) -func showConfigKey(key string) error { +func (cli *cliConfig) showKey(key string) error { + cfg := cli.cfg() + type Env struct { Config *csconfig.Config } @@ -23,41 +26,43 @@ func showConfigKey(key string) error { opts := []expr.Option{} opts = append(opts, exprhelpers.GetExprOptions(map[string]interface{}{})...) opts = append(opts, expr.Env(Env{})) + program, err := expr.Compile(key, opts...) if err != nil { return err } - output, err := expr.Run(program, Env{Config: csConfig}) + output, err := expr.Run(program, Env{Config: cfg}) if err != nil { return err } - switch csConfig.Cscli.Output { + switch cfg.Cscli.Output { case "human", "raw": + // Don't use litter for strings, it adds quotes + // that would break compatibility with previous versions switch output.(type) { case string: - fmt.Printf("%s\n", output) - case int: - fmt.Printf("%d\n", output) + fmt.Println(output) default: - fmt.Printf("%v\n", output) + litter.Dump(output) } case "json": data, err := json.MarshalIndent(output, "", " ") if err != nil { - return fmt.Errorf("failed to marshal configuration: %w", err) + return fmt.Errorf("failed to serialize configuration: %w", err) } - fmt.Printf("%s\n", string(data)) + fmt.Println(string(data)) } + return nil } -var configShowTemplate = `Global: +func (cli *cliConfig) template() string { + return `Global: {{- if .ConfigPaths }} - - Configuration Folder : {{.ConfigPaths.ConfigDir}} - Configuration Folder : {{.ConfigPaths.ConfigDir}} - Data Folder : {{.ConfigPaths.DataDir}} - Hub Folder : {{.ConfigPaths.HubDir}} @@ -71,7 +76,7 @@ var configShowTemplate = `Global: {{- end }} {{- if .Crowdsec }} -Crowdsec: +Crowdsec{{if and .Crowdsec.Enable (not (ValueBool .Crowdsec.Enable))}} (disabled){{end}}: - Acquisition File : {{.Crowdsec.AcquisitionFilePath}} - Parsers routines : {{.Crowdsec.ParserRoutinesCount}} {{- if .Crowdsec.AcquisitionDirPath }} @@ -83,7 +88,6 @@ Crowdsec: cscli: - Output : {{.Cscli.Output}} - Hub Branch : {{.Cscli.HubBranch}} - - Hub Folder : {{.Cscli.HubDir}} {{- end }} {{- if .API }} @@ -97,8 +101,9 @@ API Client: {{- end }} {{- if .API.Server }} -Local API Server: +Local API Server{{if and .API.Server.Enable (not (ValueBool .API.Server.Enable))}} (disabled){{end}}: - Listen URL : {{.API.Server.ListenURI}} + - Listen Socket : {{.API.Server.ListenSocket}} - Profile File : {{.API.Server.ProfilesPath}} {{- if .API.Server.TLS }} @@ -164,6 +169,12 @@ Central API: - User : {{.DbConfig.User}} - DB Name : {{.DbConfig.DbName}} {{- end }} +{{- if .DbConfig.MaxOpenConns }} + - Max Open Conns : {{.DbConfig.MaxOpenConns}} +{{- end }} +{{- if ne .DbConfig.DecisionBulkSize 0 }} + - Decision Bulk Size : {{.DbConfig.DecisionBulkSize}} +{{- end }} {{- if .DbConfig.Flush }} {{- if .DbConfig.Flush.MaxAge }} - Flush age : {{.DbConfig.Flush.MaxAge}} @@ -174,64 +185,74 @@ Central API: {{- end }} {{- end }} ` +} -func runConfigShow(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - if err := csConfig.LoadAPIClient(); err != nil { - log.Errorf("failed to load API client configuration: %s", err) - // don't return, we can still show the configuration - } - - key, err := flags.GetString("key") - if err != nil { - return err - } - - if key != "" { - return showConfigKey(key) - } +func (cli *cliConfig) show() error { + cfg := cli.cfg() - switch csConfig.Cscli.Output { + switch cfg.Cscli.Output { case "human": - tmp, err := template.New("config").Parse(configShowTemplate) + // The tests on .Enable look funny because the option has a true default which has + // not been set yet (we don't really load the LAPI) and go templates don't dereference + // pointers in boolean tests. Prefix notation is the cherry on top. + funcs := template.FuncMap{ + // can't use generics here + "ValueBool": func(b *bool) bool { return b != nil && *b }, + } + + tmp, err := template.New("config").Funcs(funcs).Parse(cli.template()) if err != nil { return err } - err = tmp.Execute(os.Stdout, csConfig) + + err = tmp.Execute(os.Stdout, cfg) if err != nil { return err } case "json": - data, err := json.MarshalIndent(csConfig, "", " ") + data, err := json.MarshalIndent(cfg, "", " ") if err != nil { - return fmt.Errorf("failed to marshal configuration: %w", err) + return fmt.Errorf("failed to serialize configuration: %w", err) } - fmt.Printf("%s\n", string(data)) + fmt.Println(string(data)) case "raw": - data, err := yaml.Marshal(csConfig) + data, err := yaml.Marshal(cfg) if err != nil { - return fmt.Errorf("failed to marshal configuration: %w", err) + return fmt.Errorf("failed to serialize configuration: %w", err) } - fmt.Printf("%s\n", string(data)) + fmt.Println(string(data)) } + return nil } -func NewConfigShowCmd() *cobra.Command { - cmdConfigShow := &cobra.Command{ +func (cli *cliConfig) newShowCmd() *cobra.Command { + var key string + + cmd := &cobra.Command{ Use: "show", Short: "Displays current config", Long: `Displays the current cli configuration.`, Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - RunE: runConfigShow, + RunE: func(_ *cobra.Command, _ []string) error { + if err := cli.cfg().LoadAPIClient(); err != nil { + log.Errorf("failed to load API client configuration: %s", err) + // don't return, we can still show the configuration + } + + if key != "" { + return cli.showKey(key) + } + + return cli.show() + }, } - flags := cmdConfigShow.Flags() - flags.StringP("key", "", "", "Display only this value (Config.API.Server.ListenURI)") + flags := cmd.Flags() + flags.StringVarP(&key, "key", "", "", "Display only this value (Config.API.Server.ListenURI)") - return cmdConfigShow + return cmd } diff --git a/cmd/crowdsec-cli/config_showyaml.go b/cmd/crowdsec-cli/config_showyaml.go index 82bc67ffcb8..52daee6a65e 100644 --- a/cmd/crowdsec-cli/config_showyaml.go +++ b/cmd/crowdsec-cli/config_showyaml.go @@ -6,19 +6,21 @@ import ( "github.com/spf13/cobra" ) -func runConfigShowYAML(cmd *cobra.Command, args []string) error { +func (cli *cliConfig) showYAML() error { fmt.Println(mergedConfig) return nil } -func NewConfigShowYAMLCmd() *cobra.Command { - cmdConfigShow := &cobra.Command{ +func (cli *cliConfig) newShowYAMLCmd() *cobra.Command { + cmd := &cobra.Command{ Use: "show-yaml", Short: "Displays merged config.yaml + config.yaml.local", Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - RunE: runConfigShowYAML, + RunE: func(_ *cobra.Command, _ []string) error { + return cli.showYAML() + }, } - return cmdConfigShow + return cmd } diff --git a/cmd/crowdsec-cli/console.go b/cmd/crowdsec-cli/console.go deleted file mode 100644 index 83886267da2..00000000000 --- a/cmd/crowdsec-cli/console.go +++ /dev/null @@ -1,334 +0,0 @@ -package main - -import ( - "context" - "encoding/csv" - "encoding/json" - "errors" - "fmt" - "io/fs" - "net/url" - "os" - - "github.com/fatih/color" - "github.com/go-openapi/strfmt" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - "gopkg.in/yaml.v3" - - "github.com/crowdsecurity/go-cs-lib/pkg/ptr" - "github.com/crowdsecurity/go-cs-lib/pkg/version" - - "github.com/crowdsecurity/crowdsec/pkg/apiclient" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" - "github.com/crowdsecurity/crowdsec/pkg/cwhub" - "github.com/crowdsecurity/crowdsec/pkg/fflag" - "github.com/crowdsecurity/crowdsec/pkg/types" -) - -func NewConsoleCmd() *cobra.Command { - var cmdConsole = &cobra.Command{ - Use: "console [action]", - Short: "Manage interaction with Crowdsec console (https://app.crowdsec.net)", - Args: cobra.MinimumNArgs(1), - DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if err := csConfig.LoadAPIServer(); err != nil || csConfig.DisableAPI { - var fdErr *fs.PathError - if errors.As(err, &fdErr) { - log.Fatalf("Unable to load Local API : %s", fdErr) - } - if err != nil { - log.Fatalf("Unable to load required Local API Configuration : %s", err) - } - log.Fatal("Local API is disabled, please run this command on the local API machine") - } - if csConfig.DisableAPI { - log.Fatal("Local API is disabled, please run this command on the local API machine") - } - if csConfig.API.Server.OnlineClient == nil { - log.Fatalf("No configuration for Central API (CAPI) in '%s'", *csConfig.FilePath) - } - if csConfig.API.Server.OnlineClient.Credentials == nil { - log.Fatal("You must configure Central API (CAPI) with `cscli capi register` before accessing console features.") - } - return nil - }, - } - - name := "" - overwrite := false - tags := []string{} - - cmdEnroll := &cobra.Command{ - Use: "enroll [enroll-key]", - Short: "Enroll this instance to https://app.crowdsec.net [requires local API]", - Long: ` -Enroll this instance to https://app.crowdsec.net - -You can get your enrollment key by creating an account on https://app.crowdsec.net. -After running this command your will need to validate the enrollment in the webapp.`, - Example: `cscli console enroll YOUR-ENROLL-KEY - cscli console enroll --name [instance_name] YOUR-ENROLL-KEY - cscli console enroll --name [instance_name] --tags [tag_1] --tags [tag_2] YOUR-ENROLL-KEY -`, - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - password := strfmt.Password(csConfig.API.Server.OnlineClient.Credentials.Password) - apiURL, err := url.Parse(csConfig.API.Server.OnlineClient.Credentials.URL) - if err != nil { - log.Fatalf("Could not parse CAPI URL : %s", err) - } - - if err := csConfig.LoadHub(); err != nil { - log.Fatal(err) - } - - if err := cwhub.GetHubIdx(csConfig.Hub); err != nil { - log.Fatalf("Failed to load hub index : %s", err) - log.Info("Run 'sudo cscli hub update' to get the hub index") - } - - scenarios, err := cwhub.GetInstalledScenariosAsString() - if err != nil { - log.Fatalf("failed to get scenarios : %s", err) - } - - if len(scenarios) == 0 { - scenarios = make([]string, 0) - } - - c, _ := apiclient.NewClient(&apiclient.Config{ - MachineID: csConfig.API.Server.OnlineClient.Credentials.Login, - Password: password, - Scenarios: scenarios, - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), - URL: apiURL, - VersionPrefix: "v3", - }) - resp, err := c.Auth.EnrollWatcher(context.Background(), args[0], name, tags, overwrite) - if err != nil { - log.Fatalf("Could not enroll instance: %s", err) - } - if resp.Response.StatusCode == 200 && !overwrite { - log.Warning("Instance already enrolled. You can use '--overwrite' to force enroll") - return - } - - SetConsoleOpts(csconfig.CONSOLE_CONFIGS, true) - if err := csConfig.API.Server.DumpConsoleConfig(); err != nil { - log.Fatalf("failed writing console config : %s", err) - } - log.Infof("Enabled tainted&manual alerts sharing, see 'cscli console status'.") - log.Infof("Watcher successfully enrolled. Visit https://app.crowdsec.net to accept it.") - log.Infof("Please restart crowdsec after accepting the enrollment.") - }, - } - cmdEnroll.Flags().StringVarP(&name, "name", "n", "", "Name to display in the console") - cmdEnroll.Flags().BoolVarP(&overwrite, "overwrite", "", false, "Force enroll the instance") - cmdEnroll.Flags().StringSliceVarP(&tags, "tags", "t", tags, "Tags to display in the console") - cmdConsole.AddCommand(cmdEnroll) - - var enableAll, disableAll bool - - cmdEnable := &cobra.Command{ - Use: "enable [option]", - Short: "Enable a console option", - Example: "sudo cscli console enable tainted", - Long: ` -Enable given information push to the central API. Allows to empower the console`, - ValidArgs: csconfig.CONSOLE_CONFIGS, - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - if enableAll { - SetConsoleOpts(csconfig.CONSOLE_CONFIGS, true) - log.Infof("All features have been enabled successfully") - } else { - if len(args) == 0 { - log.Fatalf("You must specify at least one feature to enable") - } - SetConsoleOpts(args, true) - log.Infof("%v have been enabled", args) - } - if err := csConfig.API.Server.DumpConsoleConfig(); err != nil { - log.Fatalf("failed writing console config : %s", err) - } - log.Infof(ReloadMessage()) - }, - } - cmdEnable.Flags().BoolVarP(&enableAll, "all", "a", false, "Enable all console options") - cmdConsole.AddCommand(cmdEnable) - - cmdDisable := &cobra.Command{ - Use: "disable [option]", - Short: "Disable a console option", - Example: "sudo cscli console disable tainted", - Long: ` -Disable given information push to the central API.`, - ValidArgs: csconfig.CONSOLE_CONFIGS, - Args: cobra.MinimumNArgs(1), - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - if disableAll { - SetConsoleOpts(csconfig.CONSOLE_CONFIGS, false) - } else { - SetConsoleOpts(args, false) - } - - if err := csConfig.API.Server.DumpConsoleConfig(); err != nil { - log.Fatalf("failed writing console config : %s", err) - } - if disableAll { - log.Infof("All features have been disabled") - } else { - log.Infof("%v have been disabled", args) - } - log.Infof(ReloadMessage()) - }, - } - cmdDisable.Flags().BoolVarP(&disableAll, "all", "a", false, "Disable all console options") - cmdConsole.AddCommand(cmdDisable) - - cmdConsoleStatus := &cobra.Command{ - Use: "status [option]", - Short: "Shows status of one or all console options", - Example: `sudo cscli console status`, - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - switch csConfig.Cscli.Output { - case "human": - cmdConsoleStatusTable(color.Output, *csConfig) - case "json": - data, err := json.MarshalIndent(csConfig.API.Server.ConsoleConfig, "", " ") - if err != nil { - log.Fatalf("failed to marshal configuration: %s", err) - } - fmt.Printf("%s\n", string(data)) - case "raw": - csvwriter := csv.NewWriter(os.Stdout) - err := csvwriter.Write([]string{"option", "enabled"}) - if err != nil { - log.Fatal(err) - } - - rows := [][]string{ - {csconfig.SEND_MANUAL_SCENARIOS, fmt.Sprintf("%t", *csConfig.API.Server.ConsoleConfig.ShareManualDecisions)}, - {csconfig.SEND_CUSTOM_SCENARIOS, fmt.Sprintf("%t", *csConfig.API.Server.ConsoleConfig.ShareCustomScenarios)}, - {csconfig.SEND_TAINTED_SCENARIOS, fmt.Sprintf("%t", *csConfig.API.Server.ConsoleConfig.ShareTaintedScenarios)}, - {csconfig.SEND_CONTEXT, fmt.Sprintf("%t", *csConfig.API.Server.ConsoleConfig.ShareContext)}, - {csconfig.CONSOLE_MANAGEMENT, fmt.Sprintf("%t", *csConfig.API.Server.ConsoleConfig.ConsoleManagement)}, - } - for _, row := range rows { - err = csvwriter.Write(row) - if err != nil { - log.Fatal(err) - } - } - csvwriter.Flush() - } - }, - } - cmdConsole.AddCommand(cmdConsoleStatus) - - return cmdConsole -} - -func SetConsoleOpts(args []string, wanted bool) { - for _, arg := range args { - switch arg { - case csconfig.CONSOLE_MANAGEMENT: - if !fflag.PapiClient.IsEnabled() { - continue - } - /*for each flag check if it's already set before setting it*/ - if csConfig.API.Server.ConsoleConfig.ConsoleManagement != nil { - if *csConfig.API.Server.ConsoleConfig.ConsoleManagement == wanted { - log.Debugf("%s already set to %t", csconfig.CONSOLE_MANAGEMENT, wanted) - } else { - log.Infof("%s set to %t", csconfig.CONSOLE_MANAGEMENT, wanted) - *csConfig.API.Server.ConsoleConfig.ConsoleManagement = wanted - } - } else { - log.Infof("%s set to %t", csconfig.CONSOLE_MANAGEMENT, wanted) - csConfig.API.Server.ConsoleConfig.ConsoleManagement = ptr.Of(wanted) - } - if csConfig.API.Server.OnlineClient.Credentials != nil { - changed := false - if wanted && csConfig.API.Server.OnlineClient.Credentials.PapiURL == "" { - changed = true - csConfig.API.Server.OnlineClient.Credentials.PapiURL = types.PAPIBaseURL - } else if !wanted && csConfig.API.Server.OnlineClient.Credentials.PapiURL != "" { - changed = true - csConfig.API.Server.OnlineClient.Credentials.PapiURL = "" - } - if changed { - fileContent, err := yaml.Marshal(csConfig.API.Server.OnlineClient.Credentials) - if err != nil { - log.Fatalf("Cannot marshal credentials: %s", err) - } - log.Infof("Updating credentials file: %s", csConfig.API.Server.OnlineClient.CredentialsFilePath) - err = os.WriteFile(csConfig.API.Server.OnlineClient.CredentialsFilePath, fileContent, 0600) - if err != nil { - log.Fatalf("Cannot write credentials file: %s", err) - } - } - } - case csconfig.SEND_CUSTOM_SCENARIOS: - /*for each flag check if it's already set before setting it*/ - if csConfig.API.Server.ConsoleConfig.ShareCustomScenarios != nil { - if *csConfig.API.Server.ConsoleConfig.ShareCustomScenarios == wanted { - log.Debugf("%s already set to %t", csconfig.SEND_CUSTOM_SCENARIOS, wanted) - } else { - log.Infof("%s set to %t", csconfig.SEND_CUSTOM_SCENARIOS, wanted) - *csConfig.API.Server.ConsoleConfig.ShareCustomScenarios = wanted - } - } else { - log.Infof("%s set to %t", csconfig.SEND_CUSTOM_SCENARIOS, wanted) - csConfig.API.Server.ConsoleConfig.ShareCustomScenarios = ptr.Of(wanted) - } - case csconfig.SEND_TAINTED_SCENARIOS: - /*for each flag check if it's already set before setting it*/ - if csConfig.API.Server.ConsoleConfig.ShareTaintedScenarios != nil { - if *csConfig.API.Server.ConsoleConfig.ShareTaintedScenarios == wanted { - log.Debugf("%s already set to %t", csconfig.SEND_TAINTED_SCENARIOS, wanted) - } else { - log.Infof("%s set to %t", csconfig.SEND_TAINTED_SCENARIOS, wanted) - *csConfig.API.Server.ConsoleConfig.ShareTaintedScenarios = wanted - } - } else { - log.Infof("%s set to %t", csconfig.SEND_TAINTED_SCENARIOS, wanted) - csConfig.API.Server.ConsoleConfig.ShareTaintedScenarios = ptr.Of(wanted) - } - case csconfig.SEND_MANUAL_SCENARIOS: - /*for each flag check if it's already set before setting it*/ - if csConfig.API.Server.ConsoleConfig.ShareManualDecisions != nil { - if *csConfig.API.Server.ConsoleConfig.ShareManualDecisions == wanted { - log.Debugf("%s already set to %t", csconfig.SEND_MANUAL_SCENARIOS, wanted) - } else { - log.Infof("%s set to %t", csconfig.SEND_MANUAL_SCENARIOS, wanted) - *csConfig.API.Server.ConsoleConfig.ShareManualDecisions = wanted - } - } else { - log.Infof("%s set to %t", csconfig.SEND_MANUAL_SCENARIOS, wanted) - csConfig.API.Server.ConsoleConfig.ShareManualDecisions = ptr.Of(wanted) - } - case csconfig.SEND_CONTEXT: - /*for each flag check if it's already set before setting it*/ - if csConfig.API.Server.ConsoleConfig.ShareContext != nil { - if *csConfig.API.Server.ConsoleConfig.ShareContext == wanted { - log.Debugf("%s already set to %t", csconfig.SEND_CONTEXT, wanted) - } else { - log.Infof("%s set to %t", csconfig.SEND_CONTEXT, wanted) - *csConfig.API.Server.ConsoleConfig.ShareContext = wanted - } - } else { - log.Infof("%s set to %t", csconfig.SEND_CONTEXT, wanted) - csConfig.API.Server.ConsoleConfig.ShareContext = ptr.Of(wanted) - } - default: - log.Fatalf("unknown flag %s", arg) - } - } - -} diff --git a/cmd/crowdsec-cli/console_table.go b/cmd/crowdsec-cli/console_table.go deleted file mode 100644 index f6778d6257f..00000000000 --- a/cmd/crowdsec-cli/console_table.go +++ /dev/null @@ -1,60 +0,0 @@ -package main - -import ( - "io" - - "github.com/aquasecurity/table" - "github.com/enescakir/emoji" - - "github.com/crowdsecurity/crowdsec/pkg/csconfig" -) - -func cmdConsoleStatusTable(out io.Writer, csConfig csconfig.Config) { - t := newTable(out) - t.SetRowLines(false) - - t.SetHeaders("Option Name", "Activated", "Description") - t.SetHeaderAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft) - - for _, option := range csconfig.CONSOLE_CONFIGS { - switch option { - case csconfig.SEND_CUSTOM_SCENARIOS: - activated := string(emoji.CrossMark) - if *csConfig.API.Server.ConsoleConfig.ShareCustomScenarios { - activated = string(emoji.CheckMarkButton) - } - - t.AddRow(option, activated, "Send alerts from custom scenarios to the console") - - case csconfig.SEND_MANUAL_SCENARIOS: - activated := string(emoji.CrossMark) - if *csConfig.API.Server.ConsoleConfig.ShareManualDecisions { - activated = string(emoji.CheckMarkButton) - } - - t.AddRow(option, activated, "Send manual decisions to the console") - - case csconfig.SEND_TAINTED_SCENARIOS: - activated := string(emoji.CrossMark) - if *csConfig.API.Server.ConsoleConfig.ShareTaintedScenarios { - activated = string(emoji.CheckMarkButton) - } - - t.AddRow(option, activated, "Send alerts from tainted scenarios to the console") - case csconfig.SEND_CONTEXT: - activated := string(emoji.CrossMark) - if *csConfig.API.Server.ConsoleConfig.ShareContext { - activated = string(emoji.CheckMarkButton) - } - t.AddRow(option, activated, "Send context with alerts to the console") - case csconfig.CONSOLE_MANAGEMENT: - activated := string(emoji.CrossMark) - if *csConfig.API.Server.ConsoleConfig.ConsoleManagement { - activated = string(emoji.CheckMarkButton) - } - t.AddRow(option, activated, "Receive decisions from console") - } - } - - t.Render() -} diff --git a/cmd/crowdsec-cli/copyfile.go b/cmd/crowdsec-cli/copyfile.go index 4de6cd6e24a..272fb3f7851 100644 --- a/cmd/crowdsec-cli/copyfile.go +++ b/cmd/crowdsec-cli/copyfile.go @@ -9,7 +9,6 @@ import ( log "github.com/sirupsen/logrus" ) - /*help to copy the file, ioutil doesn't offer the feature*/ func copyFileContents(src, dst string) (err error) { @@ -18,56 +17,66 @@ func copyFileContents(src, dst string) (err error) { return } defer in.Close() + out, err := os.Create(dst) if err != nil { return } + defer func() { cerr := out.Close() if err == nil { err = cerr } }() + if _, err = io.Copy(out, in); err != nil { return } + err = out.Sync() + return } /*copy the file, ioutile doesn't offer the feature*/ -func CopyFile(sourceSymLink, destinationFile string) (err error) { +func CopyFile(sourceSymLink, destinationFile string) error { sourceFile, err := filepath.EvalSymlinks(sourceSymLink) if err != nil { log.Infof("Not a symlink : %s", err) + sourceFile = sourceSymLink } sourceFileStat, err := os.Stat(sourceFile) if err != nil { - return + return err } + if !sourceFileStat.Mode().IsRegular() { // cannot copy non-regular files (e.g., directories, // symlinks, devices, etc.) return fmt.Errorf("copyFile: non-regular source file %s (%q)", sourceFileStat.Name(), sourceFileStat.Mode().String()) } + destinationFileStat, err := os.Stat(destinationFile) if err != nil { if !os.IsNotExist(err) { - return + return err } } else { if !(destinationFileStat.Mode().IsRegular()) { return fmt.Errorf("copyFile: non-regular destination file %s (%q)", destinationFileStat.Name(), destinationFileStat.Mode().String()) } + if os.SameFile(sourceFileStat, destinationFileStat) { - return + return err } } + if err = os.Link(sourceFile, destinationFile); err != nil { err = copyFileContents(sourceFile, destinationFile) } - return -} + return err +} diff --git a/cmd/crowdsec-cli/cstable/cstable.go b/cmd/crowdsec-cli/cstable/cstable.go new file mode 100644 index 00000000000..85ba491f4e8 --- /dev/null +++ b/cmd/crowdsec-cli/cstable/cstable.go @@ -0,0 +1,161 @@ +package cstable + +// transisional file to keep (minimal) backwards compatibility with the old table +// we can migrate the code to the new dependency later, it can already use the Writer interface + +import ( + "fmt" + "io" + "os" + + "github.com/jedib0t/go-pretty/v6/table" + "github.com/jedib0t/go-pretty/v6/text" + isatty "github.com/mattn/go-isatty" +) + +func shouldWeColorize(wantColor string) bool { + switch wantColor { + case "yes": + return true + case "no": + return false + default: + return isatty.IsTerminal(os.Stdout.Fd()) || isatty.IsCygwinTerminal(os.Stdout.Fd()) + } +} + +type Table struct { + Writer table.Writer + output io.Writer + align []text.Align + alignHeader []text.Align +} + +func New(out io.Writer, wantColor string) *Table { + if out == nil { + panic("newTable: out is nil") + } + + t := table.NewWriter() + + // colorize output, use unicode box characters + fancy := shouldWeColorize(wantColor) + + colorOptions := table.ColorOptions{} + + if fancy { + colorOptions.Header = text.Colors{text.Italic} + colorOptions.Border = text.Colors{text.FgHiBlack} + colorOptions.Separator = text.Colors{text.FgHiBlack} + } + + // no upper/lower case transformations + format := table.FormatOptions{} + + box := table.StyleBoxDefault + if fancy { + box = table.StyleBoxRounded + } + + style := table.Style{ + Box: box, + Color: colorOptions, + Format: format, + HTML: table.DefaultHTMLOptions, + Options: table.OptionsDefault, + Title: table.TitleOptionsDefault, + } + + t.SetStyle(style) + + return &Table{ + Writer: t, + output: out, + align: make([]text.Align, 0), + alignHeader: make([]text.Align, 0), + } +} + +func NewLight(output io.Writer, wantColor string) *Table { + t := New(output, wantColor) + s := t.Writer.Style() + s.Box.Left = "" + s.Box.LeftSeparator = "" + s.Box.TopLeft = "" + s.Box.BottomLeft = "" + s.Box.Right = "" + s.Box.RightSeparator = "" + s.Box.TopRight = "" + s.Box.BottomRight = "" + s.Options.SeparateRows = false + s.Options.SeparateFooter = false + s.Options.SeparateHeader = true + s.Options.SeparateColumns = false + + return t +} + +// +// wrapper methods for backwards compatibility +// + +// setColumnConfigs must be called right before rendering, +// to allow for setting the alignment like the old API +func (t *Table) setColumnConfigs() { + configs := []table.ColumnConfig{} + // the go-pretty table does not expose the names or number of columns + for i := range len(t.align) { + configs = append(configs, table.ColumnConfig{ + Number: i + 1, + AlignHeader: t.alignHeader[i], + Align: t.align[i], + WidthMax: 60, + WidthMaxEnforcer: text.WrapSoft, + }) + } + + t.Writer.SetColumnConfigs(configs) +} + +func (t *Table) Render() { + // change default options for backwards compatibility. + // we do this late to allow changing the alignment like the old API + t.setColumnConfigs() + fmt.Fprintln(t.output, t.Writer.Render()) +} + +func (t *Table) SetHeaders(str ...string) { + row := table.Row{} + t.align = make([]text.Align, len(str)) + t.alignHeader = make([]text.Align, len(str)) + + for i, v := range str { + row = append(row, v) + t.align[i] = text.AlignLeft + t.alignHeader[i] = text.AlignCenter + } + + t.Writer.AppendHeader(row) +} + +func (t *Table) AddRow(str ...string) { + row := table.Row{} + for _, v := range str { + row = append(row, v) + } + + t.Writer.AppendRow(row) +} + +func (t *Table) SetRowLines(rowLines bool) { + t.Writer.Style().Options.SeparateRows = rowLines +} + +func (t *Table) SetAlignment(align ...text.Align) { + // align can be shorter than t.align, it will leave the default value + copy(t.align, align) +} + +func (t *Table) SetHeaderAlignment(align ...text.Align) { + copy(t.alignHeader, align) +} diff --git a/cmd/crowdsec-cli/dashboard.go b/cmd/crowdsec-cli/dashboard.go index c6069c71cb2..41db9e6cbf2 100644 --- a/cmd/crowdsec-cli/dashboard.go +++ b/cmd/crowdsec-cli/dashboard.go @@ -1,3 +1,5 @@ +//go:build linux + package main import ( @@ -10,6 +12,7 @@ import ( "path/filepath" "strconv" "strings" + "syscall" "unicode" "github.com/AlecAivazis/survey/v2" @@ -17,16 +20,21 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "github.com/crowdsecurity/go-cs-lib/version" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/idgen" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" "github.com/crowdsecurity/crowdsec/pkg/metabase" ) var ( metabaseUser = "crowdsec@crowdsec.net" metabasePassword string - metabaseDbPath string + metabaseDBPath string metabaseConfigPath string metabaseConfigFolder = "metabase/" metabaseConfigFile = "metabase.yaml" + metabaseImage = "metabase/metabase:v0.46.6.1" /**/ metabaseListenAddress = "127.0.0.1" metabaseListenPort = "3000" @@ -35,12 +43,21 @@ var ( forceYes bool - /*informations needed to setup a random password on user's behalf*/ + // information needed to set up a random password on user's behalf ) -func NewDashboardCmd() *cobra.Command { - /* ---- UPDATE COMMAND */ - var cmdDashboard = &cobra.Command{ +type cliDashboard struct { + cfg configGetter +} + +func NewCLIDashboard(cfg configGetter) *cliDashboard { + return &cliDashboard{ + cfg: cfg, + } +} + +func (cli *cliDashboard) NewCommand() *cobra.Command { + cmd := &cobra.Command{ Use: "dashboard [command]", Short: "Manage your metabase dashboard container [requires local API]", Long: `Install/Start/Stop/Remove a metabase container exposing dashboard and metrics. @@ -54,23 +71,28 @@ cscli dashboard start cscli dashboard stop cscli dashboard remove `, - PersistentPreRun: func(cmd *cobra.Command, args []string) { - if err := metabase.TestAvailability(); err != nil { - log.Fatalf("%s", err) + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + if version.System == "docker" { + return errors.New("cscli dashboard is not supported whilst running CrowdSec within a container please see: https://github.com/crowdsecurity/example-docker-compose/tree/main/basic") + } + + cfg := cli.cfg() + if err := require.LAPI(cfg); err != nil { + return err } - if err := csConfig.LoadAPIServer(); err != nil || csConfig.DisableAPI { - log.Fatal("Local API is disabled, please run this command on the local API machine") + if err := metabase.TestAvailability(); err != nil { + return err } - metabaseConfigFolderPath := filepath.Join(csConfig.ConfigPaths.ConfigDir, metabaseConfigFolder) + metabaseConfigFolderPath := filepath.Join(cfg.ConfigPaths.ConfigDir, metabaseConfigFolder) metabaseConfigPath = filepath.Join(metabaseConfigFolderPath, metabaseConfigFile) if err := os.MkdirAll(metabaseConfigFolderPath, os.ModePerm); err != nil { - log.Fatal(err) + return err } - if err := csConfig.LoadDBConfig(); err != nil { - log.Errorf("This command requires direct database access (must be run on the local API machine)") - log.Fatal(err) + + if err := require.DB(cfg); err != nil { + return err } /* @@ -84,23 +106,26 @@ cscli dashboard remove metabaseContainerID = oldContainerID } } + + log.Warn("cscli dashboard will be deprecated in version 1.7.0, read more at https://docs.crowdsec.net/blog/cscli_dashboard_deprecation/") + + return nil }, } - cmdDashboard.AddCommand(NewDashboardSetupCmd()) - cmdDashboard.AddCommand(NewDashboardStartCmd()) - cmdDashboard.AddCommand(NewDashboardStopCmd()) - cmdDashboard.AddCommand(NewDashboardShowPasswordCmd()) - cmdDashboard.AddCommand(NewDashboardRemoveCmd()) + cmd.AddCommand(cli.newSetupCmd()) + cmd.AddCommand(cli.newStartCmd()) + cmd.AddCommand(cli.newStopCmd()) + cmd.AddCommand(cli.newShowPasswordCmd()) + cmd.AddCommand(cli.newRemoveCmd()) - return cmdDashboard + return cmd } - -func NewDashboardSetupCmd() *cobra.Command { +func (cli *cliDashboard) newSetupCmd() *cobra.Command { var force bool - var cmdDashSetup = &cobra.Command{ + cmd := &cobra.Command{ Use: "setup", Short: "Setup a metabase container.", Long: `Perform a metabase docker setup, download standard dashboards, create a fresh user and start the container`, @@ -111,82 +136,38 @@ cscli dashboard setup cscli dashboard setup --listen 0.0.0.0 cscli dashboard setup -l 0.0.0.0 -p 443 --password `, - Run: func(cmd *cobra.Command, args []string) { - if metabaseDbPath == "" { - metabaseDbPath = csConfig.ConfigPaths.DataDir + RunE: func(_ *cobra.Command, _ []string) error { + if metabaseDBPath == "" { + metabaseDBPath = cli.cfg().ConfigPaths.DataDir } if metabasePassword == "" { isValid := passwordIsValid(metabasePassword) for !isValid { - metabasePassword = generatePassword(16) + metabasePassword = idgen.GeneratePassword(16) isValid = passwordIsValid(metabasePassword) } } - var answer bool - if valid, err := checkSystemMemory(); err == nil && !valid { - if !forceYes { - prompt := &survey.Confirm{ - Message: "Metabase requires 1-2GB of RAM, your system is below this requirement continue ?", - Default: true, - } - if err := survey.AskOne(prompt, &answer); err != nil { - log.Warnf("unable to ask about RAM check: %s", err) - } - if !answer { - log.Fatal("Unable to continue due to RAM requirement") - } - } else { - log.Warnf("Metabase requires 1-2GB of RAM, your system is below this requirement") - } - } - groupExist := false - dockerGroup, err := user.LookupGroup(crowdsecGroup) - if err == nil { - groupExist = true - } - if !forceYes && !groupExist { - prompt := &survey.Confirm{ - Message: fmt.Sprintf("For metabase docker to be able to access SQLite file we need to add a new group called '%s' to the system, is it ok for you ?", crowdsecGroup), - Default: true, - } - if err := survey.AskOne(prompt, &answer); err != nil { - log.Fatalf("unable to ask to force: %s", err) - } - } - if !answer && !forceYes && !groupExist { - log.Fatalf("unable to continue without creating '%s' group", crowdsecGroup) + if err := checkSystemMemory(&forceYes); err != nil { + return err } - if !groupExist { - groupAddCmd, err := exec.LookPath("groupadd") - if err != nil { - log.Fatalf("unable to find 'groupadd' command, can't continue") - } - - groupAdd := &exec.Cmd{Path: groupAddCmd, Args: []string{groupAddCmd, crowdsecGroup}} - if err := groupAdd.Run(); err != nil { - log.Fatalf("unable to add group '%s': %s", dockerGroup, err) - } - dockerGroup, err = user.LookupGroup(crowdsecGroup) - if err != nil { - log.Fatalf("unable to lookup '%s' group: %+v", dockerGroup, err) - } + warnIfNotLoopback(metabaseListenAddress) + if err := disclaimer(&forceYes); err != nil { + return err } - intID, err := strconv.Atoi(dockerGroup.Gid) + dockerGroup, err := checkGroups(&forceYes) if err != nil { - log.Fatalf("unable to convert group ID to int: %s", err) + return err } - if err := os.Chown(csConfig.DbConfig.DbPath, 0, intID); err != nil { - log.Fatalf("unable to chown sqlite db file '%s': %s", csConfig.DbConfig.DbPath, err) + if err = cli.chownDatabase(dockerGroup.Gid); err != nil { + return err } - - mb, err := metabase.SetupMetabase(csConfig.API.Server.DbConfig, metabaseListenAddress, metabaseListenPort, metabaseUser, metabasePassword, metabaseDbPath, dockerGroup.Gid, metabaseContainerID) + mb, err := metabase.SetupMetabase(cli.cfg().API.Server.DbConfig, metabaseListenAddress, metabaseListenPort, metabaseUser, metabasePassword, metabaseDBPath, dockerGroup.Gid, metabaseContainerID, metabaseImage) if err != nil { - log.Fatal(err) + return err } - if err := mb.DumpConfig(metabaseConfigPath); err != nil { - log.Fatal(err) + return err } log.Infof("Metabase is ready") @@ -194,79 +175,96 @@ cscli dashboard setup -l 0.0.0.0 -p 443 --password fmt.Printf("\tURL : '%s'\n", mb.Config.ListenURL) fmt.Printf("\tusername : '%s'\n", mb.Config.Username) fmt.Printf("\tpassword : '%s'\n", mb.Config.Password) + + return nil }, } - cmdDashSetup.Flags().BoolVarP(&force, "force", "f", false, "Force setup : override existing files") - cmdDashSetup.Flags().StringVarP(&metabaseDbPath, "dir", "d", "", "Shared directory with metabase container") - cmdDashSetup.Flags().StringVarP(&metabaseListenAddress, "listen", "l", metabaseListenAddress, "Listen address of container") - cmdDashSetup.Flags().StringVarP(&metabaseListenPort, "port", "p", metabaseListenPort, "Listen port of container") - cmdDashSetup.Flags().BoolVarP(&forceYes, "yes", "y", false, "force yes") - //cmdDashSetup.Flags().StringVarP(&metabaseUser, "user", "u", "crowdsec@crowdsec.net", "metabase user") - cmdDashSetup.Flags().StringVar(&metabasePassword, "password", "", "metabase password") - - return cmdDashSetup + + flags := cmd.Flags() + flags.BoolVarP(&force, "force", "f", false, "Force setup : override existing files") + flags.StringVarP(&metabaseDBPath, "dir", "d", "", "Shared directory with metabase container") + flags.StringVarP(&metabaseListenAddress, "listen", "l", metabaseListenAddress, "Listen address of container") + flags.StringVar(&metabaseImage, "metabase-image", metabaseImage, "Metabase image to use") + flags.StringVarP(&metabaseListenPort, "port", "p", metabaseListenPort, "Listen port of container") + flags.BoolVarP(&forceYes, "yes", "y", false, "force yes") + // flags.StringVarP(&metabaseUser, "user", "u", "crowdsec@crowdsec.net", "metabase user") + flags.StringVar(&metabasePassword, "password", "", "metabase password") + + return cmd } -func NewDashboardStartCmd() *cobra.Command { - var cmdDashStart = &cobra.Command{ +func (cli *cliDashboard) newStartCmd() *cobra.Command { + cmd := &cobra.Command{ Use: "start", Short: "Start the metabase container.", Long: `Stats the metabase container using docker.`, Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { + RunE: func(_ *cobra.Command, _ []string) error { mb, err := metabase.NewMetabase(metabaseConfigPath, metabaseContainerID) if err != nil { - log.Fatal(err) + return err + } + warnIfNotLoopback(mb.Config.ListenAddr) + if err := disclaimer(&forceYes); err != nil { + return err } if err := mb.Container.Start(); err != nil { - log.Fatalf("Failed to start metabase container : %s", err) + return fmt.Errorf("failed to start metabase container : %s", err) } log.Infof("Started metabase") - log.Infof("url : http://%s:%s", metabaseListenAddress, metabaseListenPort) + log.Infof("url : http://%s:%s", mb.Config.ListenAddr, mb.Config.ListenPort) + + return nil }, } - return cmdDashStart + + cmd.Flags().BoolVarP(&forceYes, "yes", "y", false, "force yes") + + return cmd } -func NewDashboardStopCmd() *cobra.Command { - var cmdDashStop = &cobra.Command{ +func (cli *cliDashboard) newStopCmd() *cobra.Command { + cmd := &cobra.Command{ Use: "stop", Short: "Stops the metabase container.", Long: `Stops the metabase container using docker.`, Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { + RunE: func(_ *cobra.Command, _ []string) error { if err := metabase.StopContainer(metabaseContainerID); err != nil { - log.Fatalf("unable to stop container '%s': %s", metabaseContainerID, err) + return fmt.Errorf("unable to stop container '%s': %s", metabaseContainerID, err) } + return nil }, } - return cmdDashStop -} + return cmd +} -func NewDashboardShowPasswordCmd() *cobra.Command { - var cmdDashShowPassword = &cobra.Command{Use: "show-password", +func (cli *cliDashboard) newShowPasswordCmd() *cobra.Command { + cmd := &cobra.Command{Use: "show-password", Short: "displays password of metabase.", Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { + RunE: func(_ *cobra.Command, _ []string) error { m := metabase.Metabase{} if err := m.LoadConfig(metabaseConfigPath); err != nil { - log.Fatal(err) + return err } log.Printf("'%s'", m.Config.Password) + + return nil }, } - return cmdDashShowPassword -} + return cmd +} -func NewDashboardRemoveCmd() *cobra.Command { +func (cli *cliDashboard) newRemoveCmd() *cobra.Command { var force bool - var cmdDashRemove = &cobra.Command{ + cmd := &cobra.Command{ Use: "remove", Short: "removes the metabase container.", Long: `removes the metabase container using docker.`, @@ -276,66 +274,77 @@ func NewDashboardRemoveCmd() *cobra.Command { cscli dashboard remove cscli dashboard remove --force `, - Run: func(cmd *cobra.Command, args []string) { - answer := true + RunE: func(_ *cobra.Command, _ []string) error { if !forceYes { + var answer bool prompt := &survey.Confirm{ Message: "Do you really want to remove crowdsec dashboard? (all your changes will be lost)", Default: true, } if err := survey.AskOne(prompt, &answer); err != nil { - log.Fatalf("unable to ask to force: %s", err) + return fmt.Errorf("unable to ask to force: %s", err) + } + if !answer { + return errors.New("user stated no to continue") } } - if answer { - if metabase.IsContainerExist(metabaseContainerID) { - log.Debugf("Stopping container %s", metabaseContainerID) - if err := metabase.StopContainer(metabaseContainerID); err != nil { - log.Warningf("unable to stop container '%s': %s", metabaseContainerID, err) - } - dockerGroup, err := user.LookupGroup(crowdsecGroup) - if err == nil { // if group exist, remove it - groupDelCmd, err := exec.LookPath("groupdel") - if err != nil { - log.Fatalf("unable to find 'groupdel' command, can't continue") - } - - groupDel := &exec.Cmd{Path: groupDelCmd, Args: []string{groupDelCmd, crowdsecGroup}} - if err := groupDel.Run(); err != nil { - log.Errorf("unable to delete group '%s': %s", dockerGroup, err) - } + if metabase.IsContainerExist(metabaseContainerID) { + log.Debugf("Stopping container %s", metabaseContainerID) + if err := metabase.StopContainer(metabaseContainerID); err != nil { + log.Warningf("unable to stop container '%s': %s", metabaseContainerID, err) + } + dockerGroup, err := user.LookupGroup(crowdsecGroup) + if err == nil { // if group exist, remove it + groupDelCmd, err := exec.LookPath("groupdel") + if err != nil { + return errors.New("unable to find 'groupdel' command, can't continue") } - log.Debugf("Removing container %s", metabaseContainerID) - if err := metabase.RemoveContainer(metabaseContainerID); err != nil { - log.Warningf("unable to remove container '%s': %s", metabaseContainerID, err) + + groupDel := &exec.Cmd{Path: groupDelCmd, Args: []string{groupDelCmd, crowdsecGroup}} + if err := groupDel.Run(); err != nil { + log.Warnf("unable to delete group '%s': %s", dockerGroup, err) } - log.Infof("container %s stopped & removed", metabaseContainerID) } - log.Debugf("Removing metabase db %s", csConfig.ConfigPaths.DataDir) - if err := metabase.RemoveDatabase(csConfig.ConfigPaths.DataDir); err != nil { - log.Warningf("failed to remove metabase internal db : %s", err) + log.Debugf("Removing container %s", metabaseContainerID) + if err := metabase.RemoveContainer(metabaseContainerID); err != nil { + log.Warnf("unable to remove container '%s': %s", metabaseContainerID, err) + } + log.Infof("container %s stopped & removed", metabaseContainerID) + } + log.Debugf("Removing metabase db %s", cli.cfg().ConfigPaths.DataDir) + if err := metabase.RemoveDatabase(cli.cfg().ConfigPaths.DataDir); err != nil { + log.Warnf("failed to remove metabase internal db : %s", err) + } + if force { + m := metabase.Metabase{} + if err := m.LoadConfig(metabaseConfigPath); err != nil { + return err } - if force { - if err := metabase.RemoveImageContainer(); err != nil { - if !strings.Contains(err.Error(), "No such image") { - log.Fatalf("removing docker image: %s", err) - } + if err := metabase.RemoveImageContainer(m.Config.Image); err != nil { + if !strings.Contains(err.Error(), "No such image") { + return fmt.Errorf("removing docker image: %s", err) } } } + + return nil }, } - cmdDashRemove.Flags().BoolVarP(&force, "force", "f", false, "Remove also the metabase image") - cmdDashRemove.Flags().BoolVarP(&forceYes, "yes", "y", false, "force yes") - return cmdDashRemove + flags := cmd.Flags() + flags.BoolVarP(&force, "force", "f", false, "Remove also the metabase image") + flags.BoolVarP(&forceYes, "yes", "y", false, "force yes") + + return cmd } func passwordIsValid(password string) bool { hasDigit := false + for _, j := range password { if unicode.IsDigit(j) { hasDigit = true + break } } @@ -343,17 +352,134 @@ func passwordIsValid(password string) bool { if !hasDigit || len(password) < 6 { return false } - return true + return true } -func checkSystemMemory() (bool, error) { +func checkSystemMemory(forceYes *bool) error { totMem := memory.TotalMemory() - if totMem == 0 { - return true, errors.New("Unable to get system total memory") + if totMem >= uint64(math.Pow(2, 30)) { + return nil + } + + if !*forceYes { + var answer bool + + prompt := &survey.Confirm{ + Message: "Metabase requires 1-2GB of RAM, your system is below this requirement continue ?", + Default: true, + } + if err := survey.AskOne(prompt, &answer); err != nil { + return fmt.Errorf("unable to ask about RAM check: %s", err) + } + + if !answer { + return errors.New("user stated no to continue") + } + + return nil + } + + log.Warn("Metabase requires 1-2GB of RAM, your system is below this requirement") + + return nil +} + +func warnIfNotLoopback(addr string) { + if addr == "127.0.0.1" || addr == "::1" { + return } - if uint64(math.Pow(2, 30)) >= totMem { - return false, nil + + log.Warnf("You are potentially exposing your metabase port to the internet (addr: %s), please consider using a reverse proxy", addr) +} + +func disclaimer(forceYes *bool) error { + if !*forceYes { + var answer bool + + prompt := &survey.Confirm{ + Message: "CrowdSec takes no responsibility for the security of your metabase instance. Do you accept these responsibilities ?", + Default: true, + } + + if err := survey.AskOne(prompt, &answer); err != nil { + return fmt.Errorf("unable to ask to question: %s", err) + } + + if !answer { + return errors.New("user stated no to responsibilities") + } + + return nil } - return true, nil + + log.Warn("CrowdSec takes no responsibility for the security of your metabase instance. You used force yes, so you accept this disclaimer") + + return nil +} + +func checkGroups(forceYes *bool) (*user.Group, error) { + dockerGroup, err := user.LookupGroup(crowdsecGroup) + if err == nil { + return dockerGroup, nil + } + + if !*forceYes { + var answer bool + + prompt := &survey.Confirm{ + Message: fmt.Sprintf("For metabase docker to be able to access SQLite file we need to add a new group called '%s' to the system, is it ok for you ?", crowdsecGroup), + Default: true, + } + + if err := survey.AskOne(prompt, &answer); err != nil { + return dockerGroup, fmt.Errorf("unable to ask to question: %s", err) + } + + if !answer { + return dockerGroup, fmt.Errorf("unable to continue without creating '%s' group", crowdsecGroup) + } + } + + groupAddCmd, err := exec.LookPath("groupadd") + if err != nil { + return dockerGroup, errors.New("unable to find 'groupadd' command, can't continue") + } + + groupAdd := &exec.Cmd{Path: groupAddCmd, Args: []string{groupAddCmd, crowdsecGroup}} + if err := groupAdd.Run(); err != nil { + return dockerGroup, fmt.Errorf("unable to add group '%s': %s", dockerGroup, err) + } + + return user.LookupGroup(crowdsecGroup) +} + +func (cli *cliDashboard) chownDatabase(gid string) error { + cfg := cli.cfg() + intID, err := strconv.Atoi(gid) + + if err != nil { + return fmt.Errorf("unable to convert group ID to int: %s", err) + } + + if stat, err := os.Stat(cfg.DbConfig.DbPath); !os.IsNotExist(err) { + info := stat.Sys() + if err := os.Chown(cfg.DbConfig.DbPath, int(info.(*syscall.Stat_t).Uid), intID); err != nil { + return fmt.Errorf("unable to chown sqlite db file '%s': %s", cfg.DbConfig.DbPath, err) + } + } + + if cfg.DbConfig.Type == "sqlite" && cfg.DbConfig.UseWal != nil && *cfg.DbConfig.UseWal { + for _, ext := range []string{"-wal", "-shm"} { + file := cfg.DbConfig.DbPath + ext + if stat, err := os.Stat(file); !os.IsNotExist(err) { + info := stat.Sys() + if err := os.Chown(file, int(info.(*syscall.Stat_t).Uid), intID); err != nil { + return fmt.Errorf("unable to chown sqlite db file '%s': %s", file, err) + } + } + } + } + + return nil } diff --git a/cmd/crowdsec-cli/dashboard_unsupported.go b/cmd/crowdsec-cli/dashboard_unsupported.go new file mode 100644 index 00000000000..cc80abd2528 --- /dev/null +++ b/cmd/crowdsec-cli/dashboard_unsupported.go @@ -0,0 +1,32 @@ +//go:build !linux + +package main + +import ( + "runtime" + + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" +) + +type cliDashboard struct{ + cfg configGetter +} + +func NewCLIDashboard(cfg configGetter) *cliDashboard { + return &cliDashboard{ + cfg: cfg, + } +} + +func (cli cliDashboard) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "dashboard", + DisableAutoGenTag: true, + Run: func(_ *cobra.Command, _ []string) { + log.Infof("Dashboard command is disabled on %s", runtime.GOOS) + }, + } + + return cmd +} diff --git a/cmd/crowdsec-cli/decisions.go b/cmd/crowdsec-cli/decisions.go deleted file mode 100644 index ce3d0e46e2b..00000000000 --- a/cmd/crowdsec-cli/decisions.go +++ /dev/null @@ -1,482 +0,0 @@ -package main - -import ( - "context" - "encoding/csv" - "encoding/json" - "fmt" - "net/url" - "os" - "strconv" - "strings" - "time" - - "github.com/fatih/color" - "github.com/go-openapi/strfmt" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - - "github.com/crowdsecurity/go-cs-lib/pkg/version" - - "github.com/crowdsecurity/crowdsec/pkg/apiclient" - "github.com/crowdsecurity/crowdsec/pkg/models" - "github.com/crowdsecurity/crowdsec/pkg/types" -) - -var Client *apiclient.ApiClient - -func DecisionsToTable(alerts *models.GetAlertsResponse, printMachine bool) error { - /*here we cheat a bit : to make it more readable for the user, we dedup some entries*/ - spamLimit := make(map[string]bool) - skipped := 0 - - for aIdx := 0; aIdx < len(*alerts); aIdx++ { - alertItem := (*alerts)[aIdx] - newDecisions := make([]*models.Decision, 0) - for _, decisionItem := range alertItem.Decisions { - spamKey := fmt.Sprintf("%t:%s:%s:%s", *decisionItem.Simulated, *decisionItem.Type, *decisionItem.Scope, *decisionItem.Value) - if _, ok := spamLimit[spamKey]; ok { - skipped++ - continue - } - spamLimit[spamKey] = true - newDecisions = append(newDecisions, decisionItem) - } - alertItem.Decisions = newDecisions - } - if csConfig.Cscli.Output == "raw" { - csvwriter := csv.NewWriter(os.Stdout) - header := []string{"id", "source", "ip", "reason", "action", "country", "as", "events_count", "expiration", "simulated", "alert_id"} - if printMachine { - header = append(header, "machine") - } - err := csvwriter.Write(header) - if err != nil { - return err - } - for _, alertItem := range *alerts { - for _, decisionItem := range alertItem.Decisions { - raw := []string{ - fmt.Sprintf("%d", decisionItem.ID), - *decisionItem.Origin, - *decisionItem.Scope + ":" + *decisionItem.Value, - *decisionItem.Scenario, - *decisionItem.Type, - alertItem.Source.Cn, - alertItem.Source.GetAsNumberName(), - fmt.Sprintf("%d", *alertItem.EventsCount), - *decisionItem.Duration, - fmt.Sprintf("%t", *decisionItem.Simulated), - fmt.Sprintf("%d", alertItem.ID), - } - if printMachine { - raw = append(raw, alertItem.MachineID) - } - - err := csvwriter.Write(raw) - if err != nil { - return err - } - } - } - csvwriter.Flush() - } else if csConfig.Cscli.Output == "json" { - x, _ := json.MarshalIndent(alerts, "", " ") - fmt.Printf("%s", string(x)) - } else if csConfig.Cscli.Output == "human" { - if len(*alerts) == 0 { - fmt.Println("No active decisions") - return nil - } - decisionsTable(color.Output, alerts, printMachine) - if skipped > 0 { - fmt.Printf("%d duplicated entries skipped\n", skipped) - } - } - return nil -} - -func NewDecisionsCmd() *cobra.Command { - var cmdDecisions = &cobra.Command{ - Use: "decisions [action]", - Short: "Manage decisions", - Long: `Add/List/Delete/Import decisions from LAPI`, - Example: `cscli decisions [action] [filter]`, - Aliases: []string{"decision"}, - /*TBD example*/ - Args: cobra.MinimumNArgs(1), - DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if err := csConfig.LoadAPIClient(); err != nil { - return fmt.Errorf("loading api client: %w", err) - } - password := strfmt.Password(csConfig.API.Client.Credentials.Password) - apiurl, err := url.Parse(csConfig.API.Client.Credentials.URL) - if err != nil { - return fmt.Errorf("parsing api url %s: %w", csConfig.API.Client.Credentials.URL, err) - } - Client, err = apiclient.NewClient(&apiclient.Config{ - MachineID: csConfig.API.Client.Credentials.Login, - Password: password, - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), - URL: apiurl, - VersionPrefix: "v1", - }) - if err != nil { - return fmt.Errorf("creating api client: %w", err) - } - return nil - }, - } - - cmdDecisions.AddCommand(NewDecisionsListCmd()) - cmdDecisions.AddCommand(NewDecisionsAddCmd()) - cmdDecisions.AddCommand(NewDecisionsDeleteCmd()) - cmdDecisions.AddCommand(NewDecisionsImportCmd()) - - return cmdDecisions -} - -func NewDecisionsListCmd() *cobra.Command { - var filter = apiclient.AlertsListOpts{ - ValueEquals: new(string), - ScopeEquals: new(string), - ScenarioEquals: new(string), - OriginEquals: new(string), - IPEquals: new(string), - RangeEquals: new(string), - Since: new(string), - Until: new(string), - TypeEquals: new(string), - IncludeCAPI: new(bool), - Limit: new(int), - } - NoSimu := new(bool) - contained := new(bool) - var printMachine bool - - var cmdDecisionsList = &cobra.Command{ - Use: "list [options]", - Short: "List decisions from LAPI", - Example: `cscli decisions list -i 1.2.3.4 -cscli decisions list -r 1.2.3.0/24 -cscli decisions list -s crowdsecurity/ssh-bf -cscli decisions list -t ban -`, - Args: cobra.ExactArgs(0), - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - var err error - /*take care of shorthand options*/ - if err = manageCliDecisionAlerts(filter.IPEquals, filter.RangeEquals, filter.ScopeEquals, filter.ValueEquals); err != nil { - return err - } - filter.ActiveDecisionEquals = new(bool) - *filter.ActiveDecisionEquals = true - if NoSimu != nil && *NoSimu { - filter.IncludeSimulated = new(bool) - } - /* nullify the empty entries to avoid bad filter */ - if *filter.Until == "" { - filter.Until = nil - } else if strings.HasSuffix(*filter.Until, "d") { - /*time.ParseDuration support hours 'h' as bigger unit, let's make the user's life easier*/ - realDuration := strings.TrimSuffix(*filter.Until, "d") - days, err := strconv.Atoi(realDuration) - if err != nil { - printHelp(cmd) - return fmt.Errorf("can't parse duration %s, valid durations format: 1d, 4h, 4h15m", *filter.Until) - } - *filter.Until = fmt.Sprintf("%d%s", days*24, "h") - } - - if *filter.Since == "" { - filter.Since = nil - } else if strings.HasSuffix(*filter.Since, "d") { - /*time.ParseDuration support hours 'h' as bigger unit, let's make the user's life easier*/ - realDuration := strings.TrimSuffix(*filter.Since, "d") - days, err := strconv.Atoi(realDuration) - if err != nil { - printHelp(cmd) - return fmt.Errorf("can't parse duration %s, valid durations format: 1d, 4h, 4h15m", *filter.Since) - } - *filter.Since = fmt.Sprintf("%d%s", days*24, "h") - } - if *filter.IncludeCAPI { - *filter.Limit = 0 - } - if *filter.TypeEquals == "" { - filter.TypeEquals = nil - } - if *filter.ValueEquals == "" { - filter.ValueEquals = nil - } - if *filter.ScopeEquals == "" { - filter.ScopeEquals = nil - } - if *filter.ScenarioEquals == "" { - filter.ScenarioEquals = nil - } - if *filter.IPEquals == "" { - filter.IPEquals = nil - } - if *filter.RangeEquals == "" { - filter.RangeEquals = nil - } - - if *filter.OriginEquals == "" { - filter.OriginEquals = nil - } - - if contained != nil && *contained { - filter.Contains = new(bool) - } - - alerts, _, err := Client.Alerts.List(context.Background(), filter) - if err != nil { - return fmt.Errorf("unable to retrieve decisions: %w", err) - } - - err = DecisionsToTable(alerts, printMachine) - if err != nil { - return fmt.Errorf("unable to print decisions: %w", err) - } - - return nil - }, - } - cmdDecisionsList.Flags().SortFlags = false - cmdDecisionsList.Flags().BoolVarP(filter.IncludeCAPI, "all", "a", false, "Include decisions from Central API") - cmdDecisionsList.Flags().StringVar(filter.Since, "since", "", "restrict to alerts newer than since (ie. 4h, 30d)") - cmdDecisionsList.Flags().StringVar(filter.Until, "until", "", "restrict to alerts older than until (ie. 4h, 30d)") - cmdDecisionsList.Flags().StringVarP(filter.TypeEquals, "type", "t", "", "restrict to this decision type (ie. ban,captcha)") - cmdDecisionsList.Flags().StringVar(filter.ScopeEquals, "scope", "", "restrict to this scope (ie. ip,range,session)") - cmdDecisionsList.Flags().StringVar(filter.OriginEquals, "origin", "", fmt.Sprintf("the value to match for the specified origin (%s ...)", strings.Join(types.GetOrigins(), ","))) - cmdDecisionsList.Flags().StringVarP(filter.ValueEquals, "value", "v", "", "restrict to this value (ie. 1.2.3.4,userName)") - cmdDecisionsList.Flags().StringVarP(filter.ScenarioEquals, "scenario", "s", "", "restrict to this scenario (ie. crowdsecurity/ssh-bf)") - cmdDecisionsList.Flags().StringVarP(filter.IPEquals, "ip", "i", "", "restrict to alerts from this source ip (shorthand for --scope ip --value )") - cmdDecisionsList.Flags().StringVarP(filter.RangeEquals, "range", "r", "", "restrict to alerts from this source range (shorthand for --scope range --value )") - cmdDecisionsList.Flags().IntVarP(filter.Limit, "limit", "l", 100, "number of alerts to get (use 0 to remove the limit)") - cmdDecisionsList.Flags().BoolVar(NoSimu, "no-simu", false, "exclude decisions in simulation mode") - cmdDecisionsList.Flags().BoolVarP(&printMachine, "machine", "m", false, "print machines that triggered decisions") - cmdDecisionsList.Flags().BoolVar(contained, "contained", false, "query decisions contained by range") - - return cmdDecisionsList -} - -func NewDecisionsAddCmd() *cobra.Command { - var ( - addIP string - addRange string - addDuration string - addValue string - addScope string - addReason string - addType string - ) - - var cmdDecisionsAdd = &cobra.Command{ - Use: "add [options]", - Short: "Add decision to LAPI", - Example: `cscli decisions add --ip 1.2.3.4 -cscli decisions add --range 1.2.3.0/24 -cscli decisions add --ip 1.2.3.4 --duration 24h --type captcha -cscli decisions add --scope username --value foobar -`, - /*TBD : fix long and example*/ - Args: cobra.ExactArgs(0), - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - var err error - alerts := models.AddAlertsRequest{} - origin := types.CscliOrigin - capacity := int32(0) - leakSpeed := "0" - eventsCount := int32(1) - empty := "" - simulated := false - startAt := time.Now().UTC().Format(time.RFC3339) - stopAt := time.Now().UTC().Format(time.RFC3339) - createdAt := time.Now().UTC().Format(time.RFC3339) - - /*take care of shorthand options*/ - if err := manageCliDecisionAlerts(&addIP, &addRange, &addScope, &addValue); err != nil { - return err - } - - if addIP != "" { - addValue = addIP - addScope = types.Ip - } else if addRange != "" { - addValue = addRange - addScope = types.Range - } else if addValue == "" { - printHelp(cmd) - return fmt.Errorf("Missing arguments, a value is required (--ip, --range or --scope and --value)") - } - - if addReason == "" { - addReason = fmt.Sprintf("manual '%s' from '%s'", addType, csConfig.API.Client.Credentials.Login) - } - decision := models.Decision{ - Duration: &addDuration, - Scope: &addScope, - Value: &addValue, - Type: &addType, - Scenario: &addReason, - Origin: &origin, - } - alert := models.Alert{ - Capacity: &capacity, - Decisions: []*models.Decision{&decision}, - Events: []*models.Event{}, - EventsCount: &eventsCount, - Leakspeed: &leakSpeed, - Message: &addReason, - ScenarioHash: &empty, - Scenario: &addReason, - ScenarioVersion: &empty, - Simulated: &simulated, - //setting empty scope/value broke plugins, and it didn't seem to be needed anymore w/ latest papi changes - Source: &models.Source{ - AsName: empty, - AsNumber: empty, - Cn: empty, - IP: addValue, - Range: "", - Scope: &addScope, - Value: &addValue, - }, - StartAt: &startAt, - StopAt: &stopAt, - CreatedAt: createdAt, - } - alerts = append(alerts, &alert) - - _, _, err = Client.Alerts.Add(context.Background(), alerts) - if err != nil { - return err - } - - log.Info("Decision successfully added") - return nil - }, - } - - cmdDecisionsAdd.Flags().SortFlags = false - cmdDecisionsAdd.Flags().StringVarP(&addIP, "ip", "i", "", "Source ip (shorthand for --scope ip --value )") - cmdDecisionsAdd.Flags().StringVarP(&addRange, "range", "r", "", "Range source ip (shorthand for --scope range --value )") - cmdDecisionsAdd.Flags().StringVarP(&addDuration, "duration", "d", "4h", "Decision duration (ie. 1h,4h,30m)") - cmdDecisionsAdd.Flags().StringVarP(&addValue, "value", "v", "", "The value (ie. --scope username --value foobar)") - cmdDecisionsAdd.Flags().StringVar(&addScope, "scope", types.Ip, "Decision scope (ie. ip,range,username)") - cmdDecisionsAdd.Flags().StringVarP(&addReason, "reason", "R", "", "Decision reason (ie. scenario-name)") - cmdDecisionsAdd.Flags().StringVarP(&addType, "type", "t", "ban", "Decision type (ie. ban,captcha,throttle)") - - return cmdDecisionsAdd -} - -func NewDecisionsDeleteCmd() *cobra.Command { - var delFilter = apiclient.DecisionsDeleteOpts{ - ScopeEquals: new(string), - ValueEquals: new(string), - TypeEquals: new(string), - IPEquals: new(string), - RangeEquals: new(string), - ScenarioEquals: new(string), - OriginEquals: new(string), - } - var delDecisionId string - var delDecisionAll bool - contained := new(bool) - - var cmdDecisionsDelete = &cobra.Command{ - Use: "delete [options]", - Short: "Delete decisions", - DisableAutoGenTag: true, - Aliases: []string{"remove"}, - Example: `cscli decisions delete -r 1.2.3.0/24 -cscli decisions delete -i 1.2.3.4 -cscli decisions delete --id 42 -cscli decisions delete --type captcha -`, - /*TBD : refaire le Long/Example*/ - PreRunE: func(cmd *cobra.Command, args []string) error { - if delDecisionAll { - return nil - } - if *delFilter.ScopeEquals == "" && *delFilter.ValueEquals == "" && - *delFilter.TypeEquals == "" && *delFilter.IPEquals == "" && - *delFilter.RangeEquals == "" && *delFilter.ScenarioEquals == "" && - *delFilter.OriginEquals == "" && delDecisionId == "" { - cmd.Usage() - return fmt.Errorf("at least one filter or --all must be specified") - } - - return nil - }, - RunE: func(cmd *cobra.Command, args []string) error { - var err error - var decisions *models.DeleteDecisionResponse - - /*take care of shorthand options*/ - if err = manageCliDecisionAlerts(delFilter.IPEquals, delFilter.RangeEquals, delFilter.ScopeEquals, delFilter.ValueEquals); err != nil { - return err - } - if *delFilter.ScopeEquals == "" { - delFilter.ScopeEquals = nil - } - if *delFilter.OriginEquals == "" { - delFilter.OriginEquals = nil - } - if *delFilter.ValueEquals == "" { - delFilter.ValueEquals = nil - } - if *delFilter.ScenarioEquals == "" { - delFilter.ScenarioEquals = nil - } - if *delFilter.TypeEquals == "" { - delFilter.TypeEquals = nil - } - if *delFilter.IPEquals == "" { - delFilter.IPEquals = nil - } - if *delFilter.RangeEquals == "" { - delFilter.RangeEquals = nil - } - if contained != nil && *contained { - delFilter.Contains = new(bool) - } - - if delDecisionId == "" { - decisions, _, err = Client.Decisions.Delete(context.Background(), delFilter) - if err != nil { - return fmt.Errorf("Unable to delete decisions: %v", err) - } - } else { - if _, err = strconv.Atoi(delDecisionId); err != nil { - return fmt.Errorf("id '%s' is not an integer: %v", delDecisionId, err) - } - decisions, _, err = Client.Decisions.DeleteOne(context.Background(), delDecisionId) - if err != nil { - return fmt.Errorf("Unable to delete decision: %v", err) - } - } - log.Infof("%s decision(s) deleted", decisions.NbDeleted) - return nil - }, - } - - cmdDecisionsDelete.Flags().SortFlags = false - cmdDecisionsDelete.Flags().StringVarP(delFilter.IPEquals, "ip", "i", "", "Source ip (shorthand for --scope ip --value )") - cmdDecisionsDelete.Flags().StringVarP(delFilter.RangeEquals, "range", "r", "", "Range source ip (shorthand for --scope range --value )") - cmdDecisionsDelete.Flags().StringVarP(delFilter.TypeEquals, "type", "t", "", "the decision type (ie. ban,captcha)") - cmdDecisionsDelete.Flags().StringVarP(delFilter.ValueEquals, "value", "v", "", "the value to match for in the specified scope") - cmdDecisionsDelete.Flags().StringVarP(delFilter.ScenarioEquals, "scenario", "s", "", "the scenario name (ie. crowdsecurity/ssh-bf)") - cmdDecisionsDelete.Flags().StringVar(delFilter.OriginEquals, "origin", "", fmt.Sprintf("the value to match for the specified origin (%s ...)", strings.Join(types.GetOrigins(), ","))) - - cmdDecisionsDelete.Flags().StringVar(&delDecisionId, "id", "", "decision id") - cmdDecisionsDelete.Flags().BoolVar(&delDecisionAll, "all", false, "delete all decisions") - cmdDecisionsDelete.Flags().BoolVar(contained, "contained", false, "query decisions contained by range") - - return cmdDecisionsDelete -} diff --git a/cmd/crowdsec-cli/doc.go b/cmd/crowdsec-cli/doc.go new file mode 100644 index 00000000000..f68d535db03 --- /dev/null +++ b/cmd/crowdsec-cli/doc.go @@ -0,0 +1,61 @@ +package main + +import ( + "fmt" + "path/filepath" + "strings" + + "github.com/spf13/cobra" + "github.com/spf13/cobra/doc" +) + +type cliDoc struct{} + +func NewCLIDoc() *cliDoc { + return &cliDoc{} +} + +func (cli cliDoc) NewCommand(rootCmd *cobra.Command) *cobra.Command { + var target string + + const defaultTarget = "./doc" + + cmd := &cobra.Command{ + Use: "doc", + Short: "Generate the documentation related to cscli commands. Target directory must exist.", + Args: cobra.NoArgs, + Hidden: true, + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + if err := doc.GenMarkdownTreeCustom(rootCmd, target, cli.filePrepender, cli.linkHandler); err != nil { + return fmt.Errorf("failed to generate cscli documentation: %w", err) + } + + fmt.Println("Documentation generated in", target) + + return nil + }, + } + + flags := cmd.Flags() + flags.StringVar(&target, "target", defaultTarget, "The target directory where the documentation will be generated") + + return cmd +} + +func (cli cliDoc) filePrepender(filename string) string { + const header = `--- +id: %s +title: %s +--- +` + + name := filepath.Base(filename) + base := strings.TrimSuffix(name, filepath.Ext(name)) + + return fmt.Sprintf(header, base, strings.ReplaceAll(base, "_", " ")) +} + +func (cli cliDoc) linkHandler(name string) string { + return fmt.Sprintf("/cscli/%s", name) +} diff --git a/cmd/crowdsec-cli/explain.go b/cmd/crowdsec-cli/explain.go deleted file mode 100644 index d9b1ae31d91..00000000000 --- a/cmd/crowdsec-cli/explain.go +++ /dev/null @@ -1,202 +0,0 @@ -package main - -import ( - "bufio" - "fmt" - "io" - "os" - "os/exec" - "path/filepath" - - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - - "github.com/crowdsecurity/crowdsec/pkg/hubtest" - "github.com/crowdsecurity/crowdsec/pkg/types" -) - -func runExplain(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - logFile, err := flags.GetString("file") - if err != nil { - return err - } - - dsn, err := flags.GetString("dsn") - if err != nil { - return err - } - - logLine, err := flags.GetString("log") - if err != nil { - return err - } - - logType, err := flags.GetString("type") - if err != nil { - return err - } - - opts := hubtest.DumpOpts{} - - opts.Details, err = flags.GetBool("verbose") - if err != nil { - return err - } - - opts.SkipOk, err = flags.GetBool("failures") - if err != nil { - return err - } - - opts.ShowNotOkParsers, err = flags.GetBool("only-successful-parsers") - opts.ShowNotOkParsers = !opts.ShowNotOkParsers - if err != nil { - return err - } - - crowdsec, err := flags.GetString("crowdsec") - if err != nil { - return err - } - - fileInfo, _ := os.Stdin.Stat() - - if logType == "" || (logLine == "" && logFile == "" && dsn == "") { - printHelp(cmd) - fmt.Println() - fmt.Printf("Please provide --type flag\n") - os.Exit(1) - } - - if logFile == "-" && ((fileInfo.Mode() & os.ModeCharDevice) == os.ModeCharDevice) { - return fmt.Errorf("the option -f - is intended to work with pipes") - } - - var f *os.File - - // using empty string fallback to /tmp - dir, err := os.MkdirTemp("", "cscli_explain") - if err != nil { - return fmt.Errorf("couldn't create a temporary directory to store cscli explain result: %s", err) - } - tmpFile := "" - // we create a temporary log file if a log line/stdin has been provided - if logLine != "" || logFile == "-" { - tmpFile = filepath.Join(dir, "cscli_test_tmp.log") - f, err = os.Create(tmpFile) - if err != nil { - return err - } - - if logLine != "" { - _, err = f.WriteString(logLine) - if err != nil { - return err - } - } else if logFile == "-" { - reader := bufio.NewReader(os.Stdin) - errCount := 0 - for { - input, err := reader.ReadBytes('\n') - if err != nil && err == io.EOF { - break - } - _, err = f.Write(input) - if err != nil { - errCount++ - } - } - if errCount > 0 { - log.Warnf("Failed to write %d lines to tmp file", errCount) - } - } - f.Close() - // this is the file that was going to be read by crowdsec anyway - logFile = tmpFile - } - - if logFile != "" { - absolutePath, err := filepath.Abs(logFile) - if err != nil { - return fmt.Errorf("unable to get absolute path of '%s', exiting", logFile) - } - dsn = fmt.Sprintf("file://%s", absolutePath) - lineCount := types.GetLineCountForFile(absolutePath) - if lineCount > 100 { - log.Warnf("log file contains %d lines. This may take lot of resources.", lineCount) - } - } - - if dsn == "" { - return fmt.Errorf("no acquisition (--file or --dsn) provided, can't run cscli test") - } - - cmdArgs := []string{"-c", ConfigFilePath, "-type", logType, "-dsn", dsn, "-dump-data", dir, "-no-api"} - crowdsecCmd := exec.Command(crowdsec, cmdArgs...) - output, err := crowdsecCmd.CombinedOutput() - if err != nil { - fmt.Println(string(output)) - return fmt.Errorf("fail to run crowdsec for test: %v", err) - } - - // rm the temporary log file if only a log line/stdin was provided - if tmpFile != "" { - if err := os.Remove(tmpFile); err != nil { - return fmt.Errorf("unable to remove tmp log file '%s': %+v", tmpFile, err) - } - } - parserDumpFile := filepath.Join(dir, hubtest.ParserResultFileName) - bucketStateDumpFile := filepath.Join(dir, hubtest.BucketPourResultFileName) - - parserDump, err := hubtest.LoadParserDump(parserDumpFile) - if err != nil { - return fmt.Errorf("unable to load parser dump result: %s", err) - } - - bucketStateDump, err := hubtest.LoadBucketPourDump(bucketStateDumpFile) - if err != nil { - return fmt.Errorf("unable to load bucket dump result: %s", err) - } - - hubtest.DumpTree(*parserDump, *bucketStateDump, opts) - - if err := os.RemoveAll(dir); err != nil { - return fmt.Errorf("unable to delete temporary directory '%s': %s", dir, err) - } - - return nil -} - -func NewExplainCmd() *cobra.Command { - cmdExplain := &cobra.Command{ - Use: "explain", - Short: "Explain log pipeline", - Long: ` -Explain log pipeline - `, - Example: ` -cscli explain --file ./myfile.log --type nginx -cscli explain --log "Sep 19 18:33:22 scw-d95986 sshd[24347]: pam_unix(sshd:auth): authentication failure; logname= uid=0 euid=0 tty=ssh ruser= rhost=1.2.3.4" --type syslog -cscli explain --dsn "file://myfile.log" --type nginx -tail -n 5 myfile.log | cscli explain --type nginx -f - - `, - Args: cobra.ExactArgs(0), - DisableAutoGenTag: true, - RunE: runExplain, - } - - flags := cmdExplain.Flags() - - flags.StringP("file", "f", "", "Log file to test") - flags.StringP("dsn", "d", "", "DSN to test") - flags.StringP("log", "l", "", "Log line to test") - flags.StringP("type", "t", "", "Type of the acquisition to test") - flags.BoolP("verbose", "v", false, "Display individual changes") - flags.Bool("failures", false, "Only show failed lines") - flags.Bool("only-successful-parsers", false, "Only show successful parsers") - flags.String("crowdsec", "crowdsec", "Path to crowdsec") - - return cmdExplain -} diff --git a/cmd/crowdsec-cli/hub.go b/cmd/crowdsec-cli/hub.go deleted file mode 100644 index 4fec8fc8d59..00000000000 --- a/cmd/crowdsec-cli/hub.go +++ /dev/null @@ -1,164 +0,0 @@ -package main - -import ( - "errors" - "fmt" - - "github.com/fatih/color" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - - "github.com/crowdsecurity/crowdsec/pkg/cwhub" -) - -func NewHubCmd() *cobra.Command { - var cmdHub = &cobra.Command{ - Use: "hub [action]", - Short: "Manage Hub", - Long: ` -Hub management - -List/update parsers/scenarios/postoverflows/collections from [Crowdsec Hub](https://hub.crowdsec.net). -The Hub is managed by cscli, to get the latest hub files from [Crowdsec Hub](https://hub.crowdsec.net), you need to update. - `, - Example: ` -cscli hub list # List all installed configurations -cscli hub update # Download list of available configurations from the hub - `, - Args: cobra.ExactArgs(0), - DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if csConfig.Cscli == nil { - return fmt.Errorf("you must configure cli before interacting with hub") - } - - return nil - }, - } - cmdHub.PersistentFlags().StringVarP(&cwhub.HubBranch, "branch", "b", "", "Use given branch from hub") - - cmdHub.AddCommand(NewHubListCmd()) - cmdHub.AddCommand(NewHubUpdateCmd()) - cmdHub.AddCommand(NewHubUpgradeCmd()) - - return cmdHub -} - -func NewHubListCmd() *cobra.Command { - var cmdHubList = &cobra.Command{ - Use: "list [-a]", - Short: "List installed configs", - Args: cobra.ExactArgs(0), - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - - if err := csConfig.LoadHub(); err != nil { - log.Fatal(err) - } - if err := cwhub.GetHubIdx(csConfig.Hub); err != nil { - log.Info("Run 'sudo cscli hub update' to get the hub index") - log.Fatalf("Failed to get Hub index : %v", err) - } - //use LocalSync to get warnings about tainted / outdated items - _, warn := cwhub.LocalSync(csConfig.Hub) - for _, v := range warn { - log.Info(v) - } - cwhub.DisplaySummary() - ListItems(color.Output, []string{ - cwhub.COLLECTIONS, cwhub.PARSERS, cwhub.SCENARIOS, cwhub.PARSERS_OVFLW, - }, args, true, false, all) - }, - } - cmdHubList.PersistentFlags().BoolVarP(&all, "all", "a", false, "List disabled items as well") - - return cmdHubList -} - -func NewHubUpdateCmd() *cobra.Command { - var cmdHubUpdate = &cobra.Command{ - Use: "update", - Short: "Fetch available configs from hub", - Long: ` -Fetches the [.index.json](https://github.com/crowdsecurity/hub/blob/master/.index.json) file from hub, containing the list of available configs. -`, - Args: cobra.ExactArgs(0), - DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if csConfig.Cscli == nil { - return fmt.Errorf("you must configure cli before interacting with hub") - } - - if err := cwhub.SetHubBranch(); err != nil { - return fmt.Errorf("error while setting hub branch: %s", err) - } - return nil - }, - Run: func(cmd *cobra.Command, args []string) { - if err := csConfig.LoadHub(); err != nil { - log.Fatal(err) - } - if err := cwhub.UpdateHubIdx(csConfig.Hub); err != nil { - if errors.Is(err, cwhub.ErrIndexNotFound) { - log.Warnf("Could not find index file for branch '%s', using 'master'", cwhub.HubBranch) - cwhub.HubBranch = "master" - if err := cwhub.UpdateHubIdx(csConfig.Hub); err != nil { - log.Fatalf("Failed to get Hub index after retry : %v", err) - } - } else { - log.Fatalf("Failed to get Hub index : %v", err) - } - } - //use LocalSync to get warnings about tainted / outdated items - _, warn := cwhub.LocalSync(csConfig.Hub) - for _, v := range warn { - log.Info(v) - } - }, - } - - return cmdHubUpdate -} - -func NewHubUpgradeCmd() *cobra.Command { - var cmdHubUpgrade = &cobra.Command{ - Use: "upgrade", - Short: "Upgrade all configs installed from hub", - Long: ` -Upgrade all configs installed from Crowdsec Hub. Run 'sudo cscli hub update' if you want the latest versions available. -`, - Args: cobra.ExactArgs(0), - DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if csConfig.Cscli == nil { - return fmt.Errorf("you must configure cli before interacting with hub") - } - - if err := cwhub.SetHubBranch(); err != nil { - return fmt.Errorf("error while setting hub branch: %s", err) - } - return nil - }, - Run: func(cmd *cobra.Command, args []string) { - if err := csConfig.LoadHub(); err != nil { - log.Fatal(err) - } - if err := cwhub.GetHubIdx(csConfig.Hub); err != nil { - log.Info("Run 'sudo cscli hub update' to get the hub index") - log.Fatalf("Failed to get Hub index : %v", err) - } - - log.Infof("Upgrading collections") - cwhub.UpgradeConfig(csConfig, cwhub.COLLECTIONS, "", forceAction) - log.Infof("Upgrading parsers") - cwhub.UpgradeConfig(csConfig, cwhub.PARSERS, "", forceAction) - log.Infof("Upgrading scenarios") - cwhub.UpgradeConfig(csConfig, cwhub.SCENARIOS, "", forceAction) - log.Infof("Upgrading postoverflows") - cwhub.UpgradeConfig(csConfig, cwhub.PARSERS_OVFLW, "", forceAction) - }, - } - cmdHubUpgrade.PersistentFlags().BoolVar(&forceAction, "force", false, "Force upgrade : Overwrite tainted and outdated files") - - return cmdHubUpgrade -} diff --git a/cmd/crowdsec-cli/hubtest.go b/cmd/crowdsec-cli/hubtest.go deleted file mode 100644 index 97bb8c8dd65..00000000000 --- a/cmd/crowdsec-cli/hubtest.go +++ /dev/null @@ -1,595 +0,0 @@ -package main - -import ( - "encoding/json" - "fmt" - "math" - "os" - "path/filepath" - "strings" - - "github.com/AlecAivazis/survey/v2" - "github.com/enescakir/emoji" - "github.com/fatih/color" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - "gopkg.in/yaml.v2" - - "github.com/crowdsecurity/crowdsec/pkg/hubtest" -) - -var ( - HubTest hubtest.HubTest -) - -func NewHubTestCmd() *cobra.Command { - var hubPath string - var crowdsecPath string - var cscliPath string - - var cmdHubTest = &cobra.Command{ - Use: "hubtest", - Short: "Run functional tests on hub configurations", - Long: "Run functional tests on hub configurations (parsers, scenarios, collections...)", - Args: cobra.ExactArgs(0), - DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - var err error - HubTest, err = hubtest.NewHubTest(hubPath, crowdsecPath, cscliPath) - if err != nil { - return fmt.Errorf("unable to load hubtest: %+v", err) - } - - return nil - }, - } - cmdHubTest.PersistentFlags().StringVar(&hubPath, "hub", ".", "Path to hub folder") - cmdHubTest.PersistentFlags().StringVar(&crowdsecPath, "crowdsec", "crowdsec", "Path to crowdsec") - cmdHubTest.PersistentFlags().StringVar(&cscliPath, "cscli", "cscli", "Path to cscli") - - cmdHubTest.AddCommand(NewHubTestCreateCmd()) - cmdHubTest.AddCommand(NewHubTestRunCmd()) - cmdHubTest.AddCommand(NewHubTestCleanCmd()) - cmdHubTest.AddCommand(NewHubTestInfoCmd()) - cmdHubTest.AddCommand(NewHubTestListCmd()) - cmdHubTest.AddCommand(NewHubTestCoverageCmd()) - cmdHubTest.AddCommand(NewHubTestEvalCmd()) - cmdHubTest.AddCommand(NewHubTestExplainCmd()) - - return cmdHubTest -} - - -func NewHubTestCreateCmd() *cobra.Command { - parsers := []string{} - postoverflows := []string{} - scenarios := []string{} - var ignoreParsers bool - var labels map[string]string - var logType string - - var cmdHubTestCreate = &cobra.Command{ - Use: "create", - Short: "create [test_name]", - Example: `cscli hubtest create my-awesome-test --type syslog -cscli hubtest create my-nginx-custom-test --type nginx -cscli hubtest create my-scenario-test --parsers crowdsecurity/nginx --scenarios crowdsecurity/http-probing`, - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - testName := args[0] - testPath := filepath.Join(HubTest.HubTestPath, testName) - if _, err := os.Stat(testPath); os.IsExist(err) { - return fmt.Errorf("test '%s' already exists in '%s', exiting", testName, testPath) - } - - if logType == "" { - return fmt.Errorf("please provide a type (--type) for the test") - } - - if err := os.MkdirAll(testPath, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %+v", testPath, err) - } - - // create empty log file - logFileName := fmt.Sprintf("%s.log", testName) - logFilePath := filepath.Join(testPath, logFileName) - logFile, err := os.Create(logFilePath) - if err != nil { - return err - } - logFile.Close() - - // create empty parser assertion file - parserAssertFilePath := filepath.Join(testPath, hubtest.ParserAssertFileName) - parserAssertFile, err := os.Create(parserAssertFilePath) - if err != nil { - return err - } - parserAssertFile.Close() - - // create empty scenario assertion file - scenarioAssertFilePath := filepath.Join(testPath, hubtest.ScenarioAssertFileName) - scenarioAssertFile, err := os.Create(scenarioAssertFilePath) - if err != nil { - return err - } - scenarioAssertFile.Close() - - parsers = append(parsers, "crowdsecurity/syslog-logs") - parsers = append(parsers, "crowdsecurity/dateparse-enrich") - - if len(scenarios) == 0 { - scenarios = append(scenarios, "") - } - - if len(postoverflows) == 0 { - postoverflows = append(postoverflows, "") - } - - configFileData := &hubtest.HubTestItemConfig{ - Parsers: parsers, - Scenarios: scenarios, - PostOVerflows: postoverflows, - LogFile: logFileName, - LogType: logType, - IgnoreParsers: ignoreParsers, - Labels: labels, - } - - configFilePath := filepath.Join(testPath, "config.yaml") - fd, err := os.OpenFile(configFilePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0666) - if err != nil { - return fmt.Errorf("open: %s", err) - } - data, err := yaml.Marshal(configFileData) - if err != nil { - return fmt.Errorf("marshal: %s", err) - } - _, err = fd.Write(data) - if err != nil { - return fmt.Errorf("write: %s", err) - } - if err := fd.Close(); err != nil { - return fmt.Errorf("close: %s", err) - } - fmt.Println() - fmt.Printf(" Test name : %s\n", testName) - fmt.Printf(" Test path : %s\n", testPath) - fmt.Printf(" Log file : %s (please fill it with logs)\n", logFilePath) - fmt.Printf(" Parser assertion file : %s (please fill it with assertion)\n", parserAssertFilePath) - fmt.Printf(" Scenario assertion file : %s (please fill it with assertion)\n", scenarioAssertFilePath) - fmt.Printf(" Configuration File : %s (please fill it with parsers, scenarios...)\n", configFilePath) - - return nil - }, - } - cmdHubTestCreate.PersistentFlags().StringVarP(&logType, "type", "t", "", "Log type of the test") - cmdHubTestCreate.Flags().StringSliceVarP(&parsers, "parsers", "p", parsers, "Parsers to add to test") - cmdHubTestCreate.Flags().StringSliceVar(&postoverflows, "postoverflows", postoverflows, "Postoverflows to add to test") - cmdHubTestCreate.Flags().StringSliceVarP(&scenarios, "scenarios", "s", scenarios, "Scenarios to add to test") - cmdHubTestCreate.PersistentFlags().BoolVar(&ignoreParsers, "ignore-parsers", false, "Don't run test on parsers") - - return cmdHubTestCreate -} - - -func NewHubTestRunCmd() *cobra.Command { - var noClean bool - var runAll bool - var forceClean bool - - var cmdHubTestRun = &cobra.Command{ - Use: "run", - Short: "run [test_name]", - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - if !runAll && len(args) == 0 { - printHelp(cmd) - return fmt.Errorf("Please provide test to run or --all flag") - } - - if runAll { - if err := HubTest.LoadAllTests(); err != nil { - return fmt.Errorf("unable to load all tests: %+v", err) - } - } else { - for _, testName := range args { - _, err := HubTest.LoadTestItem(testName) - if err != nil { - return fmt.Errorf("unable to load test '%s': %s", testName, err) - } - } - } - - for _, test := range HubTest.Tests { - if csConfig.Cscli.Output == "human" { - log.Infof("Running test '%s'", test.Name) - } - err := test.Run() - if err != nil { - log.Errorf("running test '%s' failed: %+v", test.Name, err) - } - } - - return nil - }, - PersistentPostRunE: func(cmd *cobra.Command, args []string) error { - success := true - testResult := make(map[string]bool) - for _, test := range HubTest.Tests { - if test.AutoGen { - if test.ParserAssert.AutoGenAssert { - log.Warningf("Assert file '%s' is empty, generating assertion:", test.ParserAssert.File) - fmt.Println() - fmt.Println(test.ParserAssert.AutoGenAssertData) - } - if test.ScenarioAssert.AutoGenAssert { - log.Warningf("Assert file '%s' is empty, generating assertion:", test.ScenarioAssert.File) - fmt.Println() - fmt.Println(test.ScenarioAssert.AutoGenAssertData) - } - if !noClean { - if err := test.Clean(); err != nil { - return fmt.Errorf("unable to clean test '%s' env: %s", test.Name, err) - } - } - fmt.Printf("\nPlease fill your assert file(s) for test '%s', exiting\n", test.Name) - os.Exit(1) - } - testResult[test.Name] = test.Success - if test.Success { - if csConfig.Cscli.Output == "human" { - log.Infof("Test '%s' passed successfully (%d assertions)\n", test.Name, test.ParserAssert.NbAssert+test.ScenarioAssert.NbAssert) - } - if !noClean { - if err := test.Clean(); err != nil { - return fmt.Errorf("unable to clean test '%s' env: %s", test.Name, err) - } - } - } else { - success = false - cleanTestEnv := false - if csConfig.Cscli.Output == "human" { - if len(test.ParserAssert.Fails) > 0 { - fmt.Println() - log.Errorf("Parser test '%s' failed (%d errors)\n", test.Name, len(test.ParserAssert.Fails)) - for _, fail := range test.ParserAssert.Fails { - fmt.Printf("(L.%d) %s => %s\n", fail.Line, emoji.RedCircle, fail.Expression) - fmt.Printf(" Actual expression values:\n") - for key, value := range fail.Debug { - fmt.Printf(" %s = '%s'\n", key, strings.TrimSuffix(value, "\n")) - } - fmt.Println() - } - } - if len(test.ScenarioAssert.Fails) > 0 { - fmt.Println() - log.Errorf("Scenario test '%s' failed (%d errors)\n", test.Name, len(test.ScenarioAssert.Fails)) - for _, fail := range test.ScenarioAssert.Fails { - fmt.Printf("(L.%d) %s => %s\n", fail.Line, emoji.RedCircle, fail.Expression) - fmt.Printf(" Actual expression values:\n") - for key, value := range fail.Debug { - fmt.Printf(" %s = '%s'\n", key, strings.TrimSuffix(value, "\n")) - } - fmt.Println() - } - } - if !forceClean && !noClean { - prompt := &survey.Confirm{ - Message: fmt.Sprintf("\nDo you want to remove runtime folder for test '%s'? (default: Yes)", test.Name), - Default: true, - } - if err := survey.AskOne(prompt, &cleanTestEnv); err != nil { - return fmt.Errorf("unable to ask to remove runtime folder: %s", err) - } - } - } - - if cleanTestEnv || forceClean { - if err := test.Clean(); err != nil { - return fmt.Errorf("unable to clean test '%s' env: %s", test.Name, err) - } - } - } - } - if csConfig.Cscli.Output == "human" { - hubTestResultTable(color.Output, testResult) - } else if csConfig.Cscli.Output == "json" { - jsonResult := make(map[string][]string, 0) - jsonResult["success"] = make([]string, 0) - jsonResult["fail"] = make([]string, 0) - for testName, success := range testResult { - if success { - jsonResult["success"] = append(jsonResult["success"], testName) - } else { - jsonResult["fail"] = append(jsonResult["fail"], testName) - } - } - jsonStr, err := json.Marshal(jsonResult) - if err != nil { - return fmt.Errorf("unable to json test result: %s", err) - } - fmt.Println(string(jsonStr)) - } - - if !success { - os.Exit(1) - } - - return nil - }, - } - cmdHubTestRun.Flags().BoolVar(&noClean, "no-clean", false, "Don't clean runtime environment if test succeed") - cmdHubTestRun.Flags().BoolVar(&forceClean, "clean", false, "Clean runtime environment if test fail") - cmdHubTestRun.Flags().BoolVar(&runAll, "all", false, "Run all tests") - - return cmdHubTestRun -} - - -func NewHubTestCleanCmd() *cobra.Command { - var cmdHubTestClean = &cobra.Command{ - Use: "clean", - Short: "clean [test_name]", - Args: cobra.MinimumNArgs(1), - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - for _, testName := range args { - test, err := HubTest.LoadTestItem(testName) - if err != nil { - return fmt.Errorf("unable to load test '%s': %s", testName, err) - } - if err := test.Clean(); err != nil { - return fmt.Errorf("unable to clean test '%s' env: %s", test.Name, err) - } - } - - return nil - }, - } - - return cmdHubTestClean -} - - -func NewHubTestInfoCmd() *cobra.Command { - var cmdHubTestInfo = &cobra.Command{ - Use: "info", - Short: "info [test_name]", - Args: cobra.MinimumNArgs(1), - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - for _, testName := range args { - test, err := HubTest.LoadTestItem(testName) - if err != nil { - return fmt.Errorf("unable to load test '%s': %s", testName, err) - } - fmt.Println() - fmt.Printf(" Test name : %s\n", test.Name) - fmt.Printf(" Test path : %s\n", test.Path) - fmt.Printf(" Log file : %s\n", filepath.Join(test.Path, test.Config.LogFile)) - fmt.Printf(" Parser assertion file : %s\n", filepath.Join(test.Path, hubtest.ParserAssertFileName)) - fmt.Printf(" Scenario assertion file : %s\n", filepath.Join(test.Path, hubtest.ScenarioAssertFileName)) - fmt.Printf(" Configuration File : %s\n", filepath.Join(test.Path, "config.yaml")) - } - - return nil - }, - } - - return cmdHubTestInfo -} - - -func NewHubTestListCmd() *cobra.Command { - var cmdHubTestList = &cobra.Command{ - Use: "list", - Short: "list", - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - if err := HubTest.LoadAllTests(); err != nil { - return fmt.Errorf("unable to load all tests: %s", err) - } - - switch csConfig.Cscli.Output { - case "human": - hubTestListTable(color.Output, HubTest.Tests) - case "json": - j, err := json.MarshalIndent(HubTest.Tests, " ", " ") - if err != nil { - return err - } - fmt.Println(string(j)) - default: - return fmt.Errorf("only human/json output modes are supported") - } - - return nil - }, - } - - return cmdHubTestList -} - - -func NewHubTestCoverageCmd() *cobra.Command { - var showParserCov bool - var showScenarioCov bool - var showOnlyPercent bool - - var cmdHubTestCoverage = &cobra.Command{ - Use: "coverage", - Short: "coverage", - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - if err := HubTest.LoadAllTests(); err != nil { - return fmt.Errorf("unable to load all tests: %+v", err) - } - var err error - scenarioCoverage := []hubtest.ScenarioCoverage{} - parserCoverage := []hubtest.ParserCoverage{} - scenarioCoveragePercent := 0 - parserCoveragePercent := 0 - - // if both are false (flag by default), show both - showAll := !showScenarioCov && !showParserCov - - if showParserCov || showAll { - parserCoverage, err = HubTest.GetParsersCoverage() - if err != nil { - return fmt.Errorf("while getting parser coverage: %s", err) - } - parserTested := 0 - for _, test := range parserCoverage { - if test.TestsCount > 0 { - parserTested += 1 - } - } - parserCoveragePercent = int(math.Round((float64(parserTested) / float64(len(parserCoverage)) * 100))) - } - - if showScenarioCov || showAll { - scenarioCoverage, err = HubTest.GetScenariosCoverage() - if err != nil { - return fmt.Errorf("while getting scenario coverage: %s", err) - } - scenarioTested := 0 - for _, test := range scenarioCoverage { - if test.TestsCount > 0 { - scenarioTested += 1 - } - } - scenarioCoveragePercent = int(math.Round((float64(scenarioTested) / float64(len(scenarioCoverage)) * 100))) - } - - if showOnlyPercent { - if showAll { - fmt.Printf("parsers=%d%%\nscenarios=%d%%", parserCoveragePercent, scenarioCoveragePercent) - } else if showParserCov { - fmt.Printf("parsers=%d%%", parserCoveragePercent) - } else if showScenarioCov { - fmt.Printf("scenarios=%d%%", scenarioCoveragePercent) - } - os.Exit(0) - } - - if csConfig.Cscli.Output == "human" { - if showParserCov || showAll { - hubTestParserCoverageTable(color.Output, parserCoverage) - } - - if showScenarioCov || showAll { - hubTestScenarioCoverageTable(color.Output, scenarioCoverage) - } - fmt.Println() - if showParserCov || showAll { - fmt.Printf("PARSERS : %d%% of coverage\n", parserCoveragePercent) - } - if showScenarioCov || showAll { - fmt.Printf("SCENARIOS : %d%% of coverage\n", scenarioCoveragePercent) - } - } else if csConfig.Cscli.Output == "json" { - dump, err := json.MarshalIndent(parserCoverage, "", " ") - if err != nil { - return err - } - fmt.Printf("%s", dump) - dump, err = json.MarshalIndent(scenarioCoverage, "", " ") - if err != nil { - return err - } - fmt.Printf("%s", dump) - } else { - return fmt.Errorf("only human/json output modes are supported") - } - - return nil - }, - } - cmdHubTestCoverage.PersistentFlags().BoolVar(&showOnlyPercent, "percent", false, "Show only percentages of coverage") - cmdHubTestCoverage.PersistentFlags().BoolVar(&showParserCov, "parsers", false, "Show only parsers coverage") - cmdHubTestCoverage.PersistentFlags().BoolVar(&showScenarioCov, "scenarios", false, "Show only scenarios coverage") - - return cmdHubTestCoverage -} - - -func NewHubTestEvalCmd() *cobra.Command { - var evalExpression string - var cmdHubTestEval = &cobra.Command{ - Use: "eval", - Short: "eval [test_name]", - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - for _, testName := range args { - test, err := HubTest.LoadTestItem(testName) - if err != nil { - return fmt.Errorf("can't load test: %+v", err) - } - err = test.ParserAssert.LoadTest(test.ParserResultFile) - if err != nil { - return fmt.Errorf("can't load test results from '%s': %+v", test.ParserResultFile, err) - } - output, err := test.ParserAssert.EvalExpression(evalExpression) - if err != nil { - return err - } - fmt.Print(output) - } - - return nil - }, - } - cmdHubTestEval.PersistentFlags().StringVarP(&evalExpression, "expr", "e", "", "Expression to eval") - - return cmdHubTestEval -} - - -func NewHubTestExplainCmd() *cobra.Command { - var cmdHubTestExplain = &cobra.Command{ - Use: "explain", - Short: "explain [test_name]", - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - for _, testName := range args { - test, err := HubTest.LoadTestItem(testName) - if err != nil { - return fmt.Errorf("can't load test: %+v", err) - } - err = test.ParserAssert.LoadTest(test.ParserResultFile) - if err != nil { - err := test.Run() - if err != nil { - return fmt.Errorf("running test '%s' failed: %+v", test.Name, err) - } - err = test.ParserAssert.LoadTest(test.ParserResultFile) - if err != nil { - return fmt.Errorf("unable to load parser result after run: %s", err) - } - } - - err = test.ScenarioAssert.LoadTest(test.ScenarioResultFile, test.BucketPourResultFile) - if err != nil { - err := test.Run() - if err != nil { - return fmt.Errorf("running test '%s' failed: %+v", test.Name, err) - } - err = test.ScenarioAssert.LoadTest(test.ScenarioResultFile, test.BucketPourResultFile) - if err != nil { - return fmt.Errorf("unable to load scenario result after run: %s", err) - } - } - opts := hubtest.DumpOpts{} - hubtest.DumpTree(*test.ParserAssert.TestData, *test.ScenarioAssert.PourData, opts) - } - - return nil - }, - } - - return cmdHubTestExplain -} diff --git a/cmd/crowdsec-cli/hubtest_table.go b/cmd/crowdsec-cli/hubtest_table.go deleted file mode 100644 index 9f28c36992d..00000000000 --- a/cmd/crowdsec-cli/hubtest_table.go +++ /dev/null @@ -1,80 +0,0 @@ -package main - -import ( - "fmt" - "io" - - "github.com/aquasecurity/table" - "github.com/enescakir/emoji" - - "github.com/crowdsecurity/crowdsec/pkg/hubtest" -) - -func hubTestResultTable(out io.Writer, testResult map[string]bool) { - t := newLightTable(out) - t.SetHeaders("Test", "Result") - t.SetHeaderAlignment(table.AlignLeft) - t.SetAlignment(table.AlignLeft) - - for testName, success := range testResult { - status := emoji.CheckMarkButton.String() - if !success { - status = emoji.CrossMark.String() - } - - t.AddRow(testName, status) - } - - t.Render() -} - -func hubTestListTable(out io.Writer, tests []*hubtest.HubTestItem) { - t := newLightTable(out) - t.SetHeaders("Name", "Path") - t.SetHeaderAlignment(table.AlignLeft, table.AlignLeft) - t.SetAlignment(table.AlignLeft, table.AlignLeft) - - for _, test := range tests { - t.AddRow(test.Name, test.Path) - } - - t.Render() -} - -func hubTestParserCoverageTable(out io.Writer, coverage []hubtest.ParserCoverage) { - t := newLightTable(out) - t.SetHeaders("Parser", "Status", "Number of tests") - t.SetHeaderAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft) - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft) - - parserTested := 0 - for _, test := range coverage { - status := emoji.RedCircle.String() - if test.TestsCount > 0 { - status = emoji.GreenCircle.String() - parserTested++ - } - t.AddRow(test.Parser, status, fmt.Sprintf("%d times (across %d tests)", test.TestsCount, len(test.PresentIn))) - } - - t.Render() -} - -func hubTestScenarioCoverageTable(out io.Writer, coverage []hubtest.ScenarioCoverage) { - t := newLightTable(out) - t.SetHeaders("Scenario", "Status", "Number of tests") - t.SetHeaderAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft) - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft) - - parserTested := 0 - for _, test := range coverage { - status := emoji.RedCircle.String() - if test.TestsCount > 0 { - status = emoji.GreenCircle.String() - parserTested++ - } - t.AddRow(test.Scenario, status, fmt.Sprintf("%d times (across %d tests)", test.TestsCount, len(test.PresentIn))) - } - - t.Render() -} diff --git a/cmd/crowdsec-cli/idgen/machineid.go b/cmd/crowdsec-cli/idgen/machineid.go new file mode 100644 index 00000000000..4bd356b3abc --- /dev/null +++ b/cmd/crowdsec-cli/idgen/machineid.go @@ -0,0 +1,48 @@ +package idgen + +import ( + "fmt" + "strings" + + "github.com/google/uuid" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/machineid" +) + +// Returns a unique identifier for each crowdsec installation, using an +// identifier of the OS installation where available, otherwise a random +// string. +func generateMachineIDPrefix() (string, error) { + prefix, err := machineid.ID() + if err == nil { + return prefix, nil + } + + log.Debugf("failed to get machine-id with usual files: %s", err) + + bID, err := uuid.NewRandom() + if err == nil { + return bID.String(), nil + } + + return "", fmt.Errorf("generating machine id: %w", err) +} + +// Generate a unique identifier, composed by a prefix and a random suffix. +// The prefix can be provided by a parameter to use in test environments. +func GenerateMachineID(prefix string) (string, error) { + var err error + if prefix == "" { + prefix, err = generateMachineIDPrefix() + } + + if err != nil { + return "", err + } + + prefix = strings.ReplaceAll(prefix, "-", "")[:32] + suffix := GeneratePassword(16) + + return prefix + suffix, nil +} diff --git a/cmd/crowdsec-cli/idgen/password.go b/cmd/crowdsec-cli/idgen/password.go new file mode 100644 index 00000000000..e0faa4daacc --- /dev/null +++ b/cmd/crowdsec-cli/idgen/password.go @@ -0,0 +1,32 @@ +package idgen + +import ( + saferand "crypto/rand" + "math/big" + + log "github.com/sirupsen/logrus" +) + +const PasswordLength = 64 + +func GeneratePassword(length int) string { + upper := "ABCDEFGHIJKLMNOPQRSTUVWXY" + lower := "abcdefghijklmnopqrstuvwxyz" + digits := "0123456789" + + charset := upper + lower + digits + charsetLength := len(charset) + + buf := make([]byte, length) + + for i := range length { + rInt, err := saferand.Int(saferand.Reader, big.NewInt(int64(charsetLength))) + if err != nil { + log.Fatalf("failed getting data from prng for password generation : %s", err) + } + + buf[i] = charset[rInt.Int64()] + } + + return string(buf) +} diff --git a/cmd/crowdsec-cli/lapi.go b/cmd/crowdsec-cli/lapi.go deleted file mode 100644 index e4353ac19f5..00000000000 --- a/cmd/crowdsec-cli/lapi.go +++ /dev/null @@ -1,560 +0,0 @@ -package main - -import ( - "context" - "fmt" - "net/url" - "os" - "sort" - "strings" - - "github.com/go-openapi/strfmt" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - "golang.org/x/exp/slices" - "gopkg.in/yaml.v2" - - "github.com/crowdsecurity/go-cs-lib/pkg/version" - - "github.com/crowdsecurity/crowdsec/pkg/alertcontext" - "github.com/crowdsecurity/crowdsec/pkg/apiclient" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" - "github.com/crowdsecurity/crowdsec/pkg/cwhub" - "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" - "github.com/crowdsecurity/crowdsec/pkg/models" - "github.com/crowdsecurity/crowdsec/pkg/parser" -) - -var LAPIURLPrefix string = "v1" - -func runLapiStatus(cmd *cobra.Command, args []string) error { - var err error - - password := strfmt.Password(csConfig.API.Client.Credentials.Password) - apiurl, err := url.Parse(csConfig.API.Client.Credentials.URL) - login := csConfig.API.Client.Credentials.Login - if err != nil { - log.Fatalf("parsing api url ('%s'): %s", apiurl, err) - } - if err := csConfig.LoadHub(); err != nil { - log.Fatal(err) - } - - if err := cwhub.GetHubIdx(csConfig.Hub); err != nil { - log.Info("Run 'sudo cscli hub update' to get the hub index") - log.Fatalf("Failed to load hub index : %s", err) - } - scenarios, err := cwhub.GetInstalledScenariosAsString() - if err != nil { - log.Fatalf("failed to get scenarios : %s", err) - } - - Client, err = apiclient.NewDefaultClient(apiurl, - LAPIURLPrefix, - fmt.Sprintf("crowdsec/%s", version.String()), - nil) - if err != nil { - log.Fatalf("init default client: %s", err) - } - t := models.WatcherAuthRequest{ - MachineID: &login, - Password: &password, - Scenarios: scenarios, - } - log.Infof("Loaded credentials from %s", csConfig.API.Client.CredentialsFilePath) - log.Infof("Trying to authenticate with username %s on %s", login, apiurl) - _, _, err = Client.Auth.AuthenticateWatcher(context.Background(), t) - if err != nil { - log.Fatalf("Failed to authenticate to Local API (LAPI) : %s", err) - } else { - log.Infof("You can successfully interact with Local API (LAPI)") - } - - return nil -} - -func runLapiRegister(cmd *cobra.Command, args []string) error { - var err error - - flags := cmd.Flags() - - apiURL, err := flags.GetString("url") - if err != nil { - return err - } - - outputFile, err := flags.GetString("file") - if err != nil { - return err - } - - lapiUser, err := flags.GetString("machine") - if err != nil { - return err - } - - if lapiUser == "" { - lapiUser, err = generateID("") - if err != nil { - log.Fatalf("unable to generate machine id: %s", err) - } - } - password := strfmt.Password(generatePassword(passwordLength)) - if apiURL == "" { - if csConfig.API.Client != nil && csConfig.API.Client.Credentials != nil && csConfig.API.Client.Credentials.URL != "" { - apiURL = csConfig.API.Client.Credentials.URL - } else { - log.Fatalf("No Local API URL. Please provide it in your configuration or with the -u parameter") - } - } - /*URL needs to end with /, but user doesn't care*/ - if !strings.HasSuffix(apiURL, "/") { - apiURL += "/" - } - /*URL needs to start with http://, but user doesn't care*/ - if !strings.HasPrefix(apiURL, "http://") && !strings.HasPrefix(apiURL, "https://") { - apiURL = "http://" + apiURL - } - apiurl, err := url.Parse(apiURL) - if err != nil { - log.Fatalf("parsing api url: %s", err) - } - _, err = apiclient.RegisterClient(&apiclient.Config{ - MachineID: lapiUser, - Password: password, - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), - URL: apiurl, - VersionPrefix: LAPIURLPrefix, - }, nil) - - if err != nil { - log.Fatalf("api client register: %s", err) - } - - log.Printf("Successfully registered to Local API (LAPI)") - - var dumpFile string - if outputFile != "" { - dumpFile = outputFile - } else if csConfig.API.Client.CredentialsFilePath != "" { - dumpFile = csConfig.API.Client.CredentialsFilePath - } else { - dumpFile = "" - } - apiCfg := csconfig.ApiCredentialsCfg{ - Login: lapiUser, - Password: password.String(), - URL: apiURL, - } - apiConfigDump, err := yaml.Marshal(apiCfg) - if err != nil { - log.Fatalf("unable to marshal api credentials: %s", err) - } - if dumpFile != "" { - err = os.WriteFile(dumpFile, apiConfigDump, 0644) - if err != nil { - log.Fatalf("write api credentials in '%s' failed: %s", dumpFile, err) - } - log.Printf("Local API credentials dumped to '%s'", dumpFile) - } else { - fmt.Printf("%s\n", string(apiConfigDump)) - } - log.Warning(ReloadMessage()) - - return nil -} - -func NewLapiStatusCmd() *cobra.Command { - cmdLapiStatus := &cobra.Command{ - Use: "status", - Short: "Check authentication to Local API (LAPI)", - Args: cobra.MinimumNArgs(0), - DisableAutoGenTag: true, - RunE: runLapiStatus, - } - - return cmdLapiStatus -} - -func NewLapiRegisterCmd() *cobra.Command { - cmdLapiRegister := &cobra.Command{ - Use: "register", - Short: "Register a machine to Local API (LAPI)", - Long: `Register your machine to the Local API (LAPI). -Keep in mind the machine needs to be validated by an administrator on LAPI side to be effective.`, - Args: cobra.MinimumNArgs(0), - DisableAutoGenTag: true, - RunE: runLapiRegister, - } - - flags := cmdLapiRegister.Flags() - flags.StringP("url", "u", "", "URL of the API (ie. http://127.0.0.1)") - flags.StringP("file", "f", "", "output file destination") - flags.String("machine", "", "Name of the machine to register with") - - return cmdLapiRegister -} - -func NewLapiCmd() *cobra.Command { - var cmdLapi = &cobra.Command{ - Use: "lapi [action]", - Short: "Manage interaction with Local API (LAPI)", - Args: cobra.MinimumNArgs(1), - DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if err := csConfig.LoadAPIClient(); err != nil { - return fmt.Errorf("loading api client: %w", err) - } - return nil - }, - } - - cmdLapi.AddCommand(NewLapiRegisterCmd()) - cmdLapi.AddCommand(NewLapiStatusCmd()) - cmdLapi.AddCommand(NewLapiContextCmd()) - - return cmdLapi -} - -func NewLapiContextCmd() *cobra.Command { - cmdContext := &cobra.Command{ - Use: "context [command]", - Short: "Manage context to send with alerts", - DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if err := csConfig.LoadCrowdsec(); err != nil { - fileNotFoundMessage := fmt.Sprintf("failed to open context file: open %s: no such file or directory", csConfig.Crowdsec.ConsoleContextPath) - if err.Error() != fileNotFoundMessage { - log.Fatalf("Unable to load CrowdSec Agent: %s", err) - } - } - if csConfig.DisableAgent { - log.Fatalf("Agent is disabled and lapi context can only be used on the agent") - } - - return nil - }, - Run: func(cmd *cobra.Command, args []string) { - printHelp(cmd) - }, - } - - var keyToAdd string - var valuesToAdd []string - cmdContextAdd := &cobra.Command{ - Use: "add", - Short: "Add context to send with alerts. You must specify the output key with the expr value you want", - Example: `cscli lapi context add --key source_ip --value evt.Meta.source_ip -cscli lapi context add --key file_source --value evt.Line.Src - `, - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - if err := alertcontext.ValidateContextExpr(keyToAdd, valuesToAdd); err != nil { - log.Fatalf("invalid context configuration :%s", err) - } - if _, ok := csConfig.Crowdsec.ContextToSend[keyToAdd]; !ok { - csConfig.Crowdsec.ContextToSend[keyToAdd] = make([]string, 0) - log.Infof("key '%s' added", keyToAdd) - } - data := csConfig.Crowdsec.ContextToSend[keyToAdd] - for _, val := range valuesToAdd { - if !slices.Contains(data, val) { - log.Infof("value '%s' added to key '%s'", val, keyToAdd) - data = append(data, val) - } - csConfig.Crowdsec.ContextToSend[keyToAdd] = data - } - if err := csConfig.Crowdsec.DumpContextConfigFile(); err != nil { - log.Fatalf(err.Error()) - } - }, - } - cmdContextAdd.Flags().StringVarP(&keyToAdd, "key", "k", "", "The key of the different values to send") - cmdContextAdd.Flags().StringSliceVar(&valuesToAdd, "value", []string{}, "The expr fields to associate with the key") - cmdContextAdd.MarkFlagRequired("key") - cmdContextAdd.MarkFlagRequired("value") - cmdContext.AddCommand(cmdContextAdd) - - cmdContextStatus := &cobra.Command{ - Use: "status", - Short: "List context to send with alerts", - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - if len(csConfig.Crowdsec.ContextToSend) == 0 { - fmt.Println("No context found on this agent. You can use 'cscli lapi context add' to add context to your alerts.") - return - } - - dump, err := yaml.Marshal(csConfig.Crowdsec.ContextToSend) - if err != nil { - log.Fatalf("unable to show context status: %s", err) - } - - fmt.Println(string(dump)) - - }, - } - cmdContext.AddCommand(cmdContextStatus) - - var detectAll bool - cmdContextDetect := &cobra.Command{ - Use: "detect", - Short: "Detect available fields from the installed parsers", - Example: `cscli lapi context detect --all -cscli lapi context detect crowdsecurity/sshd-logs - `, - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - var err error - - if !detectAll && len(args) == 0 { - log.Infof("Please provide parsers to detect or --all flag.") - printHelp(cmd) - } - - // to avoid all the log.Info from the loaders functions - log.SetLevel(log.ErrorLevel) - - err = exprhelpers.Init(nil) - if err != nil { - log.Fatalf("Failed to init expr helpers : %s", err) - } - - // Populate cwhub package tools - if err := cwhub.GetHubIdx(csConfig.Hub); err != nil { - log.Fatalf("Failed to load hub index : %s", err) - } - - csParsers := parser.NewParsers() - if csParsers, err = parser.LoadParsers(csConfig, csParsers); err != nil { - log.Fatalf("unable to load parsers: %s", err) - } - - fieldByParsers := make(map[string][]string) - for _, node := range csParsers.Nodes { - if !detectAll && !slices.Contains(args, node.Name) { - continue - } - if !detectAll { - args = removeFromSlice(node.Name, args) - } - fieldByParsers[node.Name] = make([]string, 0) - fieldByParsers[node.Name] = detectNode(node, *csParsers.Ctx) - - subNodeFields := detectSubNode(node, *csParsers.Ctx) - for _, field := range subNodeFields { - if !slices.Contains(fieldByParsers[node.Name], field) { - fieldByParsers[node.Name] = append(fieldByParsers[node.Name], field) - } - } - - } - - fmt.Printf("Acquisition :\n\n") - fmt.Printf(" - evt.Line.Module\n") - fmt.Printf(" - evt.Line.Raw\n") - fmt.Printf(" - evt.Line.Src\n") - fmt.Println() - - parsersKey := make([]string, 0) - for k := range fieldByParsers { - parsersKey = append(parsersKey, k) - } - sort.Strings(parsersKey) - - for _, k := range parsersKey { - if len(fieldByParsers[k]) == 0 { - continue - } - fmt.Printf("%s :\n\n", k) - values := fieldByParsers[k] - sort.Strings(values) - for _, value := range values { - fmt.Printf(" - %s\n", value) - } - fmt.Println() - } - - if len(args) > 0 { - for _, parserNotFound := range args { - log.Errorf("parser '%s' not found, can't detect fields", parserNotFound) - } - } - }, - } - cmdContextDetect.Flags().BoolVarP(&detectAll, "all", "a", false, "Detect evt field for all installed parser") - cmdContext.AddCommand(cmdContextDetect) - - var keysToDelete []string - var valuesToDelete []string - cmdContextDelete := &cobra.Command{ - Use: "delete", - Short: "Delete context to send with alerts", - Example: `cscli lapi context delete --key source_ip -cscli lapi context delete --value evt.Line.Src - `, - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - if len(keysToDelete) == 0 && len(valuesToDelete) == 0 { - log.Fatalf("please provide at least a key or a value to delete") - } - - for _, key := range keysToDelete { - if _, ok := csConfig.Crowdsec.ContextToSend[key]; ok { - delete(csConfig.Crowdsec.ContextToSend, key) - log.Infof("key '%s' has been removed", key) - } else { - log.Warningf("key '%s' doesn't exist", key) - } - } - - for _, value := range valuesToDelete { - valueFound := false - for key, context := range csConfig.Crowdsec.ContextToSend { - if slices.Contains(context, value) { - valueFound = true - csConfig.Crowdsec.ContextToSend[key] = removeFromSlice(value, context) - log.Infof("value '%s' has been removed from key '%s'", value, key) - } - if len(csConfig.Crowdsec.ContextToSend[key]) == 0 { - delete(csConfig.Crowdsec.ContextToSend, key) - } - } - if !valueFound { - log.Warningf("value '%s' not found", value) - } - } - - if err := csConfig.Crowdsec.DumpContextConfigFile(); err != nil { - log.Fatalf(err.Error()) - } - - }, - } - cmdContextDelete.Flags().StringSliceVarP(&keysToDelete, "key", "k", []string{}, "The keys to delete") - cmdContextDelete.Flags().StringSliceVar(&valuesToDelete, "value", []string{}, "The expr fields to delete") - cmdContext.AddCommand(cmdContextDelete) - - return cmdContext -} - -func detectStaticField(GrokStatics []parser.ExtraField) []string { - ret := make([]string, 0) - for _, static := range GrokStatics { - if static.Parsed != "" { - fieldName := fmt.Sprintf("evt.Parsed.%s", static.Parsed) - if !slices.Contains(ret, fieldName) { - ret = append(ret, fieldName) - } - } - if static.Meta != "" { - fieldName := fmt.Sprintf("evt.Meta.%s", static.Meta) - if !slices.Contains(ret, fieldName) { - ret = append(ret, fieldName) - } - } - if static.TargetByName != "" { - fieldName := static.TargetByName - if !strings.HasPrefix(fieldName, "evt.") { - fieldName = "evt." + fieldName - } - if !slices.Contains(ret, fieldName) { - ret = append(ret, fieldName) - } - } - } - - return ret -} - -func detectNode(node parser.Node, parserCTX parser.UnixParserCtx) []string { - var ret = make([]string, 0) - if node.Grok.RunTimeRegexp != nil { - for _, capturedField := range node.Grok.RunTimeRegexp.Names() { - fieldName := fmt.Sprintf("evt.Parsed.%s", capturedField) - if !slices.Contains(ret, fieldName) { - ret = append(ret, fieldName) - } - } - } - - if node.Grok.RegexpName != "" { - grokCompiled, err := parserCTX.Grok.Get(node.Grok.RegexpName) - if err != nil { - log.Warningf("Can't get subgrok: %s", err) - } - for _, capturedField := range grokCompiled.Names() { - fieldName := fmt.Sprintf("evt.Parsed.%s", capturedField) - if !slices.Contains(ret, fieldName) { - ret = append(ret, fieldName) - } - } - } - - if len(node.Grok.Statics) > 0 { - staticsField := detectStaticField(node.Grok.Statics) - for _, staticField := range staticsField { - if !slices.Contains(ret, staticField) { - ret = append(ret, staticField) - } - } - } - - if len(node.Statics) > 0 { - staticsField := detectStaticField(node.Statics) - for _, staticField := range staticsField { - if !slices.Contains(ret, staticField) { - ret = append(ret, staticField) - } - } - } - - return ret -} - -func detectSubNode(node parser.Node, parserCTX parser.UnixParserCtx) []string { - var ret = make([]string, 0) - - for _, subnode := range node.LeavesNodes { - if subnode.Grok.RunTimeRegexp != nil { - for _, capturedField := range subnode.Grok.RunTimeRegexp.Names() { - fieldName := fmt.Sprintf("evt.Parsed.%s", capturedField) - if !slices.Contains(ret, fieldName) { - ret = append(ret, fieldName) - } - } - } - if subnode.Grok.RegexpName != "" { - grokCompiled, err := parserCTX.Grok.Get(subnode.Grok.RegexpName) - if err != nil { - log.Warningf("Can't get subgrok: %s", err) - } - for _, capturedField := range grokCompiled.Names() { - fieldName := fmt.Sprintf("evt.Parsed.%s", capturedField) - if !slices.Contains(ret, fieldName) { - ret = append(ret, fieldName) - } - } - } - - if len(subnode.Grok.Statics) > 0 { - staticsField := detectStaticField(subnode.Grok.Statics) - for _, staticField := range staticsField { - if !slices.Contains(ret, staticField) { - ret = append(ret, staticField) - } - } - } - - if len(subnode.Statics) > 0 { - staticsField := detectStaticField(subnode.Statics) - for _, staticField := range staticsField { - if !slices.Contains(ret, staticField) { - ret = append(ret, staticField) - } - } - } - } - - return ret -} diff --git a/cmd/crowdsec-cli/machines.go b/cmd/crowdsec-cli/machines.go deleted file mode 100644 index 21594310230..00000000000 --- a/cmd/crowdsec-cli/machines.go +++ /dev/null @@ -1,431 +0,0 @@ -package main - -import ( - saferand "crypto/rand" - "encoding/csv" - "encoding/json" - "fmt" - "io" - "math/big" - "os" - "strings" - "time" - - "github.com/AlecAivazis/survey/v2" - "github.com/fatih/color" - "github.com/go-openapi/strfmt" - "github.com/google/uuid" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - "golang.org/x/exp/slices" - "gopkg.in/yaml.v2" - - "github.com/crowdsecurity/machineid" - - "github.com/crowdsecurity/crowdsec/pkg/csconfig" - "github.com/crowdsecurity/crowdsec/pkg/database" - "github.com/crowdsecurity/crowdsec/pkg/database/ent" - "github.com/crowdsecurity/crowdsec/pkg/types" -) - -var ( - passwordLength = 64 -) - -func generatePassword(length int) string { - upper := "ABCDEFGHIJKLMNOPQRSTUVWXY" - lower := "abcdefghijklmnopqrstuvwxyz" - digits := "0123456789" - - charset := upper + lower + digits - charsetLength := len(charset) - - buf := make([]byte, length) - for i := 0; i < length; i++ { - rInt, err := saferand.Int(saferand.Reader, big.NewInt(int64(charsetLength))) - if err != nil { - log.Fatalf("failed getting data from prng for password generation : %s", err) - } - buf[i] = charset[rInt.Int64()] - } - - return string(buf) -} - -// Returns a unique identifier for each crowdsec installation, using an -// identifier of the OS installation where available, otherwise a random -// string. -func generateIDPrefix() (string, error) { - prefix, err := machineid.ID() - if err == nil { - return prefix, nil - } - log.Debugf("failed to get machine-id with usual files: %s", err) - - bId, err := uuid.NewRandom() - if err == nil { - return bId.String(), nil - } - return "", fmt.Errorf("generating machine id: %w", err) -} - -// Generate a unique identifier, composed by a prefix and a random suffix. -// The prefix can be provided by a parameter to use in test environments. -func generateID(prefix string) (string, error) { - var err error - if prefix == "" { - prefix, err = generateIDPrefix() - } - if err != nil { - return "", err - } - prefix = strings.ReplaceAll(prefix, "-", "")[:32] - suffix := generatePassword(16) - return prefix + suffix, nil -} - -// getLastHeartbeat returns the last heartbeat timestamp of a machine -// and a boolean indicating if the machine is considered active or not. -func getLastHeartbeat(m *ent.Machine) (string, bool) { - if m.LastHeartbeat == nil { - return "-", false - } - - elapsed := time.Now().UTC().Sub(*m.LastHeartbeat) - - hb := elapsed.Truncate(time.Second).String() - if elapsed > 2*time.Minute { - return hb, false - } - - return hb, true -} - -func getAgents(out io.Writer, dbClient *database.Client) error { - machines, err := dbClient.ListMachines() - if err != nil { - return fmt.Errorf("unable to list machines: %s", err) - } - if csConfig.Cscli.Output == "human" { - getAgentsTable(out, machines) - } else if csConfig.Cscli.Output == "json" { - enc := json.NewEncoder(out) - enc.SetIndent("", " ") - if err := enc.Encode(machines); err != nil { - return fmt.Errorf("failed to marshal") - } - return nil - } else if csConfig.Cscli.Output == "raw" { - csvwriter := csv.NewWriter(out) - err := csvwriter.Write([]string{"machine_id", "ip_address", "updated_at", "validated", "version", "auth_type", "last_heartbeat"}) - if err != nil { - return fmt.Errorf("failed to write header: %s", err) - } - for _, m := range machines { - var validated string - if m.IsValidated { - validated = "true" - } else { - validated = "false" - } - hb, _ := getLastHeartbeat(m) - err := csvwriter.Write([]string{m.MachineId, m.IpAddress, m.UpdatedAt.Format(time.RFC3339), validated, m.Version, m.AuthType, hb}) - if err != nil { - return fmt.Errorf("failed to write raw output: %w", err) - } - } - csvwriter.Flush() - } else { - log.Errorf("unknown output '%s'", csConfig.Cscli.Output) - } - return nil -} - -func NewMachinesListCmd() *cobra.Command { - cmdMachinesList := &cobra.Command{ - Use: "list", - Short: "List machines", - Long: `List `, - Example: `cscli machines list`, - Args: cobra.MaximumNArgs(1), - DisableAutoGenTag: true, - PreRunE: func(cmd *cobra.Command, args []string) error { - var err error - dbClient, err = database.NewClient(csConfig.DbConfig) - if err != nil { - return fmt.Errorf("unable to create new database client: %s", err) - } - - return nil - }, - RunE: func(cmd *cobra.Command, args []string) error { - err := getAgents(color.Output, dbClient) - if err != nil { - return fmt.Errorf("unable to list machines: %s", err) - } - - return nil - }, - } - - return cmdMachinesList -} - -func NewMachinesAddCmd() *cobra.Command { - cmdMachinesAdd := &cobra.Command{ - Use: "add", - Short: "add machine to the database.", - DisableAutoGenTag: true, - Long: `Register a new machine in the database. cscli should be on the same machine as LAPI.`, - Example: ` -cscli machines add --auto -cscli machines add MyTestMachine --auto -cscli machines add MyTestMachine --password MyPassword -`, - PreRunE: func(cmd *cobra.Command, args []string) error { - var err error - dbClient, err = database.NewClient(csConfig.DbConfig) - if err != nil { - return fmt.Errorf("unable to create new database client: %s", err) - } - - return nil - }, - RunE: runMachinesAdd, - } - - flags := cmdMachinesAdd.Flags() - flags.StringP("password", "p", "", "machine password to login to the API") - flags.StringP("file", "f", "", "output file destination (defaults to "+csconfig.DefaultConfigPath("local_api_credentials.yaml")+")") - flags.StringP("url", "u", "", "URL of the local API") - flags.BoolP("interactive", "i", false, "interfactive mode to enter the password") - flags.BoolP("auto", "a", false, "automatically generate password (and username if not provided)") - flags.Bool("force", false, "will force add the machine if it already exist") - - return cmdMachinesAdd -} - -func runMachinesAdd(cmd *cobra.Command, args []string) error { - var dumpFile string - var err error - - flags := cmd.Flags() - - machinePassword, err := flags.GetString("password") - if err != nil { - return err - } - - outputFile, err := flags.GetString("file") - if err != nil { - return err - } - - apiURL, err := flags.GetString("url") - if err != nil { - return err - } - - interactive, err := flags.GetBool("interactive") - if err != nil { - return err - } - - autoAdd, err := flags.GetBool("auto") - if err != nil { - return err - } - - forceAdd, err := flags.GetBool("force") - if err != nil { - return err - } - - var machineID string - - // create machineID if not specified by user - if len(args) == 0 { - if !autoAdd { - printHelp(cmd) - return nil - } - machineID, err = generateID("") - if err != nil { - return fmt.Errorf("unable to generate machine id: %s", err) - } - } else { - machineID = args[0] - } - - /*check if file already exists*/ - if outputFile != "" { - dumpFile = outputFile - } else if csConfig.API.Client != nil && csConfig.API.Client.CredentialsFilePath != "" { - dumpFile = csConfig.API.Client.CredentialsFilePath - } - - // create a password if it's not specified by user - if machinePassword == "" && !interactive { - if !autoAdd { - printHelp(cmd) - return nil - } - machinePassword = generatePassword(passwordLength) - } else if machinePassword == "" && interactive { - qs := &survey.Password{ - Message: "Please provide a password for the machine", - } - survey.AskOne(qs, &machinePassword) - } - password := strfmt.Password(machinePassword) - _, err = dbClient.CreateMachine(&machineID, &password, "", true, forceAdd, types.PasswordAuthType) - if err != nil { - return fmt.Errorf("unable to create machine: %s", err) - } - log.Infof("Machine '%s' successfully added to the local API", machineID) - - if apiURL == "" { - if csConfig.API.Client != nil && csConfig.API.Client.Credentials != nil && csConfig.API.Client.Credentials.URL != "" { - apiURL = csConfig.API.Client.Credentials.URL - } else if csConfig.API.Server != nil && csConfig.API.Server.ListenURI != "" { - apiURL = "http://" + csConfig.API.Server.ListenURI - } else { - return fmt.Errorf("unable to dump an api URL. Please provide it in your configuration or with the -u parameter") - } - } - apiCfg := csconfig.ApiCredentialsCfg{ - Login: machineID, - Password: password.String(), - URL: apiURL, - } - apiConfigDump, err := yaml.Marshal(apiCfg) - if err != nil { - return fmt.Errorf("unable to marshal api credentials: %s", err) - } - if dumpFile != "" && dumpFile != "-" { - err = os.WriteFile(dumpFile, apiConfigDump, 0644) - if err != nil { - return fmt.Errorf("write api credentials in '%s' failed: %s", dumpFile, err) - } - log.Printf("API credentials dumped to '%s'", dumpFile) - } else { - fmt.Printf("%s\n", string(apiConfigDump)) - } - - return nil -} - -func NewMachinesDeleteCmd() *cobra.Command { - cmdMachinesDelete := &cobra.Command{ - Use: "delete [machine_name]...", - Short: "delete machines", - Example: `cscli machines delete "machine1" "machine2"`, - Args: cobra.MinimumNArgs(1), - Aliases: []string{"remove"}, - DisableAutoGenTag: true, - PreRunE: func(cmd *cobra.Command, args []string) error { - var err error - dbClient, err = database.NewClient(csConfig.DbConfig) - if err != nil { - return fmt.Errorf("unable to create new database client: %s", err) - } - return nil - }, - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - var err error - dbClient, err = getDBClient() - if err != nil { - cobra.CompError("unable to create new database client: " + err.Error()) - return nil, cobra.ShellCompDirectiveNoFileComp - } - machines, err := dbClient.ListMachines() - if err != nil { - cobra.CompError("unable to list machines " + err.Error()) - } - ret := make([]string, 0) - for _, machine := range machines { - if strings.Contains(machine.MachineId, toComplete) && !slices.Contains(args, machine.MachineId) { - ret = append(ret, machine.MachineId) - } - } - return ret, cobra.ShellCompDirectiveNoFileComp - }, - RunE: runMachinesDelete, - } - - return cmdMachinesDelete -} - -func runMachinesDelete(cmd *cobra.Command, args []string) error { - for _, machineID := range args { - err := dbClient.DeleteWatcher(machineID) - if err != nil { - log.Errorf("unable to delete machine '%s': %s", machineID, err) - return nil - } - log.Infof("machine '%s' deleted successfully", machineID) - } - - return nil -} - -func NewMachinesValidateCmd() *cobra.Command { - cmdMachinesValidate := &cobra.Command{ - Use: "validate", - Short: "validate a machine to access the local API", - Long: `validate a machine to access the local API.`, - Example: `cscli machines validate "machine_name"`, - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - PreRunE: func(cmd *cobra.Command, args []string) error { - var err error - dbClient, err = database.NewClient(csConfig.DbConfig) - if err != nil { - return fmt.Errorf("unable to create new database client: %s", err) - } - - return nil - }, - RunE: func(cmd *cobra.Command, args []string) error { - machineID := args[0] - if err := dbClient.ValidateMachine(machineID); err != nil { - return fmt.Errorf("unable to validate machine '%s': %s", machineID, err) - } - log.Infof("machine '%s' validated successfully", machineID) - - return nil - }, - } - - return cmdMachinesValidate -} - -func NewMachinesCmd() *cobra.Command { - var cmdMachines = &cobra.Command{ - Use: "machines [action]", - Short: "Manage local API machines [requires local API]", - Long: `To list/add/delete/validate machines. -Note: This command requires database direct access, so is intended to be run on the local API machine. -`, - Example: `cscli machines [action]`, - DisableAutoGenTag: true, - Aliases: []string{"machine"}, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if err := csConfig.LoadAPIServer(); err != nil || csConfig.DisableAPI { - if err != nil { - log.Errorf("local api : %s", err) - } - return fmt.Errorf("local API is disabled, please run this command on the local API machine") - } - - return nil - }, - } - - cmdMachines.AddCommand(NewMachinesListCmd()) - cmdMachines.AddCommand(NewMachinesAddCmd()) - cmdMachines.AddCommand(NewMachinesDeleteCmd()) - cmdMachines.AddCommand(NewMachinesValidateCmd()) - - return cmdMachines -} diff --git a/cmd/crowdsec-cli/machines_table.go b/cmd/crowdsec-cli/machines_table.go deleted file mode 100644 index e166fb785a6..00000000000 --- a/cmd/crowdsec-cli/machines_table.go +++ /dev/null @@ -1,35 +0,0 @@ -package main - -import ( - "io" - "time" - - "github.com/aquasecurity/table" - "github.com/enescakir/emoji" - - "github.com/crowdsecurity/crowdsec/pkg/database/ent" -) - -func getAgentsTable(out io.Writer, machines []*ent.Machine) { - t := newLightTable(out) - t.SetHeaders("Name", "IP Address", "Last Update", "Status", "Version", "Auth Type", "Last Heartbeat") - t.SetHeaderAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) - - for _, m := range machines { - var validated string - if m.IsValidated { - validated = emoji.CheckMark.String() - } else { - validated = emoji.Prohibited.String() - } - - hb, active := getLastHeartbeat(m) - if !active { - hb = emoji.Warning.String() + " " + hb - } - t.AddRow(m.MachineId, m.IpAddress, m.UpdatedAt.Format(time.RFC3339), validated, m.Version, m.AuthType, hb) - } - - t.Render() -} diff --git a/cmd/crowdsec-cli/main.go b/cmd/crowdsec-cli/main.go index 7a7814e63e0..1cca03b1d3d 100644 --- a/cmd/crowdsec-cli/main.go +++ b/cmd/crowdsec-cli/main.go @@ -3,67 +3,125 @@ package main import ( "fmt" "os" - "path" "path/filepath" - "strings" + "slices" + "time" "github.com/fatih/color" cc "github.com/ivanpirog/coloredcobra" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "github.com/spf13/cobra/doc" - "golang.org/x/exp/slices" + "github.com/crowdsecurity/go-cs-lib/trace" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clialert" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clibouncer" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clicapi" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cliconsole" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clidecision" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cliexplain" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clihub" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clihubtest" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cliitem" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clilapi" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/climachine" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/climetrics" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clinotifications" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clipapi" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clisimulation" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clisupport" "github.com/crowdsecurity/crowdsec/pkg/csconfig" - "github.com/crowdsecurity/crowdsec/pkg/cwhub" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" - "github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/fflag" ) -var trace_lvl, dbg_lvl, nfo_lvl, wrn_lvl, err_lvl bool +var ( + ConfigFilePath string + csConfig *csconfig.Config +) -var ConfigFilePath string -var csConfig *csconfig.Config -var dbClient *database.Client +type configGetter func() *csconfig.Config -var OutputFormat string -var OutputColor string +var mergedConfig string -var downloadOnly bool -var forceAction bool -var purge bool -var all bool +type cliRoot struct { + logTrace bool + logDebug bool + logInfo bool + logWarn bool + logErr bool + outputColor string + outputFormat string + // flagBranch overrides the value in csConfig.Cscli.HubBranch + flagBranch string +} -var prometheusURL string +func newCliRoot() *cliRoot { + return &cliRoot{} +} -var mergedConfig string +// cfg() is a helper function to get the configuration loaded from config.yaml, +// we pass it to subcommands because the file is not read until the Execute() call +func (cli *cliRoot) cfg() *csconfig.Config { + return csConfig +} -func initConfig() { - var err error - if trace_lvl { - log.SetLevel(log.TraceLevel) - } else if dbg_lvl { - log.SetLevel(log.DebugLevel) - } else if nfo_lvl { - log.SetLevel(log.InfoLevel) - } else if wrn_lvl { - log.SetLevel(log.WarnLevel) - } else if err_lvl { - log.SetLevel(log.ErrorLevel) +// wantedLogLevel returns the log level requested in the command line flags. +func (cli *cliRoot) wantedLogLevel() log.Level { + switch { + case cli.logTrace: + return log.TraceLevel + case cli.logDebug: + return log.DebugLevel + case cli.logInfo: + return log.InfoLevel + case cli.logWarn: + return log.WarnLevel + case cli.logErr: + return log.ErrorLevel + default: + return log.InfoLevel } +} + +// loadConfigFor loads the configuration file for the given sub-command. +// If the sub-command does not need it, it returns a default configuration. +func loadConfigFor(command string) (*csconfig.Config, string, error) { + noNeedConfig := []string{ + "doc", + "help", + "completion", + "version", + "hubtest", + } + + if !slices.Contains(noNeedConfig, command) { + log.Debugf("Using %s as configuration file", ConfigFilePath) - if !slices.Contains(NoNeedConfig, os.Args[1]) { - csConfig, mergedConfig, err = csconfig.NewConfig(ConfigFilePath, false, false, true) + config, merged, err := csconfig.NewConfig(ConfigFilePath, false, false, true) if err != nil { - log.Fatal(err) + return nil, "", err } - log.Debugf("Using %s as configuration file", ConfigFilePath) - if err := csConfig.LoadCSCLI(); err != nil { - log.Fatal(err) + + // set up directory for trace files + if err := trace.Init(filepath.Join(config.ConfigPaths.DataDir, "trace")); err != nil { + return nil, "", fmt.Errorf("while setting up trace directory: %w", err) } - } else { - csConfig = csconfig.NewDefaultConfig() + + return config, merged, nil + } + + return csconfig.NewDefaultConfig(), "", nil +} + +// initialize is called before the subcommand is executed. +func (cli *cliRoot) initialize() error { + var err error + + log.SetLevel(cli.wantedLogLevel()) + + csConfig, mergedConfig, err = loadConfigFor(os.Args[1]) + if err != nil { + return err } // recap of the enabled feature flags, because logging @@ -72,22 +130,24 @@ func initConfig() { log.Debugf("Enabled feature flags: %s", fflist) } - if csConfig.Cscli == nil { - log.Fatalf("missing 'cscli' configuration in '%s', exiting", ConfigFilePath) + if cli.flagBranch != "" { + csConfig.Cscli.HubBranch = cli.flagBranch } - if cwhub.HubBranch == "" && csConfig.Cscli.HubBranch != "" { - cwhub.HubBranch = csConfig.Cscli.HubBranch - } - if OutputFormat != "" { - csConfig.Cscli.Output = OutputFormat - if OutputFormat != "json" && OutputFormat != "raw" && OutputFormat != "human" { - log.Fatalf("output format %s unknown", OutputFormat) - } + if cli.outputFormat != "" { + csConfig.Cscli.Output = cli.outputFormat } + if csConfig.Cscli.Output == "" { csConfig.Cscli.Output = "human" } + + if csConfig.Cscli.Output != "human" && csConfig.Cscli.Output != "json" && csConfig.Cscli.Output != "raw" { + return fmt.Errorf("output format '%s' not supported: must be one of human, json, raw", csConfig.Cscli.Output) + } + + log.SetFormatter(&log.TextFormatter{DisableTimestamp: true}) + if csConfig.Cscli.Output == "json" { log.SetFormatter(&log.JSONFormatter{}) log.SetLevel(log.ErrorLevel) @@ -95,58 +155,57 @@ func initConfig() { log.SetLevel(log.ErrorLevel) } - if OutputColor != "" { - csConfig.Cscli.Color = OutputColor - if OutputColor != "yes" && OutputColor != "no" && OutputColor != "auto" { - log.Fatalf("output color %s unknown", OutputColor) + if cli.outputColor != "" { + csConfig.Cscli.Color = cli.outputColor + + if cli.outputColor != "yes" && cli.outputColor != "no" && cli.outputColor != "auto" { + return fmt.Errorf("output color '%s' not supported: must be one of yes, no, auto", cli.outputColor) } } -} - -var validArgs = []string{ - "scenarios", "parsers", "collections", "capi", "lapi", "postoverflows", "machines", - "metrics", "bouncers", "alerts", "decisions", "simulation", "hub", "dashboard", - "config", "completion", "version", "console", "notifications", "support", -} -func prepender(filename string) string { - const header = `--- -id: %s -title: %s ---- -` - name := filepath.Base(filename) - base := strings.TrimSuffix(name, path.Ext(name)) - return fmt.Sprintf(header, base, strings.ReplaceAll(base, "_", " ")) + return nil } -func linkHandler(name string) string { - return fmt.Sprintf("/cscli/%s", name) +func (cli *cliRoot) colorize(cmd *cobra.Command) { + cc.Init(&cc.Config{ + RootCmd: cmd, + Headings: cc.Yellow, + Commands: cc.Green + cc.Bold, + CmdShortDescr: cc.Cyan, + Example: cc.Italic, + ExecName: cc.Bold, + Aliases: cc.Bold + cc.Italic, + FlagsDataType: cc.White, + Flags: cc.Green, + FlagsDescr: cc.Cyan, + NoExtraNewlines: true, + NoBottomNewline: true, + }) + cmd.SetOut(color.Output) } -var ( - NoNeedConfig = []string{ - "help", - "completion", - "version", - "hubtest", - } -) - -func main() { +func (cli *cliRoot) NewCommand() (*cobra.Command, error) { // set the formatter asap and worry about level later - logFormatter := &log.TextFormatter{TimestampFormat: "02-01-2006 15:04:05", FullTimestamp: true} + logFormatter := &log.TextFormatter{TimestampFormat: time.RFC3339, FullTimestamp: true} log.SetFormatter(logFormatter) if err := fflag.RegisterAllFeatures(); err != nil { - log.Fatalf("failed to register features: %s", err) + return nil, fmt.Errorf("failed to register features: %w", err) } if err := csconfig.LoadFeatureFlagsEnv(log.StandardLogger()); err != nil { - log.Fatalf("failed to set feature flags from env: %s", err) + return nil, fmt.Errorf("failed to set feature flags from env: %w", err) } - var rootCmd = &cobra.Command{ + // list of valid subcommands for the shell completion + validArgs := []string{ + "alerts", "appsec-configs", "appsec-rules", "bouncers", "capi", "collections", + "completion", "config", "console", "contexts", "dashboard", "decisions", "explain", + "hub", "hubtest", "lapi", "machines", "metrics", "notifications", "parsers", + "postoverflows", "scenarios", "simulation", "support", "version", + } + + cmd := &cobra.Command{ Use: "cscli", Short: "cscli allows you to manage crowdsec", Long: `cscli is the main command to interact with your crowdsec service, scenarios & db. @@ -158,59 +217,25 @@ It is meant to allow you to manage bans, parsers/scenarios/etc, api and generall /*TBD examples*/ } - cc.Init(&cc.Config{ - RootCmd: rootCmd, - Headings: cc.Yellow, - Commands: cc.Green + cc.Bold, - CmdShortDescr: cc.Cyan, - Example: cc.Italic, - ExecName: cc.Bold, - Aliases: cc.Bold + cc.Italic, - FlagsDataType: cc.White, - Flags: cc.Green, - FlagsDescr: cc.Cyan, - }) - rootCmd.SetOut(color.Output) + cli.colorize(cmd) - var cmdDocGen = &cobra.Command{ - Use: "doc", - Short: "Generate the documentation in `./doc/`. Directory must exist.", - Args: cobra.ExactArgs(0), - Hidden: true, - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - if err := doc.GenMarkdownTreeCustom(rootCmd, "./doc/", prepender, linkHandler); err != nil { - return fmt.Errorf("Failed to generate cobra doc: %s", err) - } - return nil - }, - } - rootCmd.AddCommand(cmdDocGen) - /*usage*/ - var cmdVersion = &cobra.Command{ - Use: "version", - Short: "Display version and exit.", - Args: cobra.ExactArgs(0), - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - cwversion.Show() - }, - } - rootCmd.AddCommand(cmdVersion) - - rootCmd.PersistentFlags().StringVarP(&ConfigFilePath, "config", "c", csconfig.DefaultConfigPath("config.yaml"), "path to crowdsec config file") - rootCmd.PersistentFlags().StringVarP(&OutputFormat, "output", "o", "", "Output format: human, json, raw") - rootCmd.PersistentFlags().StringVarP(&OutputColor, "color", "", "auto", "Output color: yes, no, auto") - rootCmd.PersistentFlags().BoolVar(&dbg_lvl, "debug", false, "Set logging to debug") - rootCmd.PersistentFlags().BoolVar(&nfo_lvl, "info", false, "Set logging to info") - rootCmd.PersistentFlags().BoolVar(&wrn_lvl, "warning", false, "Set logging to warning") - rootCmd.PersistentFlags().BoolVar(&err_lvl, "error", false, "Set logging to error") - rootCmd.PersistentFlags().BoolVar(&trace_lvl, "trace", false, "Set logging to trace") - - rootCmd.PersistentFlags().StringVar(&cwhub.HubBranch, "branch", "", "Override hub branch on github") - if err := rootCmd.PersistentFlags().MarkHidden("branch"); err != nil { - log.Fatalf("failed to hide flag: %s", err) - } + /*don't sort flags so we can enforce order*/ + cmd.Flags().SortFlags = false + + pflags := cmd.PersistentFlags() + pflags.SortFlags = false + + pflags.StringVarP(&ConfigFilePath, "config", "c", csconfig.DefaultConfigPath("config.yaml"), "path to crowdsec config file") + pflags.StringVarP(&cli.outputFormat, "output", "o", "", "Output format: human, json, raw") + pflags.StringVarP(&cli.outputColor, "color", "", "auto", "Output color: yes, no, auto") + pflags.BoolVar(&cli.logDebug, "debug", false, "Set logging to debug") + pflags.BoolVar(&cli.logInfo, "info", false, "Set logging to info") + pflags.BoolVar(&cli.logWarn, "warning", false, "Set logging to warning") + pflags.BoolVar(&cli.logErr, "error", false, "Set logging to error") + pflags.BoolVar(&cli.logTrace, "trace", false, "Set logging to trace") + pflags.StringVar(&cli.flagBranch, "branch", "", "Override hub branch on github") + + _ = pflags.MarkHidden("branch") // Look for "-c /path/to/config.yaml" // This duplicates the logic in cobra, but we need to do it before @@ -224,48 +249,59 @@ It is meant to allow you to manage bans, parsers/scenarios/etc, api and generall } if err := csconfig.LoadFeatureFlagsFile(ConfigFilePath, log.StandardLogger()); err != nil { - log.Fatal(err) + return nil, err } + cmd.AddCommand(NewCLIDoc().NewCommand(cmd)) + cmd.AddCommand(NewCLIVersion().NewCommand()) + cmd.AddCommand(NewCLIConfig(cli.cfg).NewCommand()) + cmd.AddCommand(clihub.New(cli.cfg).NewCommand()) + cmd.AddCommand(climetrics.New(cli.cfg).NewCommand()) + cmd.AddCommand(NewCLIDashboard(cli.cfg).NewCommand()) + cmd.AddCommand(clidecision.New(cli.cfg).NewCommand()) + cmd.AddCommand(clialert.New(cli.cfg).NewCommand()) + cmd.AddCommand(clisimulation.New(cli.cfg).NewCommand()) + cmd.AddCommand(clibouncer.New(cli.cfg).NewCommand()) + cmd.AddCommand(climachine.New(cli.cfg).NewCommand()) + cmd.AddCommand(clicapi.New(cli.cfg).NewCommand()) + cmd.AddCommand(clilapi.New(cli.cfg).NewCommand()) + cmd.AddCommand(NewCompletionCmd()) + cmd.AddCommand(cliconsole.New(cli.cfg).NewCommand()) + cmd.AddCommand(cliexplain.New(cli.cfg, ConfigFilePath).NewCommand()) + cmd.AddCommand(clihubtest.New(cli.cfg).NewCommand()) + cmd.AddCommand(clinotifications.New(cli.cfg).NewCommand()) + cmd.AddCommand(clisupport.New(cli.cfg).NewCommand()) + cmd.AddCommand(clipapi.New(cli.cfg).NewCommand()) + cmd.AddCommand(cliitem.NewCollection(cli.cfg).NewCommand()) + cmd.AddCommand(cliitem.NewParser(cli.cfg).NewCommand()) + cmd.AddCommand(cliitem.NewScenario(cli.cfg).NewCommand()) + cmd.AddCommand(cliitem.NewPostOverflow(cli.cfg).NewCommand()) + cmd.AddCommand(cliitem.NewContext(cli.cfg).NewCommand()) + cmd.AddCommand(cliitem.NewAppsecConfig(cli.cfg).NewCommand()) + cmd.AddCommand(cliitem.NewAppsecRule(cli.cfg).NewCommand()) + + cli.addSetup(cmd) + if len(os.Args) > 1 { - cobra.OnInitialize(initConfig) + cobra.OnInitialize( + func() { + if err := cli.initialize(); err != nil { + log.Fatal(err) + } + }, + ) } - /*don't sort flags so we can enforce order*/ - rootCmd.Flags().SortFlags = false - rootCmd.PersistentFlags().SortFlags = false - - rootCmd.AddCommand(NewConfigCmd()) - rootCmd.AddCommand(NewHubCmd()) - rootCmd.AddCommand(NewMetricsCmd()) - rootCmd.AddCommand(NewDashboardCmd()) - rootCmd.AddCommand(NewDecisionsCmd()) - rootCmd.AddCommand(NewAlertsCmd()) - rootCmd.AddCommand(NewSimulationCmds()) - rootCmd.AddCommand(NewBouncersCmd()) - rootCmd.AddCommand(NewMachinesCmd()) - rootCmd.AddCommand(NewParsersCmd()) - rootCmd.AddCommand(NewScenariosCmd()) - rootCmd.AddCommand(NewCollectionsCmd()) - rootCmd.AddCommand(NewPostOverflowsCmd()) - rootCmd.AddCommand(NewCapiCmd()) - rootCmd.AddCommand(NewLapiCmd()) - rootCmd.AddCommand(NewCompletionCmd()) - rootCmd.AddCommand(NewConsoleCmd()) - rootCmd.AddCommand(NewExplainCmd()) - rootCmd.AddCommand(NewHubTestCmd()) - rootCmd.AddCommand(NewNotificationsCmd()) - rootCmd.AddCommand(NewSupportCmd()) - - if fflag.CscliSetup.IsEnabled() { - rootCmd.AddCommand(NewSetupCmd()) - } + return cmd, nil +} - if fflag.PapiClient.IsEnabled() { - rootCmd.AddCommand(NewPapiCmd()) +func main() { + cmd, err := newCliRoot().NewCommand() + if err != nil { + log.Fatal(err) } - if err := rootCmd.Execute(); err != nil { + if err := cmd.Execute(); err != nil { log.Fatal(err) } } diff --git a/cmd/crowdsec-cli/messages.go b/cmd/crowdsec-cli/messages.go deleted file mode 100644 index 02f051601e4..00000000000 --- a/cmd/crowdsec-cli/messages.go +++ /dev/null @@ -1,23 +0,0 @@ -package main - -import ( - "fmt" - "runtime" -) - -// ReloadMessage returns a description of the task required to reload -// the crowdsec configuration, according to the operating system. -func ReloadMessage() string { - var msg string - - switch runtime.GOOS { - case "windows": - msg = "Please restart the crowdsec service" - case "freebsd": - msg = `Run 'sudo service crowdsec reload'` - default: - msg = `Run 'sudo systemctl reload crowdsec'` - } - - return fmt.Sprintf("%s for the new configuration to be effective.", msg) -} diff --git a/cmd/crowdsec-cli/metrics.go b/cmd/crowdsec-cli/metrics.go deleted file mode 100644 index 1c506040f6d..00000000000 --- a/cmd/crowdsec-cli/metrics.go +++ /dev/null @@ -1,328 +0,0 @@ -package main - -import ( - "encoding/json" - "fmt" - "io" - "net/http" - "strconv" - "strings" - "time" - - "github.com/fatih/color" - dto "github.com/prometheus/client_model/go" - "github.com/prometheus/prom2json" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - "gopkg.in/yaml.v3" - - "github.com/crowdsecurity/go-cs-lib/pkg/trace" -) - -// FormatPrometheusMetrics is a complete rip from prom2json -func FormatPrometheusMetrics(out io.Writer, url string, formatType string) error { - mfChan := make(chan *dto.MetricFamily, 1024) - errChan := make(chan error, 1) - - // Start with the DefaultTransport for sane defaults. - transport := http.DefaultTransport.(*http.Transport).Clone() - // Conservatively disable HTTP keep-alives as this program will only - // ever need a single HTTP request. - transport.DisableKeepAlives = true - // Timeout early if the server doesn't even return the headers. - transport.ResponseHeaderTimeout = time.Minute - go func() { - defer trace.CatchPanic("crowdsec/ShowPrometheus") - err := prom2json.FetchMetricFamilies(url, mfChan, transport) - if err != nil { - errChan <- fmt.Errorf("failed to fetch prometheus metrics: %w", err) - return - } - errChan <- nil - }() - - result := []*prom2json.Family{} - for mf := range mfChan { - result = append(result, prom2json.NewFamily(mf)) - } - - if err := <-errChan; err != nil { - return err - } - - log.Debugf("Finished reading prometheus output, %d entries", len(result)) - /*walk*/ - lapi_decisions_stats := map[string]struct { - NonEmpty int - Empty int - }{} - acquis_stats := map[string]map[string]int{} - parsers_stats := map[string]map[string]int{} - buckets_stats := map[string]map[string]int{} - lapi_stats := map[string]map[string]int{} - lapi_machine_stats := map[string]map[string]map[string]int{} - lapi_bouncer_stats := map[string]map[string]map[string]int{} - decisions_stats := map[string]map[string]map[string]int{} - alerts_stats := map[string]int{} - stash_stats := map[string]struct { - Type string - Count int - }{} - - for idx, fam := range result { - if !strings.HasPrefix(fam.Name, "cs_") { - continue - } - log.Tracef("round %d", idx) - for _, m := range fam.Metrics { - metric, ok := m.(prom2json.Metric) - if !ok { - log.Debugf("failed to convert metric to prom2json.Metric") - continue - } - name, ok := metric.Labels["name"] - if !ok { - log.Debugf("no name in Metric %v", metric.Labels) - } - source, ok := metric.Labels["source"] - if !ok { - log.Debugf("no source in Metric %v for %s", metric.Labels, fam.Name) - } else { - if srctype, ok := metric.Labels["type"]; ok { - source = srctype + ":" + source - } - } - - value := m.(prom2json.Metric).Value - machine := metric.Labels["machine"] - bouncer := metric.Labels["bouncer"] - - route := metric.Labels["route"] - method := metric.Labels["method"] - - reason := metric.Labels["reason"] - origin := metric.Labels["origin"] - action := metric.Labels["action"] - - mtype := metric.Labels["type"] - - fval, err := strconv.ParseFloat(value, 32) - if err != nil { - log.Errorf("Unexpected int value %s : %s", value, err) - } - ival := int(fval) - switch fam.Name { - /*buckets*/ - case "cs_bucket_created_total": - if _, ok := buckets_stats[name]; !ok { - buckets_stats[name] = make(map[string]int) - } - buckets_stats[name]["instantiation"] += ival - case "cs_buckets": - if _, ok := buckets_stats[name]; !ok { - buckets_stats[name] = make(map[string]int) - } - buckets_stats[name]["curr_count"] += ival - case "cs_bucket_overflowed_total": - if _, ok := buckets_stats[name]; !ok { - buckets_stats[name] = make(map[string]int) - } - buckets_stats[name]["overflow"] += ival - case "cs_bucket_poured_total": - if _, ok := buckets_stats[name]; !ok { - buckets_stats[name] = make(map[string]int) - } - if _, ok := acquis_stats[source]; !ok { - acquis_stats[source] = make(map[string]int) - } - buckets_stats[name]["pour"] += ival - acquis_stats[source]["pour"] += ival - case "cs_bucket_underflowed_total": - if _, ok := buckets_stats[name]; !ok { - buckets_stats[name] = make(map[string]int) - } - buckets_stats[name]["underflow"] += ival - /*acquis*/ - case "cs_parser_hits_total": - if _, ok := acquis_stats[source]; !ok { - acquis_stats[source] = make(map[string]int) - } - acquis_stats[source]["reads"] += ival - case "cs_parser_hits_ok_total": - if _, ok := acquis_stats[source]; !ok { - acquis_stats[source] = make(map[string]int) - } - acquis_stats[source]["parsed"] += ival - case "cs_parser_hits_ko_total": - if _, ok := acquis_stats[source]; !ok { - acquis_stats[source] = make(map[string]int) - } - acquis_stats[source]["unparsed"] += ival - case "cs_node_hits_total": - if _, ok := parsers_stats[name]; !ok { - parsers_stats[name] = make(map[string]int) - } - parsers_stats[name]["hits"] += ival - case "cs_node_hits_ok_total": - if _, ok := parsers_stats[name]; !ok { - parsers_stats[name] = make(map[string]int) - } - parsers_stats[name]["parsed"] += ival - case "cs_node_hits_ko_total": - if _, ok := parsers_stats[name]; !ok { - parsers_stats[name] = make(map[string]int) - } - parsers_stats[name]["unparsed"] += ival - case "cs_lapi_route_requests_total": - if _, ok := lapi_stats[route]; !ok { - lapi_stats[route] = make(map[string]int) - } - lapi_stats[route][method] += ival - case "cs_lapi_machine_requests_total": - if _, ok := lapi_machine_stats[machine]; !ok { - lapi_machine_stats[machine] = make(map[string]map[string]int) - } - if _, ok := lapi_machine_stats[machine][route]; !ok { - lapi_machine_stats[machine][route] = make(map[string]int) - } - lapi_machine_stats[machine][route][method] += ival - case "cs_lapi_bouncer_requests_total": - if _, ok := lapi_bouncer_stats[bouncer]; !ok { - lapi_bouncer_stats[bouncer] = make(map[string]map[string]int) - } - if _, ok := lapi_bouncer_stats[bouncer][route]; !ok { - lapi_bouncer_stats[bouncer][route] = make(map[string]int) - } - lapi_bouncer_stats[bouncer][route][method] += ival - case "cs_lapi_decisions_ko_total", "cs_lapi_decisions_ok_total": - if _, ok := lapi_decisions_stats[bouncer]; !ok { - lapi_decisions_stats[bouncer] = struct { - NonEmpty int - Empty int - }{} - } - x := lapi_decisions_stats[bouncer] - if fam.Name == "cs_lapi_decisions_ko_total" { - x.Empty += ival - } else if fam.Name == "cs_lapi_decisions_ok_total" { - x.NonEmpty += ival - } - lapi_decisions_stats[bouncer] = x - case "cs_active_decisions": - if _, ok := decisions_stats[reason]; !ok { - decisions_stats[reason] = make(map[string]map[string]int) - } - if _, ok := decisions_stats[reason][origin]; !ok { - decisions_stats[reason][origin] = make(map[string]int) - } - decisions_stats[reason][origin][action] += ival - case "cs_alerts": - /*if _, ok := alerts_stats[scenario]; !ok { - alerts_stats[scenario] = make(map[string]int) - }*/ - alerts_stats[reason] += ival - case "cs_cache_size": - stash_stats[name] = struct { - Type string - Count int - }{Type: mtype, Count: ival} - default: - continue - } - - } - } - - if formatType == "human" { - acquisStatsTable(out, acquis_stats) - bucketStatsTable(out, buckets_stats) - parserStatsTable(out, parsers_stats) - lapiStatsTable(out, lapi_stats) - lapiMachineStatsTable(out, lapi_machine_stats) - lapiBouncerStatsTable(out, lapi_bouncer_stats) - lapiDecisionStatsTable(out, lapi_decisions_stats) - decisionStatsTable(out, decisions_stats) - alertStatsTable(out, alerts_stats) - stashStatsTable(out, stash_stats) - return nil - } - - stats := make(map[string]any) - - stats["acquisition"] = acquis_stats - stats["buckets"] = buckets_stats - stats["parsers"] = parsers_stats - stats["lapi"] = lapi_stats - stats["lapi_machine"] = lapi_machine_stats - stats["lapi_bouncer"] = lapi_bouncer_stats - stats["lapi_decisions"] = lapi_decisions_stats - stats["decisions"] = decisions_stats - stats["alerts"] = alerts_stats - stats["stash"] = stash_stats - - switch formatType { - case "json": - x, err := json.MarshalIndent(stats, "", " ") - if err != nil { - return fmt.Errorf("failed to unmarshal metrics : %v", err) - } - out.Write(x) - case "raw": - x, err := yaml.Marshal(stats) - if err != nil { - return fmt.Errorf("failed to unmarshal metrics : %v", err) - } - out.Write(x) - default: - return fmt.Errorf("unknown format type %s", formatType) - } - - return nil -} - -var noUnit bool - - -func runMetrics(cmd *cobra.Command, args []string) error { - if err := csConfig.LoadPrometheus(); err != nil { - return fmt.Errorf("failed to load prometheus config: %w", err) - } - - if csConfig.Prometheus == nil { - return fmt.Errorf("prometheus section missing, can't show metrics") - } - - if !csConfig.Prometheus.Enabled { - return fmt.Errorf("prometheus is not enabled, can't show metrics") - } - - if prometheusURL == "" { - prometheusURL = csConfig.Cscli.PrometheusUrl - } - - if prometheusURL == "" { - return fmt.Errorf("no prometheus url, please specify in %s or via -u", *csConfig.FilePath) - } - - err := FormatPrometheusMetrics(color.Output, prometheusURL+"/metrics", csConfig.Cscli.Output) - if err != nil { - return fmt.Errorf("could not fetch prometheus metrics: %w", err) - } - return nil -} - - -func NewMetricsCmd() *cobra.Command { - cmdMetrics := &cobra.Command{ - Use: "metrics", - Short: "Display crowdsec prometheus metrics.", - Long: `Fetch metrics from the prometheus server and display them in a human-friendly way`, - Args: cobra.ExactArgs(0), - DisableAutoGenTag: true, - RunE: runMetrics, - } - cmdMetrics.PersistentFlags().StringVarP(&prometheusURL, "url", "u", "", "Prometheus url (http://:/metrics)") - cmdMetrics.PersistentFlags().BoolVar(&noUnit, "no-unit", false, "Show the real number instead of formatted with units") - - return cmdMetrics -} diff --git a/cmd/crowdsec-cli/metrics_table.go b/cmd/crowdsec-cli/metrics_table.go deleted file mode 100644 index 69706c7acf2..00000000000 --- a/cmd/crowdsec-cli/metrics_table.go +++ /dev/null @@ -1,307 +0,0 @@ -package main - -import ( - "fmt" - "io" - "sort" - - "github.com/aquasecurity/table" - log "github.com/sirupsen/logrus" -) - -func lapiMetricsToTable(t *table.Table, stats map[string]map[string]map[string]int) int { - // stats: machine -> route -> method -> count - - // sort keys to keep consistent order when printing - machineKeys := []string{} - for k := range stats { - machineKeys = append(machineKeys, k) - } - sort.Strings(machineKeys) - - numRows := 0 - for _, machine := range machineKeys { - // oneRow: route -> method -> count - machineRow := stats[machine] - for routeName, route := range machineRow { - for methodName, count := range route { - row := []string{ - machine, - routeName, - methodName, - } - if count != 0 { - row = append(row, fmt.Sprintf("%d", count)) - } else { - row = append(row, "-") - } - t.AddRow(row...) - numRows++ - } - } - } - return numRows -} - -func metricsToTable(t *table.Table, stats map[string]map[string]int, keys []string) (int, error) { - if t == nil { - return 0, fmt.Errorf("nil table") - } - // sort keys to keep consistent order when printing - sortedKeys := []string{} - for k := range stats { - sortedKeys = append(sortedKeys, k) - } - sort.Strings(sortedKeys) - - numRows := 0 - for _, alabel := range sortedKeys { - astats, ok := stats[alabel] - if !ok { - continue - } - row := []string{ - alabel, - } - for _, sl := range keys { - if v, ok := astats[sl]; ok && v != 0 { - numberToShow := fmt.Sprintf("%d", v) - if !noUnit { - numberToShow = formatNumber(v) - } - - row = append(row, numberToShow) - } else { - row = append(row, "-") - } - } - t.AddRow(row...) - numRows++ - } - return numRows, nil -} - -func bucketStatsTable(out io.Writer, stats map[string]map[string]int) { - t := newTable(out) - t.SetRowLines(false) - t.SetHeaders("Bucket", "Current Count", "Overflows", "Instantiated", "Poured", "Expired") - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) - - keys := []string{"curr_count", "overflow", "instantiation", "pour", "underflow"} - - if numRows, err := metricsToTable(t, stats, keys); err != nil { - log.Warningf("while collecting acquis stats: %s", err) - } else if numRows > 0 { - renderTableTitle(out, "\nBucket Metrics:") - t.Render() - } -} - -func acquisStatsTable(out io.Writer, stats map[string]map[string]int) { - t := newTable(out) - t.SetRowLines(false) - t.SetHeaders("Source", "Lines read", "Lines parsed", "Lines unparsed", "Lines poured to bucket") - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) - - keys := []string{"reads", "parsed", "unparsed", "pour"} - - if numRows, err := metricsToTable(t, stats, keys); err != nil { - log.Warningf("while collecting acquis stats: %s", err) - } else if numRows > 0 { - renderTableTitle(out, "\nAcquisition Metrics:") - t.Render() - } -} - -func parserStatsTable(out io.Writer, stats map[string]map[string]int) { - t := newTable(out) - t.SetRowLines(false) - t.SetHeaders("Parsers", "Hits", "Parsed", "Unparsed") - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) - - keys := []string{"hits", "parsed", "unparsed"} - - if numRows, err := metricsToTable(t, stats, keys); err != nil { - log.Warningf("while collecting acquis stats: %s", err) - } else if numRows > 0 { - renderTableTitle(out, "\nParser Metrics:") - t.Render() - } -} - -func stashStatsTable(out io.Writer, stats map[string]struct { - Type string - Count int -}) { - - t := newTable(out) - t.SetRowLines(false) - t.SetHeaders("Name", "Type", "Items") - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft) - - // unfortunately, we can't reuse metricsToTable as the structure is too different :/ - sortedKeys := []string{} - for k := range stats { - sortedKeys = append(sortedKeys, k) - } - sort.Strings(sortedKeys) - - numRows := 0 - for _, alabel := range sortedKeys { - astats := stats[alabel] - - row := []string{ - alabel, - astats.Type, - fmt.Sprintf("%d", astats.Count), - } - t.AddRow(row...) - numRows++ - } - if numRows > 0 { - renderTableTitle(out, "\nParser Stash Metrics:") - t.Render() - } -} - -func lapiStatsTable(out io.Writer, stats map[string]map[string]int) { - t := newTable(out) - t.SetRowLines(false) - t.SetHeaders("Route", "Method", "Hits") - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft) - - // unfortunately, we can't reuse metricsToTable as the structure is too different :/ - sortedKeys := []string{} - for k := range stats { - sortedKeys = append(sortedKeys, k) - } - sort.Strings(sortedKeys) - - numRows := 0 - for _, alabel := range sortedKeys { - astats := stats[alabel] - - subKeys := []string{} - for skey := range astats { - subKeys = append(subKeys, skey) - } - sort.Strings(subKeys) - - for _, sl := range subKeys { - row := []string{ - alabel, - sl, - fmt.Sprintf("%d", astats[sl]), - } - t.AddRow(row...) - numRows++ - } - } - - if numRows > 0 { - renderTableTitle(out, "\nLocal API Metrics:") - t.Render() - } -} - -func lapiMachineStatsTable(out io.Writer, stats map[string]map[string]map[string]int) { - t := newTable(out) - t.SetRowLines(false) - t.SetHeaders("Machine", "Route", "Method", "Hits") - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) - - numRows := lapiMetricsToTable(t, stats) - - if numRows > 0 { - renderTableTitle(out, "\nLocal API Machines Metrics:") - t.Render() - } -} - -func lapiBouncerStatsTable(out io.Writer, stats map[string]map[string]map[string]int) { - t := newTable(out) - t.SetRowLines(false) - t.SetHeaders("Bouncer", "Route", "Method", "Hits") - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) - - numRows := lapiMetricsToTable(t, stats) - - if numRows > 0 { - renderTableTitle(out, "\nLocal API Bouncers Metrics:") - t.Render() - } -} - -func lapiDecisionStatsTable(out io.Writer, stats map[string]struct { - NonEmpty int - Empty int -}, -) { - t := newTable(out) - t.SetRowLines(false) - t.SetHeaders("Bouncer", "Empty answers", "Non-empty answers") - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft) - - numRows := 0 - for bouncer, hits := range stats { - t.AddRow( - bouncer, - fmt.Sprintf("%d", hits.Empty), - fmt.Sprintf("%d", hits.NonEmpty), - ) - numRows++ - } - - if numRows > 0 { - renderTableTitle(out, "\nLocal API Bouncers Decisions:") - t.Render() - } -} - -func decisionStatsTable(out io.Writer, stats map[string]map[string]map[string]int) { - t := newTable(out) - t.SetRowLines(false) - t.SetHeaders("Reason", "Origin", "Action", "Count") - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) - - numRows := 0 - for reason, origins := range stats { - for origin, actions := range origins { - for action, hits := range actions { - t.AddRow( - reason, - origin, - action, - fmt.Sprintf("%d", hits), - ) - numRows++ - } - } - } - - if numRows > 0 { - renderTableTitle(out, "\nLocal API Decisions:") - t.Render() - } -} - -func alertStatsTable(out io.Writer, stats map[string]int) { - t := newTable(out) - t.SetRowLines(false) - t.SetHeaders("Reason", "Count") - t.SetAlignment(table.AlignLeft, table.AlignLeft) - - numRows := 0 - for scenario, hits := range stats { - t.AddRow( - scenario, - fmt.Sprintf("%d", hits), - ) - numRows++ - } - - if numRows > 0 { - renderTableTitle(out, "\nLocal API Alerts:") - t.Render() - } -} diff --git a/cmd/crowdsec-cli/notifications.go b/cmd/crowdsec-cli/notifications.go deleted file mode 100644 index fe4c14f27c5..00000000000 --- a/cmd/crowdsec-cli/notifications.go +++ /dev/null @@ -1,348 +0,0 @@ -package main - -import ( - "context" - "encoding/csv" - "encoding/json" - "fmt" - "io/fs" - "net/url" - "os" - "path/filepath" - "strconv" - "strings" - "time" - - "github.com/fatih/color" - "github.com/go-openapi/strfmt" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - "gopkg.in/tomb.v2" - - "github.com/crowdsecurity/go-cs-lib/pkg/version" - - "github.com/crowdsecurity/crowdsec/pkg/apiclient" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" - "github.com/crowdsecurity/crowdsec/pkg/csplugin" - "github.com/crowdsecurity/crowdsec/pkg/csprofiles" -) - -type NotificationsCfg struct { - Config csplugin.PluginConfig `json:"plugin_config"` - Profiles []*csconfig.ProfileCfg `json:"associated_profiles"` - ids []uint -} - -func NewNotificationsCmd() *cobra.Command { - var cmdNotifications = &cobra.Command{ - Use: "notifications [action]", - Short: "Helper for notification plugin configuration", - Long: "To list/inspect/test notification template", - Args: cobra.MinimumNArgs(1), - Aliases: []string{"notifications", "notification"}, - DisableAutoGenTag: true, - PersistentPreRun: func(cmd *cobra.Command, args []string) { - var ( - err error - ) - if err = csConfig.API.Server.LoadProfiles(); err != nil { - log.Fatal(err) - } - if csConfig.ConfigPaths.NotificationDir == "" { - log.Fatalf("config_paths.notification_dir is not set in crowdsec config") - } - }, - } - - cmdNotifications.AddCommand(NewNotificationsListCmd()) - cmdNotifications.AddCommand(NewNotificationsInspectCmd()) - cmdNotifications.AddCommand(NewNotificationsReinjectCmd()) - - return cmdNotifications -} - -func getNotificationsConfiguration() (map[string]NotificationsCfg, error) { - pcfgs := map[string]csplugin.PluginConfig{} - wf := func(path string, info fs.FileInfo, err error) error { - if info == nil { - return fmt.Errorf("error while traversing directory %s: %w", path, err) - } - name := filepath.Join(csConfig.ConfigPaths.NotificationDir, info.Name()) //Avoid calling info.Name() twice - if (strings.HasSuffix(name, "yaml") || strings.HasSuffix(name, "yml")) && !(info.IsDir()) { - ts, err := csplugin.ParsePluginConfigFile(name) - if err != nil { - return fmt.Errorf("loading notifification plugin configuration with %s: %w", name, err) - } - for _, t := range ts { - pcfgs[t.Name] = t - } - } - return nil - } - - if err := filepath.Walk(csConfig.ConfigPaths.NotificationDir, wf); err != nil { - return nil, fmt.Errorf("while loading notifification plugin configuration: %w", err) - } - - // A bit of a tricky stuf now: reconcile profiles and notification plugins - ncfgs := map[string]NotificationsCfg{} - profiles, err := csprofiles.NewProfile(csConfig.API.Server.Profiles) - if err != nil { - return nil, fmt.Errorf("while extracting profiles from configuration: %w", err) - } - for profileID, profile := range profiles { - loop: - for _, notif := range profile.Cfg.Notifications { - for name, pc := range pcfgs { - if notif == name { - if _, ok := ncfgs[pc.Name]; !ok { - ncfgs[pc.Name] = NotificationsCfg{ - Config: pc, - Profiles: []*csconfig.ProfileCfg{profile.Cfg}, - ids: []uint{uint(profileID)}, - } - continue loop - } - tmp := ncfgs[pc.Name] - for _, pr := range tmp.Profiles { - var profiles []*csconfig.ProfileCfg - if pr.Name == profile.Cfg.Name { - continue - } - profiles = append(tmp.Profiles, profile.Cfg) - ids := append(tmp.ids, uint(profileID)) - ncfgs[pc.Name] = NotificationsCfg{ - Config: tmp.Config, - Profiles: profiles, - ids: ids, - } - } - } - } - } - } - return ncfgs, nil -} - -func NewNotificationsListCmd() *cobra.Command { - var cmdNotificationsList = &cobra.Command{ - Use: "list", - Short: "List active notifications plugins", - Long: `List active notifications plugins`, - Example: `cscli notifications list`, - Args: cobra.ExactArgs(0), - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, arg []string) error { - ncfgs, err := getNotificationsConfiguration() - if err != nil { - return fmt.Errorf("can't build profiles configuration: %w", err) - } - - if csConfig.Cscli.Output == "human" { - notificationListTable(color.Output, ncfgs) - } else if csConfig.Cscli.Output == "json" { - x, err := json.MarshalIndent(ncfgs, "", " ") - if err != nil { - return fmt.Errorf("failed to marshal notification configuration: %w", err) - } - fmt.Printf("%s", string(x)) - } else if csConfig.Cscli.Output == "raw" { - csvwriter := csv.NewWriter(os.Stdout) - err := csvwriter.Write([]string{"Name", "Type", "Profile name"}) - if err != nil { - return fmt.Errorf("failed to write raw header: %w", err) - } - for _, b := range ncfgs { - profilesList := []string{} - for _, p := range b.Profiles { - profilesList = append(profilesList, p.Name) - } - err := csvwriter.Write([]string{b.Config.Name, b.Config.Type, strings.Join(profilesList, ", ")}) - if err != nil { - return fmt.Errorf("failed to write raw content: %w", err) - } - } - csvwriter.Flush() - } - return nil - }, - } - - return cmdNotificationsList -} - -func NewNotificationsInspectCmd() *cobra.Command { - var cmdNotificationsInspect = &cobra.Command{ - Use: "inspect", - Short: "Inspect active notifications plugin configuration", - Long: `Inspect active notifications plugin and show configuration`, - Example: `cscli notifications inspect `, - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, arg []string) error { - var ( - cfg NotificationsCfg - ok bool - ) - - pluginName := arg[0] - - if pluginName == "" { - return fmt.Errorf("please provide a plugin name to inspect") - } - ncfgs, err := getNotificationsConfiguration() - if err != nil { - return fmt.Errorf("can't build profiles configuration: %w", err) - } - if cfg, ok = ncfgs[pluginName]; !ok { - return fmt.Errorf("plugin '%s' does not exist or is not active", pluginName) - } - - if csConfig.Cscli.Output == "human" || csConfig.Cscli.Output == "raw" { - fmt.Printf(" - %15s: %15s\n", "Type", cfg.Config.Type) - fmt.Printf(" - %15s: %15s\n", "Name", cfg.Config.Name) - fmt.Printf(" - %15s: %15s\n", "Timeout", cfg.Config.TimeOut) - fmt.Printf(" - %15s: %15s\n", "Format", cfg.Config.Format) - for k, v := range cfg.Config.Config { - fmt.Printf(" - %15s: %15v\n", k, v) - } - } else if csConfig.Cscli.Output == "json" { - x, err := json.MarshalIndent(cfg, "", " ") - if err != nil { - return fmt.Errorf("failed to marshal notification configuration: %w", err) - } - fmt.Printf("%s", string(x)) - } - return nil - }, - } - - return cmdNotificationsInspect -} - -func NewNotificationsReinjectCmd() *cobra.Command { - var remediation bool - var alertOverride string - - var cmdNotificationsReinject = &cobra.Command{ - Use: "reinject", - Short: "reinject alert into notifications system", - Long: `Reinject alert into notifications system`, - Example: ` -cscli notifications reinject -cscli notifications reinject --remediation -cscli notifications reinject -a '{"remediation": true,"scenario":"notification/test"}' -`, - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - var ( - pluginBroker csplugin.PluginBroker - pluginTomb tomb.Tomb - ) - if len(args) != 1 { - printHelp(cmd) - return fmt.Errorf("wrong number of argument: there should be one argument") - } - - //first: get the alert - id, err := strconv.Atoi(args[0]) - if err != nil { - return fmt.Errorf("bad alert id %s", args[0]) - } - if err := csConfig.LoadAPIClient(); err != nil { - return fmt.Errorf("loading api client: %w", err) - } - if csConfig.API.Client == nil { - return fmt.Errorf("missing configuration on 'api_client:'") - } - if csConfig.API.Client.Credentials == nil { - return fmt.Errorf("missing API credentials in '%s'", csConfig.API.Client.CredentialsFilePath) - } - apiURL, err := url.Parse(csConfig.API.Client.Credentials.URL) - if err != nil { - return fmt.Errorf("error parsing the URL of the API: %w", err) - } - client, err := apiclient.NewClient(&apiclient.Config{ - MachineID: csConfig.API.Client.Credentials.Login, - Password: strfmt.Password(csConfig.API.Client.Credentials.Password), - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), - URL: apiURL, - VersionPrefix: "v1", - }) - if err != nil { - return fmt.Errorf("error creating the client for the API: %w", err) - } - alert, _, err := client.Alerts.GetByID(context.Background(), id) - if err != nil { - return fmt.Errorf("can't find alert with id %s: %w", args[0], err) - } - - if alertOverride != "" { - if err = json.Unmarshal([]byte(alertOverride), alert); err != nil { - return fmt.Errorf("can't unmarshal data in the alert flag: %w", err) - } - } - if !remediation { - alert.Remediation = true - } - - // second we start plugins - err = pluginBroker.Init(csConfig.PluginConfig, csConfig.API.Server.Profiles, csConfig.ConfigPaths) - if err != nil { - return fmt.Errorf("can't initialize plugins: %w", err) - } - - pluginTomb.Go(func() error { - pluginBroker.Run(&pluginTomb) - return nil - }) - - //third: get the profile(s), and process the whole stuff - - profiles, err := csprofiles.NewProfile(csConfig.API.Server.Profiles) - if err != nil { - return fmt.Errorf("cannot extract profiles from configuration: %w", err) - } - - for id, profile := range profiles { - _, matched, err := profile.EvaluateProfile(alert) - if err != nil { - return fmt.Errorf("can't evaluate profile %s: %w", profile.Cfg.Name, err) - } - if !matched { - log.Infof("The profile %s didn't match", profile.Cfg.Name) - continue - } - log.Infof("The profile %s matched, sending to its configured notification plugins", profile.Cfg.Name) - loop: - for { - select { - case pluginBroker.PluginChannel <- csplugin.ProfileAlert{ - ProfileID: uint(id), - Alert: alert, - }: - break loop - default: - time.Sleep(50 * time.Millisecond) - log.Info("sleeping\n") - - } - } - if profile.Cfg.OnSuccess == "break" { - log.Infof("The profile %s contains a 'on_success: break' so bailing out", profile.Cfg.Name) - break - } - } - - // time.Sleep(2 * time.Second) // There's no mechanism to ensure notification has been sent - pluginTomb.Kill(fmt.Errorf("terminating")) - pluginTomb.Wait() - return nil - }, - } - cmdNotificationsReinject.Flags().BoolVarP(&remediation, "remediation", "r", false, "Set Alert.Remediation to false in the reinjected alert (see your profile filter configuration)") - cmdNotificationsReinject.Flags().StringVarP(&alertOverride, "alert", "a", "", "JSON string used to override alert fields in the reinjected alert (see crowdsec/pkg/models/alert.go in the source tree for the full definition of the object)") - - return cmdNotificationsReinject -} diff --git a/cmd/crowdsec-cli/notifications_table.go b/cmd/crowdsec-cli/notifications_table.go deleted file mode 100644 index 1113bb7c809..00000000000 --- a/cmd/crowdsec-cli/notifications_table.go +++ /dev/null @@ -1,25 +0,0 @@ -package main - -import ( - "io" - "strings" - - "github.com/aquasecurity/table" -) - -func notificationListTable(out io.Writer, ncfgs map[string]NotificationsCfg) { - t := newLightTable(out) - t.SetHeaders("Name", "Type", "Profile name") - t.SetHeaderAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft) - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft) - - for _, b := range ncfgs { - profilesList := []string{} - for _, p := range b.Profiles { - profilesList = append(profilesList, p.Name) - } - t.AddRow(b.Config.Name, b.Config.Type, strings.Join(profilesList, ", ")) - } - - t.Render() -} diff --git a/cmd/crowdsec-cli/papi.go b/cmd/crowdsec-cli/papi.go deleted file mode 100644 index d38da0df9b8..00000000000 --- a/cmd/crowdsec-cli/papi.go +++ /dev/null @@ -1,138 +0,0 @@ -package main - -import ( - "fmt" - "time" - - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - "gopkg.in/tomb.v2" - - "github.com/crowdsecurity/go-cs-lib/pkg/ptr" - - "github.com/crowdsecurity/crowdsec/pkg/apiserver" - "github.com/crowdsecurity/crowdsec/pkg/database" -) - -func NewPapiCmd() *cobra.Command { - var cmdLapi = &cobra.Command{ - Use: "papi [action]", - Short: "Manage interaction with Polling API (PAPI)", - Args: cobra.MinimumNArgs(1), - DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if err := csConfig.LoadAPIServer(); err != nil || csConfig.DisableAPI { - return fmt.Errorf("Local API is disabled, please run this command on the local API machine: %w", err) - } - if csConfig.API.Server.OnlineClient == nil { - log.Fatalf("no configuration for Central API in '%s'", *csConfig.FilePath) - } - if csConfig.API.Server.OnlineClient.Credentials.PapiURL == "" { - log.Fatalf("no PAPI URL in configuration") - } - return nil - }, - } - - cmdLapi.AddCommand(NewPapiStatusCmd()) - cmdLapi.AddCommand(NewPapiSyncCmd()) - - return cmdLapi -} - -func NewPapiStatusCmd() *cobra.Command { - cmdCapiStatus := &cobra.Command{ - Use: "status", - Short: "Get status of the Polling API", - Args: cobra.MinimumNArgs(0), - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - var err error - dbClient, err = database.NewClient(csConfig.DbConfig) - if err != nil { - log.Fatalf("unable to initialize database client : %s", err) - } - - apic, err := apiserver.NewAPIC(csConfig.API.Server.OnlineClient, dbClient, csConfig.API.Server.ConsoleConfig, csConfig.API.Server.CapiWhitelists) - - if err != nil { - log.Fatalf("unable to initialize API client : %s", err) - } - - papi, err := apiserver.NewPAPI(apic, dbClient, csConfig.API.Server.ConsoleConfig, log.GetLevel()) - - if err != nil { - log.Fatalf("unable to initialize PAPI client : %s", err) - } - - perms, err := papi.GetPermissions() - - if err != nil { - log.Fatalf("unable to get PAPI permissions: %s", err) - } - var lastTimestampStr *string - lastTimestampStr, err = dbClient.GetConfigItem(apiserver.PapiPullKey) - if err != nil { - lastTimestampStr = ptr.Of("never") - } - log.Infof("You can successfully interact with Polling API (PAPI)") - log.Infof("Console plan: %s", perms.Plan) - log.Infof("Last order received: %s", *lastTimestampStr) - - log.Infof("PAPI subscriptions:") - for _, sub := range perms.Categories { - log.Infof(" - %s", sub) - } - }, - } - - return cmdCapiStatus -} - -func NewPapiSyncCmd() *cobra.Command { - cmdCapiSync := &cobra.Command{ - Use: "sync", - Short: "Sync with the Polling API, pulling all non-expired orders for the instance", - Args: cobra.MinimumNArgs(0), - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - var err error - t := tomb.Tomb{} - dbClient, err = database.NewClient(csConfig.DbConfig) - if err != nil { - log.Fatalf("unable to initialize database client : %s", err) - } - - apic, err := apiserver.NewAPIC(csConfig.API.Server.OnlineClient, dbClient, csConfig.API.Server.ConsoleConfig, csConfig.API.Server.CapiWhitelists) - - if err != nil { - log.Fatalf("unable to initialize API client : %s", err) - } - - t.Go(apic.Push) - - papi, err := apiserver.NewPAPI(apic, dbClient, csConfig.API.Server.ConsoleConfig, log.GetLevel()) - - if err != nil { - log.Fatalf("unable to initialize PAPI client : %s", err) - } - t.Go(papi.SyncDecisions) - - err = papi.PullOnce(time.Time{}, true) - - if err != nil { - log.Fatalf("unable to sync decisions: %s", err) - } - - log.Infof("Sending acknowledgements to CAPI") - - apic.Shutdown() - papi.Shutdown() - t.Wait() - time.Sleep(5 * time.Second) //FIXME: the push done by apic.Push is run inside a sub goroutine, sleep to make sure it's done - - }, - } - - return cmdCapiSync -} diff --git a/cmd/crowdsec-cli/parsers.go b/cmd/crowdsec-cli/parsers.go deleted file mode 100644 index 9b810238b76..00000000000 --- a/cmd/crowdsec-cli/parsers.go +++ /dev/null @@ -1,202 +0,0 @@ -package main - -import ( - "fmt" - - "github.com/fatih/color" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - - "github.com/crowdsecurity/crowdsec/pkg/cwhub" -) - - -func NewParsersCmd() *cobra.Command { - var cmdParsers = &cobra.Command{ - Use: "parsers [action] [config]", - Short: "Install/Remove/Upgrade/Inspect parser(s) from hub", - Example: `cscli parsers install crowdsecurity/sshd-logs -cscli parsers inspect crowdsecurity/sshd-logs -cscli parsers upgrade crowdsecurity/sshd-logs -cscli parsers list -cscli parsers remove crowdsecurity/sshd-logs -`, - Args: cobra.MinimumNArgs(1), - Aliases: []string{"parser"}, - DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if err := csConfig.LoadHub(); err != nil { - log.Fatal(err) - } - if csConfig.Hub == nil { - return fmt.Errorf("you must configure cli before interacting with hub") - } - - if err := cwhub.SetHubBranch(); err != nil { - return fmt.Errorf("error while setting hub branch: %s", err) - } - - if err := cwhub.GetHubIdx(csConfig.Hub); err != nil { - log.Info("Run 'sudo cscli hub update' to get the hub index") - log.Fatalf("Failed to get Hub index : %v", err) - } - return nil - }, - PersistentPostRun: func(cmd *cobra.Command, args []string) { - if cmd.Name() == "inspect" || cmd.Name() == "list" { - return - } - log.Infof(ReloadMessage()) - }, - } - - cmdParsers.AddCommand(NewParsersInstallCmd()) - cmdParsers.AddCommand(NewParsersRemoveCmd()) - cmdParsers.AddCommand(NewParsersUpgradeCmd()) - cmdParsers.AddCommand(NewParsersInspectCmd()) - cmdParsers.AddCommand(NewParsersListCmd()) - - return cmdParsers -} - - -func NewParsersInstallCmd() *cobra.Command { - var ignoreError bool - - var cmdParsersInstall = &cobra.Command{ - Use: "install [config]", - Short: "Install given parser(s)", - Long: `Fetch and install given parser(s) from hub`, - Example: `cscli parsers install crowdsec/xxx crowdsec/xyz`, - Args: cobra.MinimumNArgs(1), - DisableAutoGenTag: true, - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - return compAllItems(cwhub.PARSERS, args, toComplete) - }, - Run: func(cmd *cobra.Command, args []string) { - for _, name := range args { - t := cwhub.GetItem(cwhub.PARSERS, name) - if t == nil { - nearestItem, score := GetDistance(cwhub.PARSERS, name) - Suggest(cwhub.PARSERS, name, nearestItem.Name, score, ignoreError) - continue - } - if err := cwhub.InstallItem(csConfig, name, cwhub.PARSERS, forceAction, downloadOnly); err != nil { - if ignoreError { - log.Errorf("Error while installing '%s': %s", name, err) - } else { - log.Fatalf("Error while installing '%s': %s", name, err) - } - } - } - }, - } - cmdParsersInstall.PersistentFlags().BoolVarP(&downloadOnly, "download-only", "d", false, "Only download packages, don't enable") - cmdParsersInstall.PersistentFlags().BoolVar(&forceAction, "force", false, "Force install : Overwrite tainted and outdated files") - cmdParsersInstall.PersistentFlags().BoolVar(&ignoreError, "ignore", false, "Ignore errors when installing multiple parsers") - - return cmdParsersInstall -} - - -func NewParsersRemoveCmd() *cobra.Command { - var cmdParsersRemove = &cobra.Command{ - Use: "remove [config]", - Short: "Remove given parser(s)", - Long: `Remove given parse(s) from hub`, - Aliases: []string{"delete"}, - Example: `cscli parsers remove crowdsec/xxx crowdsec/xyz`, - DisableAutoGenTag: true, - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - return compInstalledItems(cwhub.PARSERS, args, toComplete) - }, - Run: func(cmd *cobra.Command, args []string) { - if all { - cwhub.RemoveMany(csConfig, cwhub.PARSERS, "", all, purge, forceAction) - return - } - - if len(args) == 0 { - log.Fatalf("Specify at least one parser to remove or '--all' flag.") - } - - for _, name := range args { - cwhub.RemoveMany(csConfig, cwhub.PARSERS, name, all, purge, forceAction) - } - }, - } - cmdParsersRemove.PersistentFlags().BoolVar(&purge, "purge", false, "Delete source file too") - cmdParsersRemove.PersistentFlags().BoolVar(&forceAction, "force", false, "Force remove : Remove tainted and outdated files") - cmdParsersRemove.PersistentFlags().BoolVar(&all, "all", false, "Delete all the parsers") - - return cmdParsersRemove -} - - -func NewParsersUpgradeCmd() *cobra.Command { - var cmdParsersUpgrade = &cobra.Command{ - Use: "upgrade [config]", - Short: "Upgrade given parser(s)", - Long: `Fetch and upgrade given parser(s) from hub`, - Example: `cscli parsers upgrade crowdsec/xxx crowdsec/xyz`, - DisableAutoGenTag: true, - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - return compInstalledItems(cwhub.PARSERS, args, toComplete) - }, - Run: func(cmd *cobra.Command, args []string) { - if all { - cwhub.UpgradeConfig(csConfig, cwhub.PARSERS, "", forceAction) - } else { - if len(args) == 0 { - log.Fatalf("no target parser to upgrade") - } - for _, name := range args { - cwhub.UpgradeConfig(csConfig, cwhub.PARSERS, name, forceAction) - } - } - }, - } - cmdParsersUpgrade.PersistentFlags().BoolVar(&all, "all", false, "Upgrade all the parsers") - cmdParsersUpgrade.PersistentFlags().BoolVar(&forceAction, "force", false, "Force upgrade : Overwrite tainted and outdated files") - - return cmdParsersUpgrade -} - - -func NewParsersInspectCmd() *cobra.Command { - var cmdParsersInspect = &cobra.Command{ - Use: "inspect [name]", - Short: "Inspect given parser", - Long: `Inspect given parser`, - Example: `cscli parsers inspect crowdsec/xxx`, - DisableAutoGenTag: true, - Args: cobra.MinimumNArgs(1), - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - return compInstalledItems(cwhub.PARSERS, args, toComplete) - }, - Run: func(cmd *cobra.Command, args []string) { - InspectItem(args[0], cwhub.PARSERS) - }, - } - cmdParsersInspect.PersistentFlags().StringVarP(&prometheusURL, "url", "u", "", "Prometheus url") - - return cmdParsersInspect -} - - -func NewParsersListCmd() *cobra.Command { - var cmdParsersList = &cobra.Command{ - Use: "list [name]", - Short: "List all parsers or given one", - Long: `List all parsers or given one`, - Example: `cscli parsers list -cscli parser list crowdsecurity/xxx`, - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - ListItems(color.Output, []string{cwhub.PARSERS}, args, false, true, all) - }, - } - cmdParsersList.PersistentFlags().BoolVarP(&all, "all", "a", false, "List disabled items as well") - - return cmdParsersList -} diff --git a/cmd/crowdsec-cli/postoverflows.go b/cmd/crowdsec-cli/postoverflows.go deleted file mode 100644 index 19cffccd212..00000000000 --- a/cmd/crowdsec-cli/postoverflows.go +++ /dev/null @@ -1,200 +0,0 @@ -package main - -import ( - "fmt" - - "github.com/fatih/color" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - - "github.com/crowdsecurity/crowdsec/pkg/cwhub" -) - -func NewPostOverflowsInstallCmd() *cobra.Command { - var ignoreError bool - - cmdPostOverflowsInstall := &cobra.Command{ - Use: "install [config]", - Short: "Install given postoverflow(s)", - Long: `Fetch and install given postoverflow(s) from hub`, - Example: `cscli postoverflows install crowdsec/xxx crowdsec/xyz`, - Args: cobra.MinimumNArgs(1), - DisableAutoGenTag: true, - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - return compAllItems(cwhub.PARSERS_OVFLW, args, toComplete) - }, - Run: func(cmd *cobra.Command, args []string) { - for _, name := range args { - t := cwhub.GetItem(cwhub.PARSERS_OVFLW, name) - if t == nil { - nearestItem, score := GetDistance(cwhub.PARSERS_OVFLW, name) - Suggest(cwhub.PARSERS_OVFLW, name, nearestItem.Name, score, ignoreError) - continue - } - if err := cwhub.InstallItem(csConfig, name, cwhub.PARSERS_OVFLW, forceAction, downloadOnly); err != nil { - if ignoreError { - log.Errorf("Error while installing '%s': %s", name, err) - } else { - log.Fatalf("Error while installing '%s': %s", name, err) - } - } - } - }, - } - - cmdPostOverflowsInstall.PersistentFlags().BoolVarP(&downloadOnly, "download-only", "d", false, "Only download packages, don't enable") - cmdPostOverflowsInstall.PersistentFlags().BoolVar(&forceAction, "force", false, "Force install : Overwrite tainted and outdated files") - cmdPostOverflowsInstall.PersistentFlags().BoolVar(&ignoreError, "ignore", false, "Ignore errors when installing multiple postoverflows") - - return cmdPostOverflowsInstall -} - -func NewPostOverflowsRemoveCmd() *cobra.Command { - cmdPostOverflowsRemove := &cobra.Command{ - Use: "remove [config]", - Short: "Remove given postoverflow(s)", - Long: `remove given postoverflow(s)`, - Example: `cscli postoverflows remove crowdsec/xxx crowdsec/xyz`, - DisableAutoGenTag: true, - Aliases: []string{"delete"}, - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - return compInstalledItems(cwhub.PARSERS_OVFLW, args, toComplete) - }, - Run: func(cmd *cobra.Command, args []string) { - if all { - cwhub.RemoveMany(csConfig, cwhub.PARSERS_OVFLW, "", all, purge, forceAction) - return - } - - if len(args) == 0 { - log.Fatalf("Specify at least one postoverflow to remove or '--all' flag.") - } - - for _, name := range args { - cwhub.RemoveMany(csConfig, cwhub.PARSERS_OVFLW, name, all, purge, forceAction) - } - }, - } - - cmdPostOverflowsRemove.PersistentFlags().BoolVar(&purge, "purge", false, "Delete source file too") - cmdPostOverflowsRemove.PersistentFlags().BoolVar(&forceAction, "force", false, "Force remove : Remove tainted and outdated files") - cmdPostOverflowsRemove.PersistentFlags().BoolVar(&all, "all", false, "Delete all the postoverflows") - - return cmdPostOverflowsRemove -} - -func NewPostOverflowsUpgradeCmd() *cobra.Command { - cmdPostOverflowsUpgrade := &cobra.Command{ - Use: "upgrade [config]", - Short: "Upgrade given postoverflow(s)", - Long: `Fetch and Upgrade given postoverflow(s) from hub`, - Example: `cscli postoverflows upgrade crowdsec/xxx crowdsec/xyz`, - DisableAutoGenTag: true, - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - return compInstalledItems(cwhub.PARSERS_OVFLW, args, toComplete) - }, - Run: func(cmd *cobra.Command, args []string) { - if all { - cwhub.UpgradeConfig(csConfig, cwhub.PARSERS_OVFLW, "", forceAction) - } else { - if len(args) == 0 { - log.Fatalf("no target postoverflow to upgrade") - } - for _, name := range args { - cwhub.UpgradeConfig(csConfig, cwhub.PARSERS_OVFLW, name, forceAction) - } - } - }, - } - - cmdPostOverflowsUpgrade.PersistentFlags().BoolVarP(&all, "all", "a", false, "Upgrade all the postoverflows") - cmdPostOverflowsUpgrade.PersistentFlags().BoolVar(&forceAction, "force", false, "Force upgrade : Overwrite tainted and outdated files") - - return cmdPostOverflowsUpgrade -} - -func NewPostOverflowsInspectCmd() *cobra.Command { - cmdPostOverflowsInspect := &cobra.Command{ - Use: "inspect [config]", - Short: "Inspect given postoverflow", - Long: `Inspect given postoverflow`, - Example: `cscli postoverflows inspect crowdsec/xxx crowdsec/xyz`, - DisableAutoGenTag: true, - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - return compInstalledItems(cwhub.PARSERS_OVFLW, args, toComplete) - }, - Args: cobra.MinimumNArgs(1), - Run: func(cmd *cobra.Command, args []string) { - InspectItem(args[0], cwhub.PARSERS_OVFLW) - }, - } - - return cmdPostOverflowsInspect -} - -func NewPostOverflowsListCmd() *cobra.Command { - cmdPostOverflowsList := &cobra.Command{ - Use: "list [config]", - Short: "List all postoverflows or given one", - Long: `List all postoverflows or given one`, - Example: `cscli postoverflows list -cscli postoverflows list crowdsecurity/xxx`, - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - ListItems(color.Output, []string{cwhub.PARSERS_OVFLW}, args, false, true, all) - }, - } - - cmdPostOverflowsList.PersistentFlags().BoolVarP(&all, "all", "a", false, "List disabled items as well") - - return cmdPostOverflowsList -} - - - -func NewPostOverflowsCmd() *cobra.Command { - cmdPostOverflows := &cobra.Command{ - Use: "postoverflows [action] [config]", - Short: "Install/Remove/Upgrade/Inspect postoverflow(s) from hub", - Example: `cscli postoverflows install crowdsecurity/cdn-whitelist - cscli postoverflows inspect crowdsecurity/cdn-whitelist - cscli postoverflows upgrade crowdsecurity/cdn-whitelist - cscli postoverflows list - cscli postoverflows remove crowdsecurity/cdn-whitelist`, - Args: cobra.MinimumNArgs(1), - Aliases: []string{"postoverflow"}, - DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if err := csConfig.LoadHub(); err != nil { - log.Fatal(err) - } - if csConfig.Hub == nil { - return fmt.Errorf("you must configure cli before interacting with hub") - } - - if err := cwhub.SetHubBranch(); err != nil { - return fmt.Errorf("error while setting hub branch: %s", err) - } - - if err := cwhub.GetHubIdx(csConfig.Hub); err != nil { - log.Info("Run 'sudo cscli hub update' to get the hub index") - log.Fatalf("Failed to get Hub index : %v", err) - } - return nil - }, - PersistentPostRun: func(cmd *cobra.Command, args []string) { - if cmd.Name() == "inspect" || cmd.Name() == "list" { - return - } - log.Infof(ReloadMessage()) - }, - } - - cmdPostOverflows.AddCommand(NewPostOverflowsInstallCmd()) - cmdPostOverflows.AddCommand(NewPostOverflowsRemoveCmd()) - cmdPostOverflows.AddCommand(NewPostOverflowsUpgradeCmd()) - cmdPostOverflows.AddCommand(NewPostOverflowsInspectCmd()) - cmdPostOverflows.AddCommand(NewPostOverflowsListCmd()) - - return cmdPostOverflows -} diff --git a/cmd/crowdsec-cli/reload/reload.go b/cmd/crowdsec-cli/reload/reload.go new file mode 100644 index 00000000000..fe03af1ea79 --- /dev/null +++ b/cmd/crowdsec-cli/reload/reload.go @@ -0,0 +1,6 @@ +//go:build !windows && !freebsd && !linux + +package reload + +// generic message since we don't know the platform +const Message = "Please reload the crowdsec process for the new configuration to be effective." diff --git a/cmd/crowdsec-cli/reload/reload_freebsd.go b/cmd/crowdsec-cli/reload/reload_freebsd.go new file mode 100644 index 00000000000..0dac99f2315 --- /dev/null +++ b/cmd/crowdsec-cli/reload/reload_freebsd.go @@ -0,0 +1,4 @@ +package reload + +// actually sudo is not that popular on freebsd, but this will do +const Message = "Run 'sudo service crowdsec reload' for the new configuration to be effective." diff --git a/cmd/crowdsec-cli/reload/reload_linux.go b/cmd/crowdsec-cli/reload/reload_linux.go new file mode 100644 index 00000000000..fbe16e5f168 --- /dev/null +++ b/cmd/crowdsec-cli/reload/reload_linux.go @@ -0,0 +1,4 @@ +package reload + +// assume systemd, although gentoo and others may differ +const Message = "Run 'sudo systemctl reload crowdsec' for the new configuration to be effective." diff --git a/cmd/crowdsec-cli/reload/reload_windows.go b/cmd/crowdsec-cli/reload/reload_windows.go new file mode 100644 index 00000000000..88642425ae2 --- /dev/null +++ b/cmd/crowdsec-cli/reload/reload_windows.go @@ -0,0 +1,3 @@ +package reload + +const Message = "Please restart the crowdsec service for the new configuration to be effective." diff --git a/cmd/crowdsec-cli/require/branch.go b/cmd/crowdsec-cli/require/branch.go new file mode 100644 index 00000000000..09acc0fef8a --- /dev/null +++ b/cmd/crowdsec-cli/require/branch.go @@ -0,0 +1,108 @@ +package require + +// Set the appropriate hub branch according to config settings and crowdsec version + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/mod/semver" + + "github.com/crowdsecurity/go-cs-lib/version" + + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/cwversion" +) + +// lookupLatest returns the latest crowdsec version based on github +func lookupLatest(ctx context.Context) (string, error) { + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + url := "https://version.crowdsec.net/latest" + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return "", fmt.Errorf("unable to create request for %s: %w", url, err) + } + + client := &http.Client{} + + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("unable to send request to %s: %w", url, err) + } + defer resp.Body.Close() + + latest := make(map[string]any) + + if err := json.NewDecoder(resp.Body).Decode(&latest); err != nil { + return "", fmt.Errorf("unable to decode response from %s: %w", url, err) + } + + if _, ok := latest["name"]; !ok { + return "", fmt.Errorf("unable to find 'name' key in response from %s", url) + } + + name, ok := latest["name"].(string) + if !ok { + return "", fmt.Errorf("unable to convert 'name' key to string in response from %s", url) + } + + return name, nil +} + +func chooseBranch(ctx context.Context, cfg *csconfig.Config) string { + // this was set from config.yaml or flag + if cfg.Cscli.HubBranch != "" { + log.Debugf("Hub override from config: branch '%s'", cfg.Cscli.HubBranch) + return cfg.Cscli.HubBranch + } + + latest, err := lookupLatest(ctx) + if err != nil { + log.Warningf("Unable to retrieve latest crowdsec version: %s, using hub branch 'master'", err) + return "master" + } + + csVersion := cwversion.VersionStrip() + if csVersion == "" { + log.Warning("Crowdsec version is not set, using hub branch 'master'") + return "master" + } + + if csVersion == latest { + log.Debugf("Latest crowdsec version (%s), using hub branch 'master'", version.String()) + return "master" + } + + // if current version is greater than the latest we are in pre-release + if semver.Compare(csVersion, latest) == 1 { + log.Debugf("Your current crowdsec version seems to be a pre-release (%s), using hub branch 'master'", version.String()) + return "master" + } + + log.Warnf("A new CrowdSec release is available (%s). "+ + "Your version is '%s'. Please update it to use new parsers/scenarios/collections.", + latest, csVersion) + + return csVersion +} + +// HubBranch sets the branch (in cscli config) and returns its value +// It can be "master", or the branch corresponding to the current crowdsec version, or the value overridden in config/flag +func HubBranch(ctx context.Context, cfg *csconfig.Config) string { + branch := chooseBranch(ctx, cfg) + + cfg.Cscli.HubBranch = branch + + return branch +} + +func HubURLTemplate(cfg *csconfig.Config) string { + return cfg.Cscli.HubURLTemplate +} diff --git a/cmd/crowdsec-cli/require/require.go b/cmd/crowdsec-cli/require/require.go new file mode 100644 index 00000000000..191eee55bc5 --- /dev/null +++ b/cmd/crowdsec-cli/require/require.go @@ -0,0 +1,123 @@ +package require + +import ( + "context" + "errors" + "fmt" + "io" + + "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" + "github.com/crowdsecurity/crowdsec/pkg/database" +) + +func LAPI(c *csconfig.Config) error { + if err := c.LoadAPIServer(true); err != nil { + return fmt.Errorf("failed to load Local API: %w", err) + } + + if c.DisableAPI { + return errors.New("local API is disabled -- this command must be run on the local API machine") + } + + return nil +} + +func CAPI(c *csconfig.Config) error { + if c.API.Server.OnlineClient == nil { + return fmt.Errorf("no configuration for Central API (CAPI) in '%s'", *c.FilePath) + } + + return nil +} + +func PAPI(c *csconfig.Config) error { + if err := CAPI(c); err != nil { + return err + } + + if err := CAPIRegistered(c); err != nil { + return err + } + + if c.API.Server.OnlineClient.Credentials.PapiURL == "" { + return errors.New("no PAPI URL in configuration") + } + + return nil +} + +func CAPIRegistered(c *csconfig.Config) error { + if c.API.Server.OnlineClient.Credentials == nil { + return errors.New("the Central API (CAPI) must be configured with 'cscli capi register'") + } + + return nil +} + +func DBClient(ctx context.Context, dbcfg *csconfig.DatabaseCfg) (*database.Client, error) { + db, err := database.NewClient(ctx, dbcfg) + if err != nil { + return nil, fmt.Errorf("failed to connect to database: %w", err) + } + + return db, nil +} + +func DB(c *csconfig.Config) error { + if err := c.LoadDBConfig(true); err != nil { + return fmt.Errorf("this command requires direct database access (must be run on the local API machine): %w", err) + } + + return nil +} + +func Notifications(c *csconfig.Config) error { + if c.ConfigPaths.NotificationDir == "" { + return errors.New("config_paths.notification_dir is not set in crowdsec config") + } + + return nil +} + +// RemoteHub returns the configuration required to download hub index and items: url, branch, etc. +func RemoteHub(ctx context.Context, c *csconfig.Config) *cwhub.RemoteHubCfg { + // set branch in config, and log if necessary + branch := HubBranch(ctx, c) + urlTemplate := HubURLTemplate(c) + remote := &cwhub.RemoteHubCfg{ + Branch: branch, + URLTemplate: urlTemplate, + IndexPath: ".index.json", + } + + return remote +} + +// Hub initializes the hub. If a remote configuration is provided, it can be used to download the index and items. +// If no remote parameter is provided, the hub can only be used for local operations. +func Hub(c *csconfig.Config, remote *cwhub.RemoteHubCfg, logger *logrus.Logger) (*cwhub.Hub, error) { + local := c.Hub + + if local == nil { + return nil, errors.New("you must configure cli before interacting with hub") + } + + if logger == nil { + logger = logrus.New() + logger.SetOutput(io.Discard) + } + + hub, err := cwhub.NewHub(local, remote, logger) + if err != nil { + return nil, err + } + + if err := hub.Load(); err != nil { + return nil, fmt.Errorf("failed to read Hub index: %w. Run 'sudo cscli hub update' to download the index again", err) + } + + return hub, nil +} diff --git a/cmd/crowdsec-cli/scenarios.go b/cmd/crowdsec-cli/scenarios.go deleted file mode 100644 index de52dcb4876..00000000000 --- a/cmd/crowdsec-cli/scenarios.go +++ /dev/null @@ -1,197 +0,0 @@ -package main - -import ( - "fmt" - - "github.com/fatih/color" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - - "github.com/crowdsecurity/crowdsec/pkg/cwhub" -) - -func NewScenariosCmd() *cobra.Command { - var cmdScenarios = &cobra.Command{ - Use: "scenarios [action] [config]", - Short: "Install/Remove/Upgrade/Inspect scenario(s) from hub", - Example: `cscli scenarios list [-a] -cscli scenarios install crowdsecurity/ssh-bf -cscli scenarios inspect crowdsecurity/ssh-bf -cscli scenarios upgrade crowdsecurity/ssh-bf -cscli scenarios remove crowdsecurity/ssh-bf -`, - Args: cobra.MinimumNArgs(1), - Aliases: []string{"scenario"}, - DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if err := csConfig.LoadHub(); err != nil { - log.Fatal(err) - } - if csConfig.Hub == nil { - return fmt.Errorf("you must configure cli before interacting with hub") - } - - if err := cwhub.SetHubBranch(); err != nil { - return fmt.Errorf("while setting hub branch: %w", err) - } - - if err := cwhub.GetHubIdx(csConfig.Hub); err != nil { - log.Info("Run 'sudo cscli hub update' to get the hub index") - log.Fatalf("Failed to get Hub index : %v", err) - } - - return nil - }, - PersistentPostRun: func(cmd *cobra.Command, args []string) { - if cmd.Name() == "inspect" || cmd.Name() == "list" { - return - } - log.Infof(ReloadMessage()) - }, - } - - cmdScenarios.AddCommand(NewCmdScenariosInstall()) - cmdScenarios.AddCommand(NewCmdScenariosRemove()) - cmdScenarios.AddCommand(NewCmdScenariosUpgrade()) - cmdScenarios.AddCommand(NewCmdScenariosInspect()) - cmdScenarios.AddCommand(NewCmdScenariosList()) - - return cmdScenarios -} - -func NewCmdScenariosInstall() *cobra.Command { - var ignoreError bool - - var cmdScenariosInstall = &cobra.Command{ - Use: "install [config]", - Short: "Install given scenario(s)", - Long: `Fetch and install given scenario(s) from hub`, - Example: `cscli scenarios install crowdsec/xxx crowdsec/xyz`, - Args: cobra.MinimumNArgs(1), - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - return compAllItems(cwhub.SCENARIOS, args, toComplete) - }, - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - for _, name := range args { - t := cwhub.GetItem(cwhub.SCENARIOS, name) - if t == nil { - nearestItem, score := GetDistance(cwhub.SCENARIOS, name) - Suggest(cwhub.SCENARIOS, name, nearestItem.Name, score, ignoreError) - continue - } - if err := cwhub.InstallItem(csConfig, name, cwhub.SCENARIOS, forceAction, downloadOnly); err != nil { - if ignoreError { - log.Errorf("Error while installing '%s': %s", name, err) - } else { - log.Fatalf("Error while installing '%s': %s", name, err) - } - } - } - }, - } - cmdScenariosInstall.PersistentFlags().BoolVarP(&downloadOnly, "download-only", "d", false, "Only download packages, don't enable") - cmdScenariosInstall.PersistentFlags().BoolVar(&forceAction, "force", false, "Force install : Overwrite tainted and outdated files") - cmdScenariosInstall.PersistentFlags().BoolVar(&ignoreError, "ignore", false, "Ignore errors when installing multiple scenarios") - - return cmdScenariosInstall -} - -func NewCmdScenariosRemove() *cobra.Command { - var cmdScenariosRemove = &cobra.Command{ - Use: "remove [config]", - Short: "Remove given scenario(s)", - Long: `remove given scenario(s)`, - Example: `cscli scenarios remove crowdsec/xxx crowdsec/xyz`, - Aliases: []string{"delete"}, - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - return compInstalledItems(cwhub.SCENARIOS, args, toComplete) - }, - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - if all { - cwhub.RemoveMany(csConfig, cwhub.SCENARIOS, "", all, purge, forceAction) - return - } - - if len(args) == 0 { - log.Fatalf("Specify at least one scenario to remove or '--all' flag.") - } - - for _, name := range args { - cwhub.RemoveMany(csConfig, cwhub.SCENARIOS, name, all, purge, forceAction) - } - }, - } - cmdScenariosRemove.PersistentFlags().BoolVar(&purge, "purge", false, "Delete source file too") - cmdScenariosRemove.PersistentFlags().BoolVar(&forceAction, "force", false, "Force remove : Remove tainted and outdated files") - cmdScenariosRemove.PersistentFlags().BoolVar(&all, "all", false, "Delete all the scenarios") - - return cmdScenariosRemove -} - -func NewCmdScenariosUpgrade() *cobra.Command { - var cmdScenariosUpgrade = &cobra.Command{ - Use: "upgrade [config]", - Short: "Upgrade given scenario(s)", - Long: `Fetch and Upgrade given scenario(s) from hub`, - Example: `cscli scenarios upgrade crowdsec/xxx crowdsec/xyz`, - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - return compInstalledItems(cwhub.SCENARIOS, args, toComplete) - }, - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - if all { - cwhub.UpgradeConfig(csConfig, cwhub.SCENARIOS, "", forceAction) - } else { - if len(args) == 0 { - log.Fatalf("no target scenario to upgrade") - } - for _, name := range args { - cwhub.UpgradeConfig(csConfig, cwhub.SCENARIOS, name, forceAction) - } - } - }, - } - cmdScenariosUpgrade.PersistentFlags().BoolVarP(&all, "all", "a", false, "Upgrade all the scenarios") - cmdScenariosUpgrade.PersistentFlags().BoolVar(&forceAction, "force", false, "Force upgrade : Overwrite tainted and outdated files") - - return cmdScenariosUpgrade -} - -func NewCmdScenariosInspect() *cobra.Command { - var cmdScenariosInspect = &cobra.Command{ - Use: "inspect [config]", - Short: "Inspect given scenario", - Long: `Inspect given scenario`, - Example: `cscli scenarios inspect crowdsec/xxx`, - Args: cobra.MinimumNArgs(1), - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - return compInstalledItems(cwhub.SCENARIOS, args, toComplete) - }, - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - InspectItem(args[0], cwhub.SCENARIOS) - }, - } - cmdScenariosInspect.PersistentFlags().StringVarP(&prometheusURL, "url", "u", "", "Prometheus url") - - return cmdScenariosInspect -} - -func NewCmdScenariosList() *cobra.Command { - var cmdScenariosList = &cobra.Command{ - Use: "list [config]", - Short: "List all scenario(s) or given one", - Long: `List all scenario(s) or given one`, - Example: `cscli scenarios list -cscli scenarios list crowdsecurity/xxx`, - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - ListItems(color.Output, []string{cwhub.SCENARIOS}, args, false, true, all) - }, - } - cmdScenariosList.PersistentFlags().BoolVarP(&all, "all", "a", false, "List disabled items as well") - - return cmdScenariosList -} diff --git a/cmd/crowdsec-cli/setup.go b/cmd/crowdsec-cli/setup.go index 7f1da4c4456..66c0d71e777 100644 --- a/cmd/crowdsec-cli/setup.go +++ b/cmd/crowdsec-cli/setup.go @@ -1,312 +1,18 @@ +//go:build !no_cscli_setup package main import ( - "bytes" - "fmt" - "os" - "os/exec" - - log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "gopkg.in/yaml.v3" - goccyyaml "github.com/goccy/go-yaml" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" - "github.com/crowdsecurity/crowdsec/pkg/setup" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clisetup" + "github.com/crowdsecurity/crowdsec/pkg/cwversion/component" + "github.com/crowdsecurity/crowdsec/pkg/fflag" ) -// NewSetupCmd defines the "cscli setup" command. -func NewSetupCmd() *cobra.Command { - cmdSetup := &cobra.Command{ - Use: "setup", - Short: "Tools to configure crowdsec", - Long: "Manage hub configuration and service detection", - Args: cobra.MinimumNArgs(0), - DisableAutoGenTag: true, - } - - // - // cscli setup detect - // - { - cmdSetupDetect := &cobra.Command{ - Use: "detect", - Short: "detect running services, generate a setup file", - DisableAutoGenTag: true, - RunE: runSetupDetect, - } - - defaultServiceDetect := csconfig.DefaultConfigPath("hub", "detect.yaml") - - flags := cmdSetupDetect.Flags() - flags.String("detect-config", defaultServiceDetect, "path to service detection configuration") - flags.Bool("list-supported-services", false, "do not detect; only print supported services") - flags.StringSlice("force-unit", nil, "force detection of a systemd unit (can be repeated)") - flags.StringSlice("force-process", nil, "force detection of a running process (can be repeated)") - flags.StringSlice("skip-service", nil, "ignore a service, don't recommend hub/datasources (can be repeated)") - flags.String("force-os-family", "", "override OS.Family: one of linux, freebsd, windows or darwin") - flags.String("force-os-id", "", "override OS.ID=[debian | ubuntu | , redhat...]") - flags.String("force-os-version", "", "override OS.RawVersion (of OS or Linux distribution)") - flags.Bool("snub-systemd", false, "don't use systemd, even if available") - flags.Bool("yaml", false, "output yaml, not json") - cmdSetup.AddCommand(cmdSetupDetect) - } - - // - // cscli setup install-hub - // - { - cmdSetupInstallHub := &cobra.Command{ - Use: "install-hub [setup_file] [flags]", - Short: "install items from a setup file", - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - RunE: runSetupInstallHub, - } - - flags := cmdSetupInstallHub.Flags() - flags.Bool("dry-run", false, "don't install anything; print out what would have been") - cmdSetup.AddCommand(cmdSetupInstallHub) - } - - // - // cscli setup datasources - // - { - cmdSetupDataSources := &cobra.Command{ - Use: "datasources [setup_file] [flags]", - Short: "generate datasource (acquisition) configuration from a setup file", - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - RunE: runSetupDataSources, - } - - flags := cmdSetupDataSources.Flags() - flags.String("to-dir", "", "write the configuration to a directory, in multiple files") - cmdSetup.AddCommand(cmdSetupDataSources) - } - - // - // cscli setup validate - // - { - cmdSetupValidate := &cobra.Command{ - Use: "validate [setup_file]", - Short: "validate a setup file", - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - RunE: runSetupValidate, - } - - cmdSetup.AddCommand(cmdSetupValidate) - } - - return cmdSetup -} - -func runSetupDetect(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - detectConfigFile, err := flags.GetString("detect-config") - if err != nil { - return err - } - - listSupportedServices, err := flags.GetBool("list-supported-services") - if err != nil { - return err - } - - forcedUnits, err := flags.GetStringSlice("force-unit") - if err != nil { - return err - } - - forcedProcesses, err := flags.GetStringSlice("force-process") - if err != nil { - return err - } - - forcedOSFamily, err := flags.GetString("force-os-family") - if err != nil { - return err - } - - forcedOSID, err := flags.GetString("force-os-id") - if err != nil { - return err - } - - forcedOSVersion, err := flags.GetString("force-os-version") - if err != nil { - return err - } - - skipServices, err := flags.GetStringSlice("skip-service") - if err != nil { - return err - } - - snubSystemd, err := flags.GetBool("snub-systemd") - if err != nil { - return err - } - - if !snubSystemd { - _, err := exec.LookPath("systemctl") - if err != nil { - log.Debug("systemctl not available: snubbing systemd") - snubSystemd = true - } - } - - outYaml, err := flags.GetBool("yaml") - if err != nil { - return err - } - - if forcedOSFamily == "" && forcedOSID != "" { - log.Debug("force-os-id is set: force-os-family defaults to 'linux'") - forcedOSFamily = "linux" - } - - if listSupportedServices { - supported, err := setup.ListSupported(detectConfigFile) - if err != nil { - return err - } - - for _, svc := range supported { - fmt.Println(svc) - } - - return nil - } - - opts := setup.DetectOptions{ - ForcedUnits: forcedUnits, - ForcedProcesses: forcedProcesses, - ForcedOS: setup.ExprOS{ - Family: forcedOSFamily, - ID: forcedOSID, - RawVersion: forcedOSVersion, - }, - SkipServices: skipServices, - SnubSystemd: snubSystemd, - } - - hubSetup, err := setup.Detect(detectConfigFile, opts) - if err != nil { - return fmt.Errorf("detecting services: %w", err) - } - - setup, err := setupAsString(hubSetup, outYaml) - if err != nil { - return err - } - fmt.Println(setup) - - return nil -} - -func setupAsString(cs setup.Setup, outYaml bool) (string, error) { - var ( - ret []byte - err error - ) - - wrap := func(err error) error { - return fmt.Errorf("while marshaling setup: %w", err) - } - - indentLevel := 2 - buf := &bytes.Buffer{} - enc := yaml.NewEncoder(buf) - enc.SetIndent(indentLevel) - - if err = enc.Encode(cs); err != nil { - return "", wrap(err) - } - - if err = enc.Close(); err != nil { - return "", wrap(err) - } - - ret = buf.Bytes() - - if !outYaml { - // take a general approach to output json, so we avoid the - // double tags in the structures and can use go-yaml features - // missing from the json package - ret, err = goccyyaml.YAMLToJSON(ret) - if err != nil { - return "", wrap(err) - } - } - - return string(ret), nil -} - -func runSetupDataSources(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - fromFile := args[0] - - toDir, err := flags.GetString("to-dir") - if err != nil { - return err - } - - input, err := os.ReadFile(fromFile) - if err != nil { - return fmt.Errorf("while reading setup file: %w", err) - } - - output, err := setup.DataSources(input, toDir) - if err != nil { - return err - } - - if toDir == "" { - fmt.Println(output) - } - - return nil -} - -func runSetupInstallHub(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - fromFile := args[0] - - dryRun, err := flags.GetBool("dry-run") - if err != nil { - return err - } - - input, err := os.ReadFile(fromFile) - if err != nil { - return fmt.Errorf("while reading file %s: %w", fromFile, err) - } - - if err = setup.InstallHubItems(csConfig, input, dryRun); err != nil { - return err - } - - return nil -} - -func runSetupValidate(cmd *cobra.Command, args []string) error { - fromFile := args[0] - input, err := os.ReadFile(fromFile) - if err != nil { - return fmt.Errorf("while reading stdin: %w", err) - } - - if err = setup.Validate(input); err != nil { - fmt.Printf("%v\n", err) - return fmt.Errorf("invalid setup file") +func (cli *cliRoot) addSetup(cmd *cobra.Command) { + if fflag.CscliSetup.IsEnabled() { + cmd.AddCommand(clisetup.New(cli.cfg).NewCommand()) } - return nil + component.Register("cscli_setup") } diff --git a/cmd/crowdsec-cli/setup_stub.go b/cmd/crowdsec-cli/setup_stub.go new file mode 100644 index 00000000000..e001f93c797 --- /dev/null +++ b/cmd/crowdsec-cli/setup_stub.go @@ -0,0 +1,9 @@ +//go:build no_cscli_setup +package main + +import ( + "github.com/spf13/cobra" +) + +func (cli *cliRoot) addSetup(_ *cobra.Command) { +} diff --git a/cmd/crowdsec-cli/simulation.go b/cmd/crowdsec-cli/simulation.go deleted file mode 100644 index db499e38092..00000000000 --- a/cmd/crowdsec-cli/simulation.go +++ /dev/null @@ -1,268 +0,0 @@ -package main - -import ( - "fmt" - "os" - - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - "golang.org/x/exp/slices" - "gopkg.in/yaml.v2" - - "github.com/crowdsecurity/crowdsec/pkg/cwhub" -) - -func addToExclusion(name string) error { - csConfig.Cscli.SimulationConfig.Exclusions = append(csConfig.Cscli.SimulationConfig.Exclusions, name) - return nil -} - -func removeFromExclusion(name string) error { - index := indexOf(name, csConfig.Cscli.SimulationConfig.Exclusions) - - // Remove element from the slice - csConfig.Cscli.SimulationConfig.Exclusions[index] = csConfig.Cscli.SimulationConfig.Exclusions[len(csConfig.Cscli.SimulationConfig.Exclusions)-1] - csConfig.Cscli.SimulationConfig.Exclusions[len(csConfig.Cscli.SimulationConfig.Exclusions)-1] = "" - csConfig.Cscli.SimulationConfig.Exclusions = csConfig.Cscli.SimulationConfig.Exclusions[:len(csConfig.Cscli.SimulationConfig.Exclusions)-1] - - return nil -} - -func enableGlobalSimulation() error { - csConfig.Cscli.SimulationConfig.Simulation = new(bool) - *csConfig.Cscli.SimulationConfig.Simulation = true - csConfig.Cscli.SimulationConfig.Exclusions = []string{} - - if err := dumpSimulationFile(); err != nil { - log.Fatalf("unable to dump simulation file: %s", err) - } - - log.Printf("global simulation: enabled") - - return nil -} - -func dumpSimulationFile() error { - newConfigSim, err := yaml.Marshal(csConfig.Cscli.SimulationConfig) - if err != nil { - return fmt.Errorf("unable to marshal simulation configuration: %s", err) - } - err = os.WriteFile(csConfig.ConfigPaths.SimulationFilePath, newConfigSim, 0644) - if err != nil { - return fmt.Errorf("write simulation config in '%s' failed: %s", csConfig.ConfigPaths.SimulationFilePath, err) - } - log.Debugf("updated simulation file %s", csConfig.ConfigPaths.SimulationFilePath) - - return nil -} - -func disableGlobalSimulation() error { - csConfig.Cscli.SimulationConfig.Simulation = new(bool) - *csConfig.Cscli.SimulationConfig.Simulation = false - - csConfig.Cscli.SimulationConfig.Exclusions = []string{} - newConfigSim, err := yaml.Marshal(csConfig.Cscli.SimulationConfig) - if err != nil { - return fmt.Errorf("unable to marshal new simulation configuration: %s", err) - } - err = os.WriteFile(csConfig.ConfigPaths.SimulationFilePath, newConfigSim, 0644) - if err != nil { - return fmt.Errorf("unable to write new simulation config in '%s' : %s", csConfig.ConfigPaths.SimulationFilePath, err) - } - - log.Printf("global simulation: disabled") - return nil -} - -func simulationStatus() error { - if csConfig.Cscli.SimulationConfig == nil { - log.Printf("global simulation: disabled (configuration file is missing)") - return nil - } - if *csConfig.Cscli.SimulationConfig.Simulation { - log.Println("global simulation: enabled") - if len(csConfig.Cscli.SimulationConfig.Exclusions) > 0 { - log.Println("Scenarios not in simulation mode :") - for _, scenario := range csConfig.Cscli.SimulationConfig.Exclusions { - log.Printf(" - %s", scenario) - } - } - } else { - log.Println("global simulation: disabled") - if len(csConfig.Cscli.SimulationConfig.Exclusions) > 0 { - log.Println("Scenarios in simulation mode :") - for _, scenario := range csConfig.Cscli.SimulationConfig.Exclusions { - log.Printf(" - %s", scenario) - } - } - } - return nil -} - -func NewSimulationCmds() *cobra.Command { - var cmdSimulation = &cobra.Command{ - Use: "simulation [command]", - Short: "Manage simulation status of scenarios", - Example: `cscli simulation status -cscli simulation enable crowdsecurity/ssh-bf -cscli simulation disable crowdsecurity/ssh-bf`, - DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if err := csConfig.LoadSimulation(); err != nil { - log.Fatal(err) - } - if csConfig.Cscli == nil { - return fmt.Errorf("you must configure cli before using simulation") - } - if csConfig.Cscli.SimulationConfig == nil { - return fmt.Errorf("no simulation configured") - } - return nil - }, - PersistentPostRun: func(cmd *cobra.Command, args []string) { - if cmd.Name() != "status" { - log.Infof(ReloadMessage()) - } - }, - } - cmdSimulation.Flags().SortFlags = false - cmdSimulation.PersistentFlags().SortFlags = false - - cmdSimulation.AddCommand(NewSimulationEnableCmd()) - cmdSimulation.AddCommand(NewSimulationDisableCmd()) - cmdSimulation.AddCommand(NewSimulationStatusCmd()) - - return cmdSimulation -} - -func NewSimulationEnableCmd() *cobra.Command { - var forceGlobalSimulation bool - - var cmdSimulationEnable = &cobra.Command{ - Use: "enable [scenario] [-global]", - Short: "Enable the simulation, globally or on specified scenarios", - Example: `cscli simulation enable`, - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - if err := csConfig.LoadHub(); err != nil { - log.Fatal(err) - } - if err := cwhub.GetHubIdx(csConfig.Hub); err != nil { - log.Info("Run 'sudo cscli hub update' to get the hub index") - log.Fatalf("Failed to get Hub index : %v", err) - } - - if len(args) > 0 { - for _, scenario := range args { - var item = cwhub.GetItem(cwhub.SCENARIOS, scenario) - if item == nil { - log.Errorf("'%s' doesn't exist or is not a scenario", scenario) - continue - } - if !item.Installed { - log.Warningf("'%s' isn't enabled", scenario) - } - isExcluded := slices.Contains(csConfig.Cscli.SimulationConfig.Exclusions, scenario) - if *csConfig.Cscli.SimulationConfig.Simulation && !isExcluded { - log.Warning("global simulation is already enabled") - continue - } - if !*csConfig.Cscli.SimulationConfig.Simulation && isExcluded { - log.Warningf("simulation for '%s' already enabled", scenario) - continue - } - if *csConfig.Cscli.SimulationConfig.Simulation && isExcluded { - if err := removeFromExclusion(scenario); err != nil { - log.Fatal(err) - } - log.Printf("simulation enabled for '%s'", scenario) - continue - } - if err := addToExclusion(scenario); err != nil { - log.Fatal(err) - } - log.Printf("simulation mode for '%s' enabled", scenario) - } - if err := dumpSimulationFile(); err != nil { - log.Fatalf("simulation enable: %s", err) - } - } else if forceGlobalSimulation { - if err := enableGlobalSimulation(); err != nil { - log.Fatalf("unable to enable global simulation mode : %s", err) - } - } else { - printHelp(cmd) - } - }, - } - cmdSimulationEnable.Flags().BoolVarP(&forceGlobalSimulation, "global", "g", false, "Enable global simulation (reverse mode)") - - return cmdSimulationEnable -} - -func NewSimulationDisableCmd() *cobra.Command { - var forceGlobalSimulation bool - - var cmdSimulationDisable = &cobra.Command{ - Use: "disable [scenario]", - Short: "Disable the simulation mode. Disable only specified scenarios", - Example: `cscli simulation disable`, - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - if len(args) > 0 { - for _, scenario := range args { - isExcluded := slices.Contains(csConfig.Cscli.SimulationConfig.Exclusions, scenario) - if !*csConfig.Cscli.SimulationConfig.Simulation && !isExcluded { - log.Warningf("%s isn't in simulation mode", scenario) - continue - } - if !*csConfig.Cscli.SimulationConfig.Simulation && isExcluded { - if err := removeFromExclusion(scenario); err != nil { - log.Fatal(err) - } - log.Printf("simulation mode for '%s' disabled", scenario) - continue - } - if isExcluded { - log.Warningf("simulation mode is enabled but is already disable for '%s'", scenario) - continue - } - if err := addToExclusion(scenario); err != nil { - log.Fatal(err) - } - log.Printf("simulation mode for '%s' disabled", scenario) - } - if err := dumpSimulationFile(); err != nil { - log.Fatalf("simulation disable: %s", err) - } - } else if forceGlobalSimulation { - if err := disableGlobalSimulation(); err != nil { - log.Fatalf("unable to disable global simulation mode : %s", err) - } - } else { - printHelp(cmd) - } - }, - } - cmdSimulationDisable.Flags().BoolVarP(&forceGlobalSimulation, "global", "g", false, "Disable global simulation (reverse mode)") - - return cmdSimulationDisable -} - -func NewSimulationStatusCmd() *cobra.Command { - var cmdSimulationStatus = &cobra.Command{ - Use: "status", - Short: "Show simulation mode status", - Example: `cscli simulation status`, - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - if err := simulationStatus(); err != nil { - log.Fatal(err) - } - }, - PersistentPostRun: func(cmd *cobra.Command, args []string) { - }, - } - - return cmdSimulationStatus -} diff --git a/cmd/crowdsec-cli/support.go b/cmd/crowdsec-cli/support.go deleted file mode 100644 index 66c1493a4b6..00000000000 --- a/cmd/crowdsec-cli/support.go +++ /dev/null @@ -1,430 +0,0 @@ -package main - -import ( - "archive/zip" - "bytes" - "context" - "fmt" - "io" - "net/http" - "net/url" - "os" - "path/filepath" - "regexp" - "strings" - - "github.com/blackfireio/osinfo" - "github.com/go-openapi/strfmt" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - - "github.com/crowdsecurity/go-cs-lib/pkg/version" - - "github.com/crowdsecurity/crowdsec/pkg/apiclient" - "github.com/crowdsecurity/crowdsec/pkg/cwhub" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" - "github.com/crowdsecurity/crowdsec/pkg/database" - "github.com/crowdsecurity/crowdsec/pkg/fflag" - "github.com/crowdsecurity/crowdsec/pkg/models" -) - -const ( - SUPPORT_METRICS_HUMAN_PATH = "metrics/metrics.human" - SUPPORT_METRICS_PROMETHEUS_PATH = "metrics/metrics.prometheus" - SUPPORT_VERSION_PATH = "version.txt" - SUPPORT_FEATURES_PATH = "features.txt" - SUPPORT_OS_INFO_PATH = "osinfo.txt" - SUPPORT_PARSERS_PATH = "hub/parsers.txt" - SUPPORT_SCENARIOS_PATH = "hub/scenarios.txt" - SUPPORT_COLLECTIONS_PATH = "hub/collections.txt" - SUPPORT_POSTOVERFLOWS_PATH = "hub/postoverflows.txt" - SUPPORT_BOUNCERS_PATH = "lapi/bouncers.txt" - SUPPORT_AGENTS_PATH = "lapi/agents.txt" - SUPPORT_CROWDSEC_CONFIG_PATH = "config/crowdsec.yaml" - SUPPORT_LAPI_STATUS_PATH = "lapi_status.txt" - SUPPORT_CAPI_STATUS_PATH = "capi_status.txt" - SUPPORT_ACQUISITION_CONFIG_BASE_PATH = "config/acquis/" - SUPPORT_CROWDSEC_PROFILE_PATH = "config/profiles.yaml" -) - -// from https://github.com/acarl005/stripansi -var reStripAnsi = regexp.MustCompile("[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))") - -func stripAnsiString(str string) string { - // the byte version doesn't strip correctly - return reStripAnsi.ReplaceAllString(str, "") -} - -func collectMetrics() ([]byte, []byte, error) { - log.Info("Collecting prometheus metrics") - err := csConfig.LoadPrometheus() - if err != nil { - return nil, nil, err - } - - if csConfig.Cscli.PrometheusUrl == "" { - log.Warn("No Prometheus URL configured, metrics will not be collected") - return nil, nil, fmt.Errorf("prometheus_uri is not set") - } - - humanMetrics := bytes.NewBuffer(nil) - err = FormatPrometheusMetrics(humanMetrics, csConfig.Cscli.PrometheusUrl+"/metrics", "human") - - if err != nil { - return nil, nil, fmt.Errorf("could not fetch promtheus metrics: %s", err) - } - - req, err := http.NewRequest(http.MethodGet, csConfig.Cscli.PrometheusUrl+"/metrics", nil) - if err != nil { - return nil, nil, fmt.Errorf("could not create requests to prometheus endpoint: %s", err) - } - client := &http.Client{} - resp, err := client.Do(req) - - if err != nil { - return nil, nil, fmt.Errorf("could not get metrics from prometheus endpoint: %s", err) - } - - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("could not read metrics from prometheus endpoint: %s", err) - } - - return humanMetrics.Bytes(), body, nil -} - -func collectVersion() []byte { - log.Info("Collecting version") - return []byte(cwversion.ShowStr()) -} - -func collectFeatures() []byte { - log.Info("Collecting feature flags") - enabledFeatures := fflag.Crowdsec.GetEnabledFeatures() - - w := bytes.NewBuffer(nil) - for _, k := range enabledFeatures { - fmt.Fprintf(w, "%s\n", k) - } - return w.Bytes() -} - -func collectOSInfo() ([]byte, error) { - log.Info("Collecting OS info") - info, err := osinfo.GetOSInfo() - - if err != nil { - return nil, err - } - - w := bytes.NewBuffer(nil) - w.WriteString(fmt.Sprintf("Architecture: %s\n", info.Architecture)) - w.WriteString(fmt.Sprintf("Family: %s\n", info.Family)) - w.WriteString(fmt.Sprintf("ID: %s\n", info.ID)) - w.WriteString(fmt.Sprintf("Name: %s\n", info.Name)) - w.WriteString(fmt.Sprintf("Codename: %s\n", info.Codename)) - w.WriteString(fmt.Sprintf("Version: %s\n", info.Version)) - w.WriteString(fmt.Sprintf("Build: %s\n", info.Build)) - - return w.Bytes(), nil -} - -func initHub() error { - if err := csConfig.LoadHub(); err != nil { - return fmt.Errorf("cannot load hub: %s", err) - } - if csConfig.Hub == nil { - return fmt.Errorf("hub not configured") - } - - if err := cwhub.SetHubBranch(); err != nil { - return fmt.Errorf("cannot set hub branch: %s", err) - } - - if err := cwhub.GetHubIdx(csConfig.Hub); err != nil { - return fmt.Errorf("no hub index found: %s", err) - } - return nil -} - -func collectHubItems(itemType string) []byte { - out := bytes.NewBuffer(nil) - log.Infof("Collecting %s list", itemType) - ListItems(out, []string{itemType}, []string{}, false, true, all) - return out.Bytes() -} - -func collectBouncers(dbClient *database.Client) ([]byte, error) { - out := bytes.NewBuffer(nil) - err := getBouncers(out, dbClient) - if err != nil { - return nil, err - } - return out.Bytes(), nil -} - -func collectAgents(dbClient *database.Client) ([]byte, error) { - out := bytes.NewBuffer(nil) - err := getAgents(out, dbClient) - if err != nil { - return nil, err - } - return out.Bytes(), nil -} - -func collectAPIStatus(login string, password string, endpoint string, prefix string) []byte { - if csConfig.API.Client == nil || csConfig.API.Client.Credentials == nil { - return []byte("No agent credentials found, are we LAPI ?") - } - pwd := strfmt.Password(password) - apiurl, err := url.Parse(endpoint) - - if err != nil { - return []byte(fmt.Sprintf("cannot parse API URL: %s", err)) - } - scenarios, err := cwhub.GetInstalledScenariosAsString() - if err != nil { - return []byte(fmt.Sprintf("could not collect scenarios: %s", err)) - } - - Client, err = apiclient.NewDefaultClient(apiurl, - prefix, - fmt.Sprintf("crowdsec/%s", version.String()), - nil) - if err != nil { - return []byte(fmt.Sprintf("could not init client: %s", err)) - } - t := models.WatcherAuthRequest{ - MachineID: &login, - Password: &pwd, - Scenarios: scenarios, - } - - _, _, err = Client.Auth.AuthenticateWatcher(context.Background(), t) - if err != nil { - return []byte(fmt.Sprintf("Could not authenticate to API: %s", err)) - } else { - return []byte("Successfully authenticated to LAPI") - } -} - -func collectCrowdsecConfig() []byte { - log.Info("Collecting crowdsec config") - config, err := os.ReadFile(*csConfig.FilePath) - if err != nil { - return []byte(fmt.Sprintf("could not read config file: %s", err)) - } - - r := regexp.MustCompile(`(\s+password:|\s+user:|\s+host:)\s+.*`) - - return r.ReplaceAll(config, []byte("$1 ****REDACTED****")) -} - -func collectCrowdsecProfile() []byte { - log.Info("Collecting crowdsec profile") - config, err := os.ReadFile(csConfig.API.Server.ProfilesPath) - if err != nil { - return []byte(fmt.Sprintf("could not read profile file: %s", err)) - } - return config -} - -func collectAcquisitionConfig() map[string][]byte { - log.Info("Collecting acquisition config") - ret := make(map[string][]byte) - - for _, filename := range csConfig.Crowdsec.AcquisitionFiles { - fileContent, err := os.ReadFile(filename) - if err != nil { - ret[filename] = []byte(fmt.Sprintf("could not read file: %s", err)) - } else { - ret[filename] = fileContent - } - } - - return ret -} - -func NewSupportCmd() *cobra.Command { - var cmdSupport = &cobra.Command{ - Use: "support [action]", - Short: "Provide commands to help during support", - Args: cobra.MinimumNArgs(1), - DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - return nil - }, - } - - var outFile string - - cmdDump := &cobra.Command{ - Use: "dump", - Short: "Dump all your configuration to a zip file for easier support", - Long: `Dump the following informations: -- Crowdsec version -- OS version -- Installed collections list -- Installed parsers list -- Installed scenarios list -- Installed postoverflows list -- Bouncers list -- Machines list -- CAPI status -- LAPI status -- Crowdsec config (sensitive information like username and password are redacted) -- Crowdsec metrics`, - Example: `cscli support dump -cscli support dump -f /tmp/crowdsec-support.zip -`, - Args: cobra.NoArgs, - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - var err error - var skipHub, skipDB, skipCAPI, skipLAPI, skipAgent bool - infos := map[string][]byte{ - SUPPORT_VERSION_PATH: collectVersion(), - SUPPORT_FEATURES_PATH: collectFeatures(), - } - - if outFile == "" { - outFile = "/tmp/crowdsec-support.zip" - } - - dbClient, err = database.NewClient(csConfig.DbConfig) - if err != nil { - log.Warnf("Could not connect to database: %s", err) - skipDB = true - infos[SUPPORT_BOUNCERS_PATH] = []byte(err.Error()) - infos[SUPPORT_AGENTS_PATH] = []byte(err.Error()) - } - - if err := csConfig.LoadAPIServer(); err != nil { - log.Warnf("could not load LAPI, skipping CAPI check") - skipLAPI = true - infos[SUPPORT_CAPI_STATUS_PATH] = []byte(err.Error()) - } - - if err := csConfig.LoadCrowdsec(); err != nil { - log.Warnf("could not load agent config, skipping crowdsec config check") - skipAgent = true - } - - err = initHub() - if err != nil { - log.Warn("Could not init hub, running on LAPI ? Hub related information will not be collected") - skipHub = true - infos[SUPPORT_PARSERS_PATH] = []byte(err.Error()) - infos[SUPPORT_SCENARIOS_PATH] = []byte(err.Error()) - infos[SUPPORT_POSTOVERFLOWS_PATH] = []byte(err.Error()) - infos[SUPPORT_COLLECTIONS_PATH] = []byte(err.Error()) - } - - if csConfig.API.Client == nil || csConfig.API.Client.Credentials == nil { - log.Warn("no agent credentials found, skipping LAPI connectivity check") - if _, ok := infos[SUPPORT_LAPI_STATUS_PATH]; ok { - infos[SUPPORT_LAPI_STATUS_PATH] = append(infos[SUPPORT_LAPI_STATUS_PATH], []byte("\nNo LAPI credentials found")...) - } - skipLAPI = true - } - - if csConfig.API.Server == nil || csConfig.API.Server.OnlineClient == nil || csConfig.API.Server.OnlineClient.Credentials == nil { - log.Warn("no CAPI credentials found, skipping CAPI connectivity check") - skipCAPI = true - } - - infos[SUPPORT_METRICS_HUMAN_PATH], infos[SUPPORT_METRICS_PROMETHEUS_PATH], err = collectMetrics() - if err != nil { - log.Warnf("could not collect prometheus metrics information: %s", err) - infos[SUPPORT_METRICS_HUMAN_PATH] = []byte(err.Error()) - infos[SUPPORT_METRICS_PROMETHEUS_PATH] = []byte(err.Error()) - } - - infos[SUPPORT_OS_INFO_PATH], err = collectOSInfo() - if err != nil { - log.Warnf("could not collect OS information: %s", err) - infos[SUPPORT_OS_INFO_PATH] = []byte(err.Error()) - } - - infos[SUPPORT_CROWDSEC_CONFIG_PATH] = collectCrowdsecConfig() - - if !skipHub { - infos[SUPPORT_PARSERS_PATH] = collectHubItems(cwhub.PARSERS) - infos[SUPPORT_SCENARIOS_PATH] = collectHubItems(cwhub.SCENARIOS) - infos[SUPPORT_POSTOVERFLOWS_PATH] = collectHubItems(cwhub.PARSERS_OVFLW) - infos[SUPPORT_COLLECTIONS_PATH] = collectHubItems(cwhub.COLLECTIONS) - } - - if !skipDB { - infos[SUPPORT_BOUNCERS_PATH], err = collectBouncers(dbClient) - if err != nil { - log.Warnf("could not collect bouncers information: %s", err) - infos[SUPPORT_BOUNCERS_PATH] = []byte(err.Error()) - } - - infos[SUPPORT_AGENTS_PATH], err = collectAgents(dbClient) - if err != nil { - log.Warnf("could not collect agents information: %s", err) - infos[SUPPORT_AGENTS_PATH] = []byte(err.Error()) - } - } - - if !skipCAPI { - log.Info("Collecting CAPI status") - infos[SUPPORT_CAPI_STATUS_PATH] = collectAPIStatus(csConfig.API.Server.OnlineClient.Credentials.Login, - csConfig.API.Server.OnlineClient.Credentials.Password, - csConfig.API.Server.OnlineClient.Credentials.URL, - CAPIURLPrefix) - } - - if !skipLAPI { - log.Info("Collection LAPI status") - infos[SUPPORT_LAPI_STATUS_PATH] = collectAPIStatus(csConfig.API.Client.Credentials.Login, - csConfig.API.Client.Credentials.Password, - csConfig.API.Client.Credentials.URL, - LAPIURLPrefix) - infos[SUPPORT_CROWDSEC_PROFILE_PATH] = collectCrowdsecProfile() - } - - if !skipAgent { - - acquis := collectAcquisitionConfig() - - for filename, content := range acquis { - fname := strings.ReplaceAll(filename, string(filepath.Separator), "___") - infos[SUPPORT_ACQUISITION_CONFIG_BASE_PATH+fname] = content - } - } - - w := bytes.NewBuffer(nil) - zipWriter := zip.NewWriter(w) - - for filename, data := range infos { - fw, err := zipWriter.Create(filename) - if err != nil { - log.Errorf("Could not add zip entry for %s: %s", filename, err) - continue - } - fw.Write([]byte(stripAnsiString(string(data)))) - } - - err = zipWriter.Close() - if err != nil { - log.Fatalf("could not finalize zip file: %s", err) - } - - err = os.WriteFile(outFile, w.Bytes(), 0600) - if err != nil { - log.Fatalf("could not write zip file to %s: %s", outFile, err) - } - - log.Infof("Written zip file to %s", outFile) - }, - } - cmdDump.Flags().StringVarP(&outFile, "outFile", "f", "", "File to dump the information to") - cmdSupport.AddCommand(cmdDump) - - return cmdSupport -} diff --git a/cmd/crowdsec-cli/tables.go b/cmd/crowdsec-cli/tables.go deleted file mode 100644 index 2c3173d0b0b..00000000000 --- a/cmd/crowdsec-cli/tables.go +++ /dev/null @@ -1,95 +0,0 @@ -package main - -import ( - "fmt" - "io" - "os" - - "github.com/aquasecurity/table" - isatty "github.com/mattn/go-isatty" -) - -func shouldWeColorize() bool { - if csConfig.Cscli.Color == "yes" { - return true - } - if csConfig.Cscli.Color == "no" { - return false - } - return isatty.IsTerminal(os.Stdout.Fd()) || isatty.IsCygwinTerminal(os.Stdout.Fd()) -} - -func newTable(out io.Writer) *table.Table { - if out == nil { - panic("newTable: out is nil") - } - t := table.New(out) - if shouldWeColorize() { - t.SetLineStyle(table.StyleBrightBlack) - t.SetHeaderStyle(table.StyleItalic) - } - - if shouldWeColorize() { - t.SetDividers(table.UnicodeRoundedDividers) - } else { - t.SetDividers(table.ASCIIDividers) - } - - return t -} - -func newLightTable(out io.Writer) *table.Table { - if out == nil { - panic("newTable: out is nil") - } - t := newTable(out) - t.SetRowLines(false) - t.SetBorderLeft(false) - t.SetBorderRight(false) - // This leaves three spaces between columns: - // left padding, invisible border, right padding - // There is no way to make two spaces without - // a SetColumnLines() method, but it's close enough. - t.SetPadding(1) - - if shouldWeColorize() { - t.SetDividers(table.Dividers{ - ALL: "─", - NES: "─", - NSW: "─", - NEW: "─", - ESW: "─", - NE: "─", - NW: "─", - SW: "─", - ES: "─", - EW: "─", - NS: " ", - }) - } else { - t.SetDividers(table.Dividers{ - ALL: "-", - NES: "-", - NSW: "-", - NEW: "-", - ESW: "-", - NE: "-", - NW: "-", - SW: "-", - ES: "-", - EW: "-", - NS: " ", - }) - } - return t -} - -func renderTableTitle(out io.Writer, title string) { - if out == nil { - panic("renderTableTitle: out is nil") - } - if title == "" { - return - } - fmt.Fprintln(out, title) -} diff --git a/cmd/crowdsec-cli/utils.go b/cmd/crowdsec-cli/utils.go deleted file mode 100644 index a3c42f5d18e..00000000000 --- a/cmd/crowdsec-cli/utils.go +++ /dev/null @@ -1,748 +0,0 @@ -package main - -import ( - "encoding/csv" - "encoding/json" - "fmt" - "io" - "math" - "net" - "net/http" - "os" - "strconv" - "strings" - "time" - - "github.com/fatih/color" - dto "github.com/prometheus/client_model/go" - "github.com/prometheus/prom2json" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - "github.com/agext/levenshtein" - "golang.org/x/exp/slices" - "gopkg.in/yaml.v2" - - "github.com/crowdsecurity/go-cs-lib/pkg/trace" - - "github.com/crowdsecurity/crowdsec/pkg/cwhub" - "github.com/crowdsecurity/crowdsec/pkg/database" - "github.com/crowdsecurity/crowdsec/pkg/types" -) - -const MaxDistance = 7 - -func printHelp(cmd *cobra.Command) { - err := cmd.Help() - if err != nil { - log.Fatalf("unable to print help(): %s", err) - } -} - -func indexOf(s string, slice []string) int { - for i, elem := range slice { - if s == elem { - return i - } - } - return -1 -} - -func LoadHub() error { - if err := csConfig.LoadHub(); err != nil { - log.Fatal(err) - } - if csConfig.Hub == nil { - return fmt.Errorf("unable to load hub") - } - - if err := cwhub.SetHubBranch(); err != nil { - log.Warningf("unable to set hub branch (%s), default to master", err) - } - - if err := cwhub.GetHubIdx(csConfig.Hub); err != nil { - return fmt.Errorf("Failed to get Hub index : '%w'. Run 'sudo cscli hub update' to get the hub index", err) - } - - return nil -} - -func Suggest(itemType string, baseItem string, suggestItem string, score int, ignoreErr bool) { - errMsg := "" - if score < MaxDistance { - errMsg = fmt.Sprintf("unable to find %s '%s', did you mean %s ?", itemType, baseItem, suggestItem) - } else { - errMsg = fmt.Sprintf("unable to find %s '%s'", itemType, baseItem) - } - if ignoreErr { - log.Error(errMsg) - } else { - log.Fatalf(errMsg) - } -} - -func GetDistance(itemType string, itemName string) (*cwhub.Item, int) { - allItems := make([]string, 0) - nearestScore := 100 - nearestItem := &cwhub.Item{} - hubItems := cwhub.GetHubStatusForItemType(itemType, "", true) - for _, item := range hubItems { - allItems = append(allItems, item.Name) - } - - for _, s := range allItems { - d := levenshtein.Distance(itemName, s, nil) - if d < nearestScore { - nearestScore = d - nearestItem = cwhub.GetItem(itemType, s) - } - } - return nearestItem, nearestScore -} - -func compAllItems(itemType string, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - if err := LoadHub(); err != nil { - return nil, cobra.ShellCompDirectiveDefault - } - - comp := make([]string, 0) - hubItems := cwhub.GetHubStatusForItemType(itemType, "", true) - for _, item := range hubItems { - if !slices.Contains(args, item.Name) && strings.Contains(item.Name, toComplete) { - comp = append(comp, item.Name) - } - } - cobra.CompDebugln(fmt.Sprintf("%s: %+v", itemType, comp), true) - return comp, cobra.ShellCompDirectiveNoFileComp -} - -func compInstalledItems(itemType string, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - if err := LoadHub(); err != nil { - return nil, cobra.ShellCompDirectiveDefault - } - - var items []string - var err error - switch itemType { - case cwhub.PARSERS: - items, err = cwhub.GetInstalledParsersAsString() - case cwhub.SCENARIOS: - items, err = cwhub.GetInstalledScenariosAsString() - case cwhub.PARSERS_OVFLW: - items, err = cwhub.GetInstalledPostOverflowsAsString() - case cwhub.COLLECTIONS: - items, err = cwhub.GetInstalledCollectionsAsString() - default: - return nil, cobra.ShellCompDirectiveDefault - } - - if err != nil { - cobra.CompDebugln(fmt.Sprintf("list installed %s err: %s", itemType, err), true) - return nil, cobra.ShellCompDirectiveDefault - } - comp := make([]string, 0) - - if toComplete != "" { - for _, item := range items { - if strings.Contains(item, toComplete) { - comp = append(comp, item) - } - } - } else { - comp = items - } - - cobra.CompDebugln(fmt.Sprintf("%s: %+v", itemType, comp), true) - - return comp, cobra.ShellCompDirectiveNoFileComp -} - -func ListItems(out io.Writer, itemTypes []string, args []string, showType bool, showHeader bool, all bool) { - var hubStatusByItemType = make(map[string][]cwhub.ItemHubStatus) - - for _, itemType := range itemTypes { - itemName := "" - if len(args) == 1 { - itemName = args[0] - } - hubStatusByItemType[itemType] = cwhub.GetHubStatusForItemType(itemType, itemName, all) - } - - if csConfig.Cscli.Output == "human" { - for _, itemType := range itemTypes { - var statuses []cwhub.ItemHubStatus - var ok bool - if statuses, ok = hubStatusByItemType[itemType]; !ok { - log.Errorf("unknown item type: %s", itemType) - continue - } - listHubItemTable(out, "\n"+strings.ToUpper(itemType), statuses) - } - } else if csConfig.Cscli.Output == "json" { - x, err := json.MarshalIndent(hubStatusByItemType, "", " ") - if err != nil { - log.Fatalf("failed to unmarshal") - } - out.Write(x) - } else if csConfig.Cscli.Output == "raw" { - csvwriter := csv.NewWriter(out) - if showHeader { - header := []string{"name", "status", "version", "description"} - if showType { - header = append(header, "type") - } - err := csvwriter.Write(header) - if err != nil { - log.Fatalf("failed to write header: %s", err) - } - - } - for _, itemType := range itemTypes { - var statuses []cwhub.ItemHubStatus - var ok bool - if statuses, ok = hubStatusByItemType[itemType]; !ok { - log.Errorf("unknown item type: %s", itemType) - continue - } - for _, status := range statuses { - if status.LocalVersion == "" { - status.LocalVersion = "n/a" - } - row := []string{ - status.Name, - status.Status, - status.LocalVersion, - status.Description, - } - if showType { - row = append(row, itemType) - } - err := csvwriter.Write(row) - if err != nil { - log.Fatalf("failed to write raw output : %s", err) - } - } - } - csvwriter.Flush() - } -} - -func InspectItem(name string, objecitemType string) { - - hubItem := cwhub.GetItem(objecitemType, name) - if hubItem == nil { - log.Fatalf("unable to retrieve item.") - } - var b []byte - var err error - switch csConfig.Cscli.Output { - case "human", "raw": - b, err = yaml.Marshal(*hubItem) - if err != nil { - log.Fatalf("unable to marshal item : %s", err) - } - case "json": - b, err = json.MarshalIndent(*hubItem, "", " ") - if err != nil { - log.Fatalf("unable to marshal item : %s", err) - } - } - fmt.Printf("%s", string(b)) - if csConfig.Cscli.Output == "json" || csConfig.Cscli.Output == "raw" { - return - } - - if prometheusURL == "" { - //This is technically wrong to do this, as the prometheus section contains a listen address, not an URL to query prometheus - //But for ease of use, we will use the listen address as the prometheus URL because it will be 127.0.0.1 in the default case - listenAddr := csConfig.Prometheus.ListenAddr - if listenAddr == "" { - listenAddr = "127.0.0.1" - } - listenPort := csConfig.Prometheus.ListenPort - if listenPort == 0 { - listenPort = 6060 - } - prometheusURL = fmt.Sprintf("http://%s:%d/metrics", listenAddr, listenPort) - log.Debugf("No prometheus URL provided using: %s", prometheusURL) - } - - fmt.Printf("\nCurrent metrics : \n") - ShowMetrics(hubItem) -} - -func manageCliDecisionAlerts(ip *string, ipRange *string, scope *string, value *string) error { - - /*if a range is provided, change the scope*/ - if *ipRange != "" { - _, _, err := net.ParseCIDR(*ipRange) - if err != nil { - return fmt.Errorf("%s isn't a valid range", *ipRange) - } - } - if *ip != "" { - ipRepr := net.ParseIP(*ip) - if ipRepr == nil { - return fmt.Errorf("%s isn't a valid ip", *ip) - } - } - - //avoid confusion on scope (ip vs Ip and range vs Range) - switch strings.ToLower(*scope) { - case "ip": - *scope = types.Ip - case "range": - *scope = types.Range - case "country": - *scope = types.Country - case "as": - *scope = types.AS - } - return nil -} - -func ShowMetrics(hubItem *cwhub.Item) { - switch hubItem.Type { - case cwhub.PARSERS: - metrics := GetParserMetric(prometheusURL, hubItem.Name) - parserMetricsTable(color.Output, hubItem.Name, metrics) - case cwhub.SCENARIOS: - metrics := GetScenarioMetric(prometheusURL, hubItem.Name) - scenarioMetricsTable(color.Output, hubItem.Name, metrics) - case cwhub.COLLECTIONS: - for _, item := range hubItem.Parsers { - metrics := GetParserMetric(prometheusURL, item) - parserMetricsTable(color.Output, item, metrics) - } - for _, item := range hubItem.Scenarios { - metrics := GetScenarioMetric(prometheusURL, item) - scenarioMetricsTable(color.Output, item, metrics) - } - for _, item := range hubItem.Collections { - hubItem = cwhub.GetItem(cwhub.COLLECTIONS, item) - if hubItem == nil { - log.Fatalf("unable to retrieve item '%s' from collection '%s'", item, hubItem.Name) - } - ShowMetrics(hubItem) - } - default: - log.Errorf("item of type '%s' is unknown", hubItem.Type) - } -} - -// GetParserMetric is a complete rip from prom2json -func GetParserMetric(url string, itemName string) map[string]map[string]int { - stats := make(map[string]map[string]int) - - result := GetPrometheusMetric(url) - for idx, fam := range result { - if !strings.HasPrefix(fam.Name, "cs_") { - continue - } - log.Tracef("round %d", idx) - for _, m := range fam.Metrics { - metric, ok := m.(prom2json.Metric) - if !ok { - log.Debugf("failed to convert metric to prom2json.Metric") - continue - } - name, ok := metric.Labels["name"] - if !ok { - log.Debugf("no name in Metric %v", metric.Labels) - } - if name != itemName { - continue - } - source, ok := metric.Labels["source"] - if !ok { - log.Debugf("no source in Metric %v", metric.Labels) - } else { - if srctype, ok := metric.Labels["type"]; ok { - source = srctype + ":" + source - } - } - value := m.(prom2json.Metric).Value - fval, err := strconv.ParseFloat(value, 32) - if err != nil { - log.Errorf("Unexpected int value %s : %s", value, err) - continue - } - ival := int(fval) - - switch fam.Name { - case "cs_reader_hits_total": - if _, ok := stats[source]; !ok { - stats[source] = make(map[string]int) - stats[source]["parsed"] = 0 - stats[source]["reads"] = 0 - stats[source]["unparsed"] = 0 - stats[source]["hits"] = 0 - } - stats[source]["reads"] += ival - case "cs_parser_hits_ok_total": - if _, ok := stats[source]; !ok { - stats[source] = make(map[string]int) - } - stats[source]["parsed"] += ival - case "cs_parser_hits_ko_total": - if _, ok := stats[source]; !ok { - stats[source] = make(map[string]int) - } - stats[source]["unparsed"] += ival - case "cs_node_hits_total": - if _, ok := stats[source]; !ok { - stats[source] = make(map[string]int) - } - stats[source]["hits"] += ival - case "cs_node_hits_ok_total": - if _, ok := stats[source]; !ok { - stats[source] = make(map[string]int) - } - stats[source]["parsed"] += ival - case "cs_node_hits_ko_total": - if _, ok := stats[source]; !ok { - stats[source] = make(map[string]int) - } - stats[source]["unparsed"] += ival - default: - continue - } - } - } - return stats -} - -func GetScenarioMetric(url string, itemName string) map[string]int { - stats := make(map[string]int) - - stats["instantiation"] = 0 - stats["curr_count"] = 0 - stats["overflow"] = 0 - stats["pour"] = 0 - stats["underflow"] = 0 - - result := GetPrometheusMetric(url) - for idx, fam := range result { - if !strings.HasPrefix(fam.Name, "cs_") { - continue - } - log.Tracef("round %d", idx) - for _, m := range fam.Metrics { - metric, ok := m.(prom2json.Metric) - if !ok { - log.Debugf("failed to convert metric to prom2json.Metric") - continue - } - name, ok := metric.Labels["name"] - if !ok { - log.Debugf("no name in Metric %v", metric.Labels) - } - if name != itemName { - continue - } - value := m.(prom2json.Metric).Value - fval, err := strconv.ParseFloat(value, 32) - if err != nil { - log.Errorf("Unexpected int value %s : %s", value, err) - continue - } - ival := int(fval) - - switch fam.Name { - case "cs_bucket_created_total": - stats["instantiation"] += ival - case "cs_buckets": - stats["curr_count"] += ival - case "cs_bucket_overflowed_total": - stats["overflow"] += ival - case "cs_bucket_poured_total": - stats["pour"] += ival - case "cs_bucket_underflowed_total": - stats["underflow"] += ival - default: - continue - } - } - } - return stats -} - -// it's a rip of the cli version, but in silent-mode -func silenceInstallItem(name string, obtype string) (string, error) { - var item = cwhub.GetItem(obtype, name) - if item == nil { - return "", fmt.Errorf("error retrieving item") - } - it := *item - if downloadOnly && it.Downloaded && it.UpToDate { - return fmt.Sprintf("%s is already downloaded and up-to-date", it.Name), nil - } - it, err := cwhub.DownloadLatest(csConfig.Hub, it, forceAction, false) - if err != nil { - return "", fmt.Errorf("error while downloading %s : %v", it.Name, err) - } - if err := cwhub.AddItem(obtype, it); err != nil { - return "", err - } - - if downloadOnly { - return fmt.Sprintf("Downloaded %s to %s", it.Name, csConfig.Cscli.HubDir+"/"+it.RemotePath), nil - } - it, err = cwhub.EnableItem(csConfig.Hub, it) - if err != nil { - return "", fmt.Errorf("error while enabling %s : %v", it.Name, err) - } - if err := cwhub.AddItem(obtype, it); err != nil { - return "", err - } - return fmt.Sprintf("Enabled %s", it.Name), nil -} - -func GetPrometheusMetric(url string) []*prom2json.Family { - mfChan := make(chan *dto.MetricFamily, 1024) - - // Start with the DefaultTransport for sane defaults. - transport := http.DefaultTransport.(*http.Transport).Clone() - // Conservatively disable HTTP keep-alives as this program will only - // ever need a single HTTP request. - transport.DisableKeepAlives = true - // Timeout early if the server doesn't even return the headers. - transport.ResponseHeaderTimeout = time.Minute - - go func() { - defer trace.CatchPanic("crowdsec/GetPrometheusMetric") - err := prom2json.FetchMetricFamilies(url, mfChan, transport) - if err != nil { - log.Fatalf("failed to fetch prometheus metrics : %v", err) - } - }() - - result := []*prom2json.Family{} - for mf := range mfChan { - result = append(result, prom2json.NewFamily(mf)) - } - log.Debugf("Finished reading prometheus output, %d entries", len(result)) - - return result -} - -func RestoreHub(dirPath string) error { - var err error - - if err := csConfig.LoadHub(); err != nil { - return err - } - if err := cwhub.SetHubBranch(); err != nil { - return fmt.Errorf("error while setting hub branch: %s", err) - } - - for _, itype := range cwhub.ItemTypes { - itemDirectory := fmt.Sprintf("%s/%s/", dirPath, itype) - if _, err = os.Stat(itemDirectory); err != nil { - log.Infof("no %s in backup", itype) - continue - } - /*restore the upstream items*/ - upstreamListFN := fmt.Sprintf("%s/upstream-%s.json", itemDirectory, itype) - file, err := os.ReadFile(upstreamListFN) - if err != nil { - return fmt.Errorf("error while opening %s : %s", upstreamListFN, err) - } - var upstreamList []string - err = json.Unmarshal(file, &upstreamList) - if err != nil { - return fmt.Errorf("error unmarshaling %s : %s", upstreamListFN, err) - } - for _, toinstall := range upstreamList { - label, err := silenceInstallItem(toinstall, itype) - if err != nil { - log.Errorf("Error while installing %s : %s", toinstall, err) - } else if label != "" { - log.Infof("Installed %s : %s", toinstall, label) - } else { - log.Printf("Installed %s : ok", toinstall) - } - } - - /*restore the local and tainted items*/ - files, err := os.ReadDir(itemDirectory) - if err != nil { - return fmt.Errorf("failed enumerating files of %s : %s", itemDirectory, err) - } - for _, file := range files { - //this was the upstream data - if file.Name() == fmt.Sprintf("upstream-%s.json", itype) { - continue - } - if itype == cwhub.PARSERS || itype == cwhub.PARSERS_OVFLW { - //we expect a stage here - if !file.IsDir() { - continue - } - stage := file.Name() - stagedir := fmt.Sprintf("%s/%s/%s/", csConfig.ConfigPaths.ConfigDir, itype, stage) - log.Debugf("Found stage %s in %s, target directory : %s", stage, itype, stagedir) - if err = os.MkdirAll(stagedir, os.ModePerm); err != nil { - return fmt.Errorf("error while creating stage directory %s : %s", stagedir, err) - } - /*find items*/ - ifiles, err := os.ReadDir(itemDirectory + "/" + stage + "/") - if err != nil { - return fmt.Errorf("failed enumerating files of %s : %s", itemDirectory+"/"+stage, err) - } - //finally copy item - for _, tfile := range ifiles { - log.Infof("Going to restore local/tainted [%s]", tfile.Name()) - sourceFile := fmt.Sprintf("%s/%s/%s", itemDirectory, stage, tfile.Name()) - destinationFile := fmt.Sprintf("%s%s", stagedir, tfile.Name()) - if err = CopyFile(sourceFile, destinationFile); err != nil { - return fmt.Errorf("failed copy %s %s to %s : %s", itype, sourceFile, destinationFile, err) - } - log.Infof("restored %s to %s", sourceFile, destinationFile) - } - } else { - log.Infof("Going to restore local/tainted [%s]", file.Name()) - sourceFile := fmt.Sprintf("%s/%s", itemDirectory, file.Name()) - destinationFile := fmt.Sprintf("%s/%s/%s", csConfig.ConfigPaths.ConfigDir, itype, file.Name()) - if err = CopyFile(sourceFile, destinationFile); err != nil { - return fmt.Errorf("failed copy %s %s to %s : %s", itype, sourceFile, destinationFile, err) - } - log.Infof("restored %s to %s", sourceFile, destinationFile) - } - - } - } - return nil -} - -func BackupHub(dirPath string) error { - var err error - var itemDirectory string - var upstreamParsers []string - - for _, itemType := range cwhub.ItemTypes { - clog := log.WithFields(log.Fields{ - "type": itemType, - }) - itemMap := cwhub.GetItemMap(itemType) - if itemMap == nil { - clog.Infof("No %s to backup.", itemType) - continue - } - itemDirectory = fmt.Sprintf("%s/%s/", dirPath, itemType) - if err := os.MkdirAll(itemDirectory, os.ModePerm); err != nil { - return fmt.Errorf("error while creating %s : %s", itemDirectory, err) - } - upstreamParsers = []string{} - for k, v := range itemMap { - clog = clog.WithFields(log.Fields{ - "file": v.Name, - }) - if !v.Installed { //only backup installed ones - clog.Debugf("[%s] : not installed", k) - continue - } - - //for the local/tainted ones, we backup the full file - if v.Tainted || v.Local || !v.UpToDate { - //we need to backup stages for parsers - if itemType == cwhub.PARSERS || itemType == cwhub.PARSERS_OVFLW { - fstagedir := fmt.Sprintf("%s%s", itemDirectory, v.Stage) - if err := os.MkdirAll(fstagedir, os.ModePerm); err != nil { - return fmt.Errorf("error while creating stage dir %s : %s", fstagedir, err) - } - } - clog.Debugf("[%s] : backuping file (tainted:%t local:%t up-to-date:%t)", k, v.Tainted, v.Local, v.UpToDate) - tfile := fmt.Sprintf("%s%s/%s", itemDirectory, v.Stage, v.FileName) - if err = CopyFile(v.LocalPath, tfile); err != nil { - return fmt.Errorf("failed copy %s %s to %s : %s", itemType, v.LocalPath, tfile, err) - } - clog.Infof("local/tainted saved %s to %s", v.LocalPath, tfile) - continue - } - clog.Debugf("[%s] : from hub, just backup name (up-to-date:%t)", k, v.UpToDate) - clog.Infof("saving, version:%s, up-to-date:%t", v.Version, v.UpToDate) - upstreamParsers = append(upstreamParsers, v.Name) - } - //write the upstream items - upstreamParsersFname := fmt.Sprintf("%s/upstream-%s.json", itemDirectory, itemType) - upstreamParsersContent, err := json.MarshalIndent(upstreamParsers, "", " ") - if err != nil { - return fmt.Errorf("failed marshaling upstream parsers : %s", err) - } - err = os.WriteFile(upstreamParsersFname, upstreamParsersContent, 0644) - if err != nil { - return fmt.Errorf("unable to write to %s %s : %s", itemType, upstreamParsersFname, err) - } - clog.Infof("Wrote %d entries for %s to %s", len(upstreamParsers), itemType, upstreamParsersFname) - } - - return nil -} - -type unit struct { - value int64 - symbol string -} - -var ranges = []unit{ - {value: 1e18, symbol: "E"}, - {value: 1e15, symbol: "P"}, - {value: 1e12, symbol: "T"}, - {value: 1e9, symbol: "G"}, - {value: 1e6, symbol: "M"}, - {value: 1e3, symbol: "k"}, - {value: 1, symbol: ""}, -} - -func formatNumber(num int) string { - goodUnit := unit{} - for _, u := range ranges { - if int64(num) >= u.value { - goodUnit = u - break - } - } - - if goodUnit.value == 1 { - return fmt.Sprintf("%d%s", num, goodUnit.symbol) - } - - res := math.Round(float64(num)/float64(goodUnit.value)*100) / 100 - return fmt.Sprintf("%.2f%s", res, goodUnit.symbol) -} - -func getDBClient() (*database.Client, error) { - var err error - if err := csConfig.LoadAPIServer(); err != nil || csConfig.DisableAPI { - return nil, err - } - ret, err := database.NewClient(csConfig.DbConfig) - if err != nil { - return nil, err - } - return ret, nil -} - -func removeFromSlice(val string, slice []string) []string { - var i int - var value string - - valueFound := false - - // get the index - for i, value = range slice { - if value == val { - valueFound = true - break - } - } - - if valueFound { - slice[i] = slice[len(slice)-1] - slice[len(slice)-1] = "" - slice = slice[:len(slice)-1] - } - - return slice - -} diff --git a/cmd/crowdsec-cli/utils_table.go b/cmd/crowdsec-cli/utils_table.go deleted file mode 100644 index aef1e94f7d1..00000000000 --- a/cmd/crowdsec-cli/utils_table.go +++ /dev/null @@ -1,66 +0,0 @@ -package main - -import ( - "fmt" - "io" - - "github.com/aquasecurity/table" - "github.com/enescakir/emoji" - - "github.com/crowdsecurity/crowdsec/pkg/cwhub" -) - -func listHubItemTable(out io.Writer, title string, statuses []cwhub.ItemHubStatus) { - t := newLightTable(out) - t.SetHeaders("Name", fmt.Sprintf("%v Status", emoji.Package), "Version", "Local Path") - t.SetHeaderAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) - - for _, status := range statuses { - t.AddRow(status.Name, status.UTF8_Status, status.LocalVersion, status.LocalPath) - } - renderTableTitle(out, title) - t.Render() -} - -func scenarioMetricsTable(out io.Writer, itemName string, metrics map[string]int) { - if metrics["instantiation"] == 0 { - return - } - t := newTable(out) - t.SetHeaders("Current Count", "Overflows", "Instantiated", "Poured", "Expired") - - t.AddRow( - fmt.Sprintf("%d", metrics["curr_count"]), - fmt.Sprintf("%d", metrics["overflow"]), - fmt.Sprintf("%d", metrics["instantiation"]), - fmt.Sprintf("%d", metrics["pour"]), - fmt.Sprintf("%d", metrics["underflow"]), - ) - - renderTableTitle(out, fmt.Sprintf("\n - (Scenario) %s:", itemName)) - t.Render() -} - -func parserMetricsTable(out io.Writer, itemName string, metrics map[string]map[string]int) { - skip := true - t := newTable(out) - t.SetHeaders("Parsers", "Hits", "Parsed", "Unparsed") - - for source, stats := range metrics { - if stats["hits"] > 0 { - t.AddRow( - source, - fmt.Sprintf("%d", stats["hits"]), - fmt.Sprintf("%d", stats["parsed"]), - fmt.Sprintf("%d", stats["unparsed"]), - ) - skip = false - } - } - - if !skip { - renderTableTitle(out, fmt.Sprintf("\n - (Parser) %s:", itemName)) - t.Render() - } -} diff --git a/cmd/crowdsec-cli/version.go b/cmd/crowdsec-cli/version.go new file mode 100644 index 00000000000..7ec5c459968 --- /dev/null +++ b/cmd/crowdsec-cli/version.go @@ -0,0 +1,29 @@ +package main + +import ( + "os" + + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/pkg/cwversion" +) + +type cliVersion struct{} + +func NewCLIVersion() *cliVersion { + return &cliVersion{} +} + +func (cliVersion) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "version", + Short: "Display version", + Args: cobra.NoArgs, + DisableAutoGenTag: true, + Run: func(_ *cobra.Command, _ []string) { + _, _ = os.Stdout.WriteString(cwversion.FullString()) + }, + } + + return cmd +} diff --git a/cmd/crowdsec/Makefile b/cmd/crowdsec/Makefile index 8242f1b491f..39f807cab88 100644 --- a/cmd/crowdsec/Makefile +++ b/cmd/crowdsec/Makefile @@ -4,67 +4,21 @@ ifeq ($(OS), Windows_NT) EXT = .exe endif -# Go parameters -GOCMD = go -GOBUILD = $(GOCMD) build -GOTEST = $(GOCMD) test +GO = go +GOBUILD = $(GO) build +GOTEST = $(GO) test CROWDSEC_BIN = crowdsec$(EXT) # names longer than 15 chars break 'pgrep' -PREFIX ?= "/" -CFG_PREFIX = $(PREFIX)"/etc/crowdsec/config/" -BIN_PREFIX = $(PREFIX)"/usr/local/bin/" -DATA_PREFIX = $(PREFIX)"/var/run/crowdsec/" -PID_DIR = $(PREFIX)"/var/run/" - -SYSTEMD_PATH_FILE = "/etc/systemd/system/crowdsec.service" .PHONY: all all: clean test build build: clean - $(GOBUILD) $(LD_OPTS) $(BUILD_VENDOR_FLAGS) -o $(CROWDSEC_BIN) + $(GOBUILD) $(LD_OPTS) -o $(CROWDSEC_BIN) test: $(GOTEST) $(LD_OPTS) -v ./... clean: @$(RM) $(CROWDSEC_BIN) $(WIN_IGNORE_ERR) - -.PHONY: install -install: install-conf install-bin - -.PHONY: install-conf -install-conf: - mkdir -p $(DATA_PREFIX) || exit - (cd ../.. / && find ./data -type f -exec install -Dm 755 "{}" "$(DATA_PREFIX){}" \; && cd ./cmd/crowdsec) || exit - (cd ../../config && find ./patterns -type f -exec install -Dm 755 "{}" "$(CFG_PREFIX){}" \; && cd ../cmd/crowdsec) || exit - mkdir -p "$(CFG_PREFIX)" || exit - mkdir -p "$(CFG_PREFIX)/parsers" || exit - mkdir -p "$(CFG_PREFIX)/scenarios" || exit - mkdir -p "$(CFG_PREFIX)/postoverflows" || exit - mkdir -p "$(CFG_PREFIX)/collections" || exit - mkdir -p "$(CFG_PREFIX)/patterns" || exit - install -v -m 755 -D ../../config/prod.yaml "$(CFG_PREFIX)" || exit - install -v -m 755 -D ../../config/dev.yaml "$(CFG_PREFIX)" || exit - install -v -m 755 -D ../../config/acquis.yaml "$(CFG_PREFIX)" || exit - install -v -m 755 -D ../../config/profiles.yaml "$(CFG_PREFIX)" || exit - install -v -m 755 -D ../../config/api.yaml "$(CFG_PREFIX)" || exit - mkdir -p $(PID_DIR) || exit - PID=$(PID_DIR) DATA=$(DATA_PREFIX)"/data/" CFG=$(CFG_PREFIX) envsubst < ../../config/prod.yaml > $(CFG_PREFIX)"/default.yaml" - -.PHONY: install-bin -install-bin: - install -v -m 755 -D "$(CROWDSEC_BIN)" "$(BIN_PREFIX)/$(CROWDSEC_BIN)" || exit - -.PHONY: systemd -systemd: install - CFG=$(CFG_PREFIX) PID=$(PID_DIR) BIN=$(BIN_PREFIX)"/"$(CROWDSEC_BIN) envsubst < ../../config/crowdsec.service > "$(SYSTEMD_PATH_FILE)" - systemctl daemon-reload - -.PHONY: uninstall -uninstall: - $(RM) $(CFG_PREFIX) $(WIN_IGNORE_ERR) - $(RM) $(DATA_PREFIX) $(WIN_IGNORE_ERR) - $(RM) "$(BIN_PREFIX)/$(CROWDSEC_BIN)" $(WIN_IGNORE_ERR) - $(RM) "$(SYSTEMD_PATH_FILE)" $(WIN_IGNORE_ERR) diff --git a/cmd/crowdsec/api.go b/cmd/crowdsec/api.go index fd2e2ce088c..ccb0acf0209 100644 --- a/cmd/crowdsec/api.go +++ b/cmd/crowdsec/api.go @@ -1,25 +1,26 @@ package main import ( + "context" + "errors" "fmt" "runtime" "time" - "github.com/pkg/errors" log "github.com/sirupsen/logrus" - "github.com/crowdsecurity/go-cs-lib/pkg/trace" + "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/crowdsec/pkg/apiserver" "github.com/crowdsecurity/crowdsec/pkg/csconfig" ) -func initAPIServer(cConfig *csconfig.Config) (*apiserver.APIServer, error) { +func initAPIServer(ctx context.Context, cConfig *csconfig.Config) (*apiserver.APIServer, error) { if cConfig.API.Server.OnlineClient == nil || cConfig.API.Server.OnlineClient.Credentials == nil { log.Info("push and pull to Central API disabled") } - apiServer, err := apiserver.NewServer(cConfig.API.Server) + apiServer, err := apiserver.NewServer(ctx, cConfig.API.Server) if err != nil { return nil, fmt.Errorf("unable to run local API: %w", err) } @@ -39,7 +40,7 @@ func initAPIServer(cConfig *csconfig.Config) (*apiserver.APIServer, error) { return nil, errors.New("plugins are enabled, but config_paths.plugin_dir is not defined") } - err = pluginBroker.Init(cConfig.PluginConfig, cConfig.API.Server.Profiles, cConfig.ConfigPaths) + err = pluginBroker.Init(ctx, cConfig.PluginConfig, cConfig.API.Server.Profiles, cConfig.ConfigPaths) if err != nil { return nil, fmt.Errorf("unable to run plugin broker: %w", err) } @@ -56,12 +57,16 @@ func initAPIServer(cConfig *csconfig.Config) (*apiserver.APIServer, error) { return apiServer, nil } -func serveAPIServer(apiServer *apiserver.APIServer, apiReady chan bool) { +func serveAPIServer(apiServer *apiserver.APIServer) { + apiReady := make(chan bool, 1) + apiTomb.Go(func() error { defer trace.CatchPanic("crowdsec/serveAPIServer") + go func() { defer trace.CatchPanic("crowdsec/runAPIServer") log.Debugf("serving API after %s ms", time.Since(crowdsecT0)) + if err := apiServer.Run(apiReady); err != nil { log.Fatal(err) } @@ -75,11 +80,10 @@ func serveAPIServer(apiServer *apiserver.APIServer, apiReady chan bool) { <-apiTomb.Dying() // lock until go routine is dying pluginTomb.Kill(nil) log.Infof("serve: shutting down api server") - if err := apiServer.Shutdown(); err != nil { - return err - } - return nil + + return apiServer.Shutdown() }) + <-apiReady } func hasPlugins(profiles []*csconfig.ProfileCfg) bool { @@ -88,5 +92,6 @@ func hasPlugins(profiles []*csconfig.ProfileCfg) bool { return true } } + return false } diff --git a/cmd/crowdsec/appsec.go b/cmd/crowdsec/appsec.go new file mode 100644 index 00000000000..cb02b137dcd --- /dev/null +++ b/cmd/crowdsec/appsec.go @@ -0,0 +1,18 @@ +// +build !no_datasource_appsec + +package main + +import ( + "fmt" + + "github.com/crowdsecurity/crowdsec/pkg/appsec" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" +) + +func LoadAppsecRules(hub *cwhub.Hub) error { + if err := appsec.LoadAppsecRules(hub); err != nil { + return fmt.Errorf("while loading appsec rules: %w", err) + } + + return nil +} diff --git a/cmd/crowdsec/appsec_stub.go b/cmd/crowdsec/appsec_stub.go new file mode 100644 index 00000000000..4a65b32a9ad --- /dev/null +++ b/cmd/crowdsec/appsec_stub.go @@ -0,0 +1,11 @@ +//go:build no_datasource_appsec + +package main + +import ( + "github.com/crowdsecurity/crowdsec/pkg/cwhub" +) + +func LoadAppsecRules(hub *cwhub.Hub) error { + return nil +} diff --git a/cmd/crowdsec/crowdsec.go b/cmd/crowdsec/crowdsec.go index 68a7c6180da..c44d71d2093 100644 --- a/cmd/crowdsec/crowdsec.go +++ b/cmd/crowdsec/crowdsec.go @@ -1,235 +1,228 @@ package main import ( + "context" "fmt" "os" - "path/filepath" "sync" "time" log "github.com/sirupsen/logrus" - "gopkg.in/yaml.v2" - "github.com/crowdsecurity/go-cs-lib/pkg/trace" + "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/crowdsec/pkg/acquisition" + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" + "github.com/crowdsecurity/crowdsec/pkg/alertcontext" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/cwhub" + "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" leaky "github.com/crowdsecurity/crowdsec/pkg/leakybucket" "github.com/crowdsecurity/crowdsec/pkg/parser" "github.com/crowdsecurity/crowdsec/pkg/types" ) -func initCrowdsec(cConfig *csconfig.Config) (*parser.Parsers, error) { +// initCrowdsec prepares the log processor service +func initCrowdsec(cConfig *csconfig.Config, hub *cwhub.Hub) (*parser.Parsers, []acquisition.DataSource, error) { var err error - // Populate cwhub package tools - if err = cwhub.GetHubIdx(cConfig.Hub); err != nil { - return nil, fmt.Errorf("while loading hub index: %w", err) + if err = alertcontext.LoadConsoleContext(cConfig, hub); err != nil { + return nil, nil, fmt.Errorf("while loading context: %w", err) + } + + err = exprhelpers.GeoIPInit(hub.GetDataDir()) + if err != nil { + // GeoIP databases are not mandatory, do not make crowdsec fail if they are not present + log.Warnf("unable to initialize GeoIP: %s", err) } // Start loading configs - csParsers := parser.NewParsers() + csParsers := parser.NewParsers(hub) if csParsers, err = parser.LoadParsers(cConfig, csParsers); err != nil { - return nil, fmt.Errorf("while loading parsers: %w", err) + return nil, nil, fmt.Errorf("while loading parsers: %w", err) + } + + if err = LoadBuckets(cConfig, hub); err != nil { + return nil, nil, fmt.Errorf("while loading scenarios: %w", err) } - if err := LoadBuckets(cConfig); err != nil { - return nil, fmt.Errorf("while loading scenarios: %w", err) + // can be nerfed by a build flag + if err = LoadAppsecRules(hub); err != nil { + return nil, nil, err } - if err := LoadAcquisition(cConfig); err != nil { - return nil, fmt.Errorf("while loading acquisition config: %w", err) + datasources, err := LoadAcquisition(cConfig) + if err != nil { + return nil, nil, fmt.Errorf("while loading acquisition config: %w", err) } - return csParsers, nil + + return csParsers, datasources, nil } -func runCrowdsec(cConfig *csconfig.Config, parsers *parser.Parsers) error { +// runCrowdsec starts the log processor service +func runCrowdsec(cConfig *csconfig.Config, parsers *parser.Parsers, hub *cwhub.Hub, datasources []acquisition.DataSource) error { inputEventChan = make(chan types.Event) inputLineChan = make(chan types.Event) - //start go-routines for parsing, buckets pour and outputs. + // start go-routines for parsing, buckets pour and outputs. parserWg := &sync.WaitGroup{} + parsersTomb.Go(func() error { parserWg.Add(1) - for i := 0; i < cConfig.Crowdsec.ParserRoutinesCount; i++ { + + for range cConfig.Crowdsec.ParserRoutinesCount { parsersTomb.Go(func() error { defer trace.CatchPanic("crowdsec/runParse") - if err := runParse(inputLineChan, inputEventChan, *parsers.Ctx, parsers.Nodes); err != nil { //this error will never happen as parser.Parse is not able to return errors - log.Fatalf("starting parse error : %s", err) + + if err := runParse(inputLineChan, inputEventChan, *parsers.Ctx, parsers.Nodes); err != nil { + // this error will never happen as parser.Parse is not able to return errors return err } + return nil }) } + parserWg.Done() + return nil }) parserWg.Wait() bucketWg := &sync.WaitGroup{} + bucketsTomb.Go(func() error { bucketWg.Add(1) - /*restore previous state as well if present*/ + // restore previous state as well if present if cConfig.Crowdsec.BucketStateFile != "" { log.Warningf("Restoring buckets state from %s", cConfig.Crowdsec.BucketStateFile) + if err := leaky.LoadBucketsState(cConfig.Crowdsec.BucketStateFile, buckets, holders); err != nil { - return fmt.Errorf("unable to restore buckets : %s", err) + return fmt.Errorf("unable to restore buckets: %w", err) } } - for i := 0; i < cConfig.Crowdsec.BucketsRoutinesCount; i++ { + for range cConfig.Crowdsec.BucketsRoutinesCount { bucketsTomb.Go(func() error { defer trace.CatchPanic("crowdsec/runPour") - if err := runPour(inputEventChan, holders, buckets, cConfig); err != nil { - log.Fatalf("starting pour error : %s", err) - return err - } - return nil + + return runPour(inputEventChan, holders, buckets, cConfig) }) } + bucketWg.Done() + return nil }) bucketWg.Wait() + apiClient, err := AuthenticatedLAPIClient(*cConfig.API.Client.Credentials, hub) + if err != nil { + return err + } + + log.Debugf("Starting HeartBeat service") + apiClient.HeartBeat.StartHeartBeat(context.Background(), &outputsTomb) + outputWg := &sync.WaitGroup{} + outputsTomb.Go(func() error { outputWg.Add(1) - for i := 0; i < cConfig.Crowdsec.OutputRoutinesCount; i++ { + + for range cConfig.Crowdsec.OutputRoutinesCount { outputsTomb.Go(func() error { defer trace.CatchPanic("crowdsec/runOutput") - if err := runOutput(inputEventChan, outputEventChan, buckets, *parsers.Povfwctx, parsers.Povfwnodes, *cConfig.API.Client.Credentials); err != nil { - log.Fatalf("starting outputs error : %s", err) - return err - } - return nil + + return runOutput(inputEventChan, outputEventChan, buckets, *parsers.Povfwctx, parsers.Povfwnodes, apiClient) }) } + outputWg.Done() + return nil }) outputWg.Wait() + mp := NewMetricsProvider( + apiClient, + lpMetricsDefaultInterval, + log.WithField("service", "lpmetrics"), + []string{}, + datasources, + hub, + ) + + lpMetricsTomb.Go(func() error { + return mp.Run(context.Background(), &lpMetricsTomb) + }) + if cConfig.Prometheus != nil && cConfig.Prometheus.Enabled { aggregated := false - if cConfig.Prometheus.Level == "aggregated" { + if cConfig.Prometheus.Level == configuration.CFG_METRICS_AGGREGATE { aggregated = true } + if err := acquisition.GetMetrics(dataSources, aggregated); err != nil { return fmt.Errorf("while fetching prometheus metrics for datasources: %w", err) } - } + log.Info("Starting processing data") - if err := acquisition.StartAcquisition(dataSources, inputLineChan, &acquisTomb); err != nil { - log.Fatalf("starting acquisition error : %s", err) - return err + if err := acquisition.StartAcquisition(context.TODO(), dataSources, inputLineChan, &acquisTomb); err != nil { + return fmt.Errorf("starting acquisition error: %w", err) } return nil } -func serveCrowdsec(parsers *parser.Parsers, cConfig *csconfig.Config, agentReady chan bool) { +// serveCrowdsec wraps the log processor service +func serveCrowdsec(parsers *parser.Parsers, cConfig *csconfig.Config, hub *cwhub.Hub, datasources []acquisition.DataSource, agentReady chan bool) { crowdsecTomb.Go(func() error { defer trace.CatchPanic("crowdsec/serveCrowdsec") + go func() { defer trace.CatchPanic("crowdsec/runCrowdsec") // this logs every time, even at config reload log.Debugf("running agent after %s ms", time.Since(crowdsecT0)) agentReady <- true - if err := runCrowdsec(cConfig, parsers); err != nil { + + if err := runCrowdsec(cConfig, parsers, hub, datasources); err != nil { log.Fatalf("unable to start crowdsec routines: %s", err) } }() - /*we should stop in two cases : + /* we should stop in two cases : - crowdsecTomb has been Killed() : it might be shutdown or reload, so stop - acquisTomb is dead, it means that we were in "cat" mode and files are done reading, quit */ waitOnTomb() log.Debugf("Shutting down crowdsec routines") + if err := ShutdownCrowdsecRoutines(); err != nil { - log.Fatalf("unable to shutdown crowdsec routines: %s", err) + return fmt.Errorf("unable to shutdown crowdsec routines: %w", err) } + log.Debugf("everything is dead, return crowdsecTomb") + if dumpStates { - dumpParserState() - dumpOverflowState() - dumpBucketsPour() + if err := dumpAllStates(); err != nil { + log.Fatal(err) + } os.Exit(0) } + return nil }) } -func dumpBucketsPour() { - fd, err := os.OpenFile(filepath.Join(parser.DumpFolder, "bucketpour-dump.yaml"), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0666) - if err != nil { - log.Fatalf("open: %s", err) - } - out, err := yaml.Marshal(leaky.BucketPourCache) - if err != nil { - log.Fatalf("marshal: %s", err) - } - b, err := fd.Write(out) - if err != nil { - log.Fatalf("write: %s", err) - } - log.Tracef("wrote %d bytes", b) - if err := fd.Close(); err != nil { - log.Fatalf(" close: %s", err) - } -} - -func dumpParserState() { - - fd, err := os.OpenFile(filepath.Join(parser.DumpFolder, "parser-dump.yaml"), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0666) - if err != nil { - log.Fatalf("open: %s", err) - } - out, err := yaml.Marshal(parser.StageParseCache) - if err != nil { - log.Fatalf("marshal: %s", err) - } - b, err := fd.Write(out) - if err != nil { - log.Fatalf("write: %s", err) - } - log.Tracef("wrote %d bytes", b) - if err := fd.Close(); err != nil { - log.Fatalf(" close: %s", err) - } -} - -func dumpOverflowState() { - - fd, err := os.OpenFile(filepath.Join(parser.DumpFolder, "bucket-dump.yaml"), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0666) - if err != nil { - log.Fatalf("open: %s", err) - } - out, err := yaml.Marshal(bucketOverflows) - if err != nil { - log.Fatalf("marshal: %s", err) - } - b, err := fd.Write(out) - if err != nil { - log.Fatalf("write: %s", err) - } - log.Tracef("wrote %d bytes", b) - if err := fd.Close(); err != nil { - log.Fatalf(" close: %s", err) - } -} - func waitOnTomb() { for { select { case <-acquisTomb.Dead(): - /*if it's acquisition dying it means that we were in "cat" mode. + /* if it's acquisition dying it means that we were in "cat" mode. while shutting down, we need to give time for all buckets to process in flight data*/ - log.Warning("Acquisition is finished, shutting down") + log.Info("Acquisition is finished, shutting down") /* While it might make sense to want to shut-down parser/buckets/etc. as soon as acquisition is finished, we might have some pending buckets: buckets that overflowed, but whose LeakRoutine are still alive because they diff --git a/cmd/crowdsec/dump.go b/cmd/crowdsec/dump.go new file mode 100644 index 00000000000..33c65878b11 --- /dev/null +++ b/cmd/crowdsec/dump.go @@ -0,0 +1,56 @@ +package main + +import ( + "fmt" + "os" + "path/filepath" + + log "github.com/sirupsen/logrus" + "gopkg.in/yaml.v3" + + leaky "github.com/crowdsecurity/crowdsec/pkg/leakybucket" + "github.com/crowdsecurity/crowdsec/pkg/parser" +) + +func dumpAllStates() error { + log.Debugf("Dumping parser+bucket states to %s", parser.DumpFolder) + + if err := dumpState( + filepath.Join(parser.DumpFolder, "parser-dump.yaml"), + parser.StageParseCache, + ); err != nil { + return fmt.Errorf("while dumping parser state: %w", err) + } + + if err := dumpState( + filepath.Join(parser.DumpFolder, "bucket-dump.yaml"), + bucketOverflows, + ); err != nil { + return fmt.Errorf("while dumping bucket overflow state: %w", err) + } + + if err := dumpState( + filepath.Join(parser.DumpFolder, "bucketpour-dump.yaml"), + leaky.BucketPourCache, + ); err != nil { + return fmt.Errorf("while dumping bucket pour state: %w", err) + } + + return nil +} + +func dumpState(destPath string, obj any) error { + dir := filepath.Dir(destPath) + + err := os.MkdirAll(dir, 0o755) + if err != nil { + return err + } + + out, err := yaml.Marshal(obj) + if err != nil { + return err + } + + return os.WriteFile(destPath, out, 0o666) +} diff --git a/cmd/crowdsec/fatalhook.go b/cmd/crowdsec/fatalhook.go new file mode 100644 index 00000000000..84a57406a21 --- /dev/null +++ b/cmd/crowdsec/fatalhook.go @@ -0,0 +1,28 @@ +package main + +import ( + "io" + + log "github.com/sirupsen/logrus" +) + +// FatalHook is used to log fatal messages to stderr when the rest goes to a file +type FatalHook struct { + Writer io.Writer + LogLevels []log.Level +} + +func (hook *FatalHook) Fire(entry *log.Entry) error { + line, err := entry.String() + if err != nil { + return err + } + + _, err = hook.Writer.Write([]byte(line)) + + return err +} + +func (hook *FatalHook) Levels() []log.Level { + return hook.LogLevels +} diff --git a/cmd/crowdsec/lapiclient.go b/cmd/crowdsec/lapiclient.go new file mode 100644 index 00000000000..eed517f9df9 --- /dev/null +++ b/cmd/crowdsec/lapiclient.go @@ -0,0 +1,65 @@ +package main + +import ( + "context" + "fmt" + "net/url" + "time" + + "github.com/go-openapi/strfmt" + + "github.com/crowdsecurity/crowdsec/pkg/apiclient" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" + "github.com/crowdsecurity/crowdsec/pkg/models" +) + +func AuthenticatedLAPIClient(credentials csconfig.ApiCredentialsCfg, hub *cwhub.Hub) (*apiclient.ApiClient, error) { + apiURL, err := url.Parse(credentials.URL) + if err != nil { + return nil, fmt.Errorf("parsing api url ('%s'): %w", credentials.URL, err) + } + + papiURL, err := url.Parse(credentials.PapiURL) + if err != nil { + return nil, fmt.Errorf("parsing polling api url ('%s'): %w", credentials.PapiURL, err) + } + + password := strfmt.Password(credentials.Password) + + itemsForAPI := hub.GetInstalledListForAPI() + + client, err := apiclient.NewClient(&apiclient.Config{ + MachineID: credentials.Login, + Password: password, + Scenarios: itemsForAPI, + URL: apiURL, + PapiURL: papiURL, + VersionPrefix: "v1", + UpdateScenario: func(_ context.Context) ([]string, error) { + return itemsForAPI, nil + }, + }) + if err != nil { + return nil, fmt.Errorf("new client api: %w", err) + } + + authResp, _, err := client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ + MachineID: &credentials.Login, + Password: &password, + Scenarios: itemsForAPI, + }) + if err != nil { + return nil, fmt.Errorf("authenticate watcher (%s): %w", credentials.Login, err) + } + + var expiration time.Time + if err := expiration.UnmarshalText([]byte(authResp.Expire)); err != nil { + return nil, fmt.Errorf("unable to parse jwt expiration: %w", err) + } + + client.GetClient().Transport.(*apiclient.JWTTransport).Token = authResp.Token + client.GetClient().Transport.(*apiclient.JWTTransport).Expiration = expiration + + return client, nil +} diff --git a/cmd/crowdsec/lpmetrics.go b/cmd/crowdsec/lpmetrics.go new file mode 100644 index 00000000000..24842851294 --- /dev/null +++ b/cmd/crowdsec/lpmetrics.go @@ -0,0 +1,180 @@ +package main + +import ( + "context" + "errors" + "net/http" + "time" + + "github.com/sirupsen/logrus" + "gopkg.in/tomb.v2" + + "github.com/crowdsecurity/go-cs-lib/ptr" + "github.com/crowdsecurity/go-cs-lib/trace" + "github.com/crowdsecurity/go-cs-lib/version" + + "github.com/crowdsecurity/crowdsec/pkg/acquisition" + "github.com/crowdsecurity/crowdsec/pkg/apiclient" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" + "github.com/crowdsecurity/crowdsec/pkg/fflag" + "github.com/crowdsecurity/crowdsec/pkg/models" +) + +const lpMetricsDefaultInterval = 30 * time.Minute + +// MetricsProvider collects metrics from the LP and sends them to the LAPI +type MetricsProvider struct { + apic *apiclient.ApiClient + interval time.Duration + static staticMetrics + logger *logrus.Entry +} + +type staticMetrics struct { + osName string + osVersion string + startupTS int64 + featureFlags []string + consoleOptions []string + datasourceMap map[string]int64 + hubState models.HubItems +} + +func getHubState(hub *cwhub.Hub) models.HubItems { + ret := models.HubItems{} + + for _, itemType := range cwhub.ItemTypes { + ret[itemType] = []models.HubItem{} + + for _, item := range hub.GetInstalledByType(itemType, true) { + status := "official" + if item.State.IsLocal() { + status = "custom" + } + if item.State.Tainted { + status = "tainted" + } + ret[itemType] = append(ret[itemType], models.HubItem{ + Name: item.Name, + Status: status, + Version: item.Version, + }) + } + } + + return ret +} + +// newStaticMetrics is called when the process starts, or reloads the configuration +func newStaticMetrics(consoleOptions []string, datasources []acquisition.DataSource, hub *cwhub.Hub) staticMetrics { + datasourceMap := map[string]int64{} + + for _, ds := range datasources { + datasourceMap[ds.GetName()] += 1 + } + + osName, osVersion := version.DetectOS() + + return staticMetrics{ + osName: osName, + osVersion: osVersion, + startupTS: time.Now().UTC().Unix(), + featureFlags: fflag.Crowdsec.GetEnabledFeatures(), + consoleOptions: consoleOptions, + datasourceMap: datasourceMap, + hubState: getHubState(hub), + } +} + +func NewMetricsProvider(apic *apiclient.ApiClient, interval time.Duration, logger *logrus.Entry, + consoleOptions []string, datasources []acquisition.DataSource, hub *cwhub.Hub, +) *MetricsProvider { + return &MetricsProvider{ + apic: apic, + interval: interval, + logger: logger, + static: newStaticMetrics(consoleOptions, datasources, hub), + } +} + +func (m *MetricsProvider) metricsPayload() *models.AllMetrics { + os := &models.OSversion{ + Name: ptr.Of(m.static.osName), + Version: ptr.Of(m.static.osVersion), + } + + base := models.BaseMetrics{ + UtcStartupTimestamp: ptr.Of(m.static.startupTS), + Os: os, + Version: ptr.Of(version.String()), + FeatureFlags: m.static.featureFlags, + Metrics: make([]*models.DetailedMetrics, 0), + } + + met := &models.LogProcessorsMetrics{ + BaseMetrics: base, + Datasources: m.static.datasourceMap, + HubItems: m.static.hubState, + } + + met.Metrics = append(met.Metrics, &models.DetailedMetrics{ + Meta: &models.MetricsMeta{ + UtcNowTimestamp: ptr.Of(time.Now().Unix()), + WindowSizeSeconds: ptr.Of(int64(m.interval.Seconds())), + }, + Items: make([]*models.MetricsDetailItem, 0), + }) + + return &models.AllMetrics{ + LogProcessors: []*models.LogProcessorsMetrics{met}, + } +} + +func (m *MetricsProvider) Run(ctx context.Context, myTomb *tomb.Tomb) error { + defer trace.CatchPanic("crowdsec/MetricsProvider.Run") + + if m.interval == time.Duration(0) { + return nil + } + + met := m.metricsPayload() + + ticker := time.NewTicker(1) // Send on start + + for { + select { + case <-ticker.C: + ctxTime, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + _, resp, err := m.apic.UsageMetrics.Add(ctxTime, met) + switch { + case errors.Is(err, context.DeadlineExceeded): + m.logger.Warnf("timeout sending lp metrics") + ticker.Reset(m.interval) + continue + case err != nil && resp != nil && resp.Response.StatusCode == http.StatusNotFound: + m.logger.Warnf("metrics endpoint not found, older LAPI?") + ticker.Reset(m.interval) + continue + case err != nil: + m.logger.Warnf("failed to send lp metrics: %s", err) + ticker.Reset(m.interval) + continue + } + + if resp.Response.StatusCode != http.StatusCreated { + m.logger.Warnf("failed to send lp metrics: %s", resp.Response.Status) + ticker.Reset(m.interval) + continue + } + + ticker.Reset(m.interval) + + m.logger.Tracef("lp usage metrics sent") + case <-myTomb.Dying(): + ticker.Stop() + return nil + } + } +} diff --git a/cmd/crowdsec/main.go b/cmd/crowdsec/main.go index 5c50884b5d9..6d8ca24c335 100644 --- a/cmd/crowdsec/main.go +++ b/cmd/crowdsec/main.go @@ -1,18 +1,22 @@ package main import ( + "errors" "flag" "fmt" _ "net/http/pprof" "os" + "path/filepath" "runtime" + "runtime/pprof" "strings" "time" - "github.com/pkg/errors" log "github.com/sirupsen/logrus" "gopkg.in/tomb.v2" + "github.com/crowdsecurity/go-cs-lib/trace" + "github.com/crowdsecurity/crowdsec/pkg/acquisition" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csplugin" @@ -25,28 +29,29 @@ import ( ) var ( - /*tombs for the parser, buckets and outputs.*/ - acquisTomb tomb.Tomb - parsersTomb tomb.Tomb - bucketsTomb tomb.Tomb - outputsTomb tomb.Tomb - apiTomb tomb.Tomb - crowdsecTomb tomb.Tomb - pluginTomb tomb.Tomb + // tombs for the parser, buckets and outputs. + acquisTomb tomb.Tomb + parsersTomb tomb.Tomb + bucketsTomb tomb.Tomb + outputsTomb tomb.Tomb + apiTomb tomb.Tomb + crowdsecTomb tomb.Tomb + pluginTomb tomb.Tomb + lpMetricsTomb tomb.Tomb flags *Flags - /*the state of acquisition*/ + // the state of acquisition dataSources []acquisition.DataSource - /*the state of the buckets*/ + // the state of the buckets holders []leakybucket.BucketFactory buckets *leakybucket.Buckets inputLineChan chan types.Event inputEventChan chan types.Event outputEventChan chan types.Event // the buckets init returns its own chan that is used for multiplexing - /*settings*/ - lastProcessedItem time.Time /*keep track of last item timestamp in time-machine. it is used to GC buckets when we dump them.*/ + // settings + lastProcessedItem time.Time // keep track of last item timestamp in time-machine. it is used to GC buckets when we dump them. pluginBroker csplugin.PluginBroker ) @@ -71,27 +76,32 @@ type Flags struct { DisableCAPI bool Transform string OrderEvent bool + CPUProfile string +} + +func (f *Flags) haveTimeMachine() bool { + return f.OneShotDSN != "" } type labelsMap map[string]string -func LoadBuckets(cConfig *csconfig.Config) error { +func LoadBuckets(cConfig *csconfig.Config, hub *cwhub.Hub) error { var ( err error files []string ) - for _, hubScenarioItem := range cwhub.GetItemMap(cwhub.SCENARIOS) { - if hubScenarioItem.Installed { - files = append(files, hubScenarioItem.LocalPath) - } + + for _, hubScenarioItem := range hub.GetInstalledByType(cwhub.SCENARIOS, false) { + files = append(files, hubScenarioItem.State.LocalPath) } + buckets = leakybucket.NewBuckets() log.Infof("Loading %d scenario files", len(files)) - holders, outputEventChan, err = leakybucket.LoadBuckets(cConfig.Crowdsec, files, &bucketsTomb, buckets, flags.OrderEvent) + holders, outputEventChan, err = leakybucket.LoadBuckets(cConfig.Crowdsec, hub, files, &bucketsTomb, buckets, flags.OrderEvent) if err != nil { - return fmt.Errorf("scenario loading failed: %v", err) + return fmt.Errorf("scenario loading failed: %w", err) } if cConfig.Prometheus != nil && cConfig.Prometheus.Enabled { @@ -99,10 +109,11 @@ func LoadBuckets(cConfig *csconfig.Config) error { holders[holderIndex].Profiling = true } } + return nil } -func LoadAcquisition(cConfig *csconfig.Config) error { +func LoadAcquisition(cConfig *csconfig.Config) ([]acquisition.DataSource, error) { var err error if flags.SingleFileType != "" && flags.OneShotDSN != "" { @@ -111,20 +122,20 @@ func LoadAcquisition(cConfig *csconfig.Config) error { dataSources, err = acquisition.LoadAcquisitionFromDSN(flags.OneShotDSN, flags.Labels, flags.Transform) if err != nil { - return errors.Wrapf(err, "failed to configure datasource for %s", flags.OneShotDSN) + return nil, fmt.Errorf("failed to configure datasource for %s: %w", flags.OneShotDSN, err) } } else { - dataSources, err = acquisition.LoadAcquisitionFromFile(cConfig.Crowdsec) + dataSources, err = acquisition.LoadAcquisitionFromFile(cConfig.Crowdsec, cConfig.Prometheus) if err != nil { - return err + return nil, err } } if len(dataSources) == 0 { - return fmt.Errorf("no datasource enabled") + return nil, errors.New("no datasource enabled") } - return nil + return dataSources, nil } var ( @@ -138,11 +149,15 @@ func (l *labelsMap) String() string { } func (l labelsMap) Set(label string) error { - split := strings.Split(label, ":") - if len(split) != 2 { - return errors.Wrapf(errors.New("Bad Format"), "for Label '%s'", label) + for _, pair := range strings.Split(label, ",") { + split := strings.Split(pair, ":") + if len(split) != 2 { + return fmt.Errorf("invalid format for label '%s', must be key:value", pair) + } + + l[split[0]] = split[1] } - l[split[0]] = split[1] + return nil } @@ -166,10 +181,13 @@ func (f *Flags) Parse() { flag.BoolVar(&f.DisableAPI, "no-api", false, "disable local API") flag.BoolVar(&f.DisableCAPI, "no-capi", false, "disable communication with Central API") flag.BoolVar(&f.OrderEvent, "order-event", false, "enforce event ordering with significant performance cost") + if runtime.GOOS == "windows" { flag.StringVar(&f.WinSvc, "winsvc", "", "Windows service Action: Install, Remove etc..") } + flag.StringVar(&dumpFolder, "dump-data", "", "dump parsers/buckets raw outputs") + flag.StringVar(&f.CPUProfile, "cpu-profile", "", "write cpu profile to file") flag.Parse() } @@ -203,6 +221,7 @@ func newLogLevel(curLevelPtr *log.Level, f *Flags) *log.Level { // avoid returning a new ptr to the same value return curLevelPtr } + return &ret } @@ -210,11 +229,11 @@ func newLogLevel(curLevelPtr *log.Level, f *Flags) *log.Level { func LoadConfig(configFile string, disableAgent bool, disableAPI bool, quiet bool) (*csconfig.Config, error) { cConfig, _, err := csconfig.NewConfig(configFile, disableAgent, disableAPI, quiet) if err != nil { - return nil, err + return nil, fmt.Errorf("while loading configuration file: %w", err) } - if (cConfig.Common == nil || *cConfig.Common == csconfig.CommonCfg{}) { - return nil, fmt.Errorf("unable to load configuration: common section is empty") + if err := trace.Init(filepath.Join(cConfig.ConfigPaths.DataDir, "trace")); err != nil { + return nil, fmt.Errorf("while setting up trace directory: %w", err) } cConfig.Common.LogLevel = newLogLevel(cConfig.Common.LogLevel, flags) @@ -226,11 +245,6 @@ func LoadConfig(configFile string, disableAgent bool, disableAPI bool, quiet boo dumpStates = true } - // Configuration paths are dependency to load crowdsec configuration - if err := cConfig.LoadConfigurationPaths(); err != nil { - return nil, err - } - if flags.SingleFileType != "" && flags.OneShotDSN != "" { // if we're in time-machine mode, we don't want to log to file cConfig.Common.LogMedia = "stdout" @@ -245,18 +259,25 @@ func LoadConfig(configFile string, disableAgent bool, disableAPI bool, quiet boo return nil, err } + if cConfig.Common.LogMedia != "stdout" { + log.AddHook(&FatalHook{ + Writer: os.Stderr, + LogLevels: []log.Level{log.FatalLevel, log.PanicLevel}, + }) + } + if err := csconfig.LoadFeatureFlagsFile(configFile, log.StandardLogger()); err != nil { return nil, err } - if !flags.DisableAgent { + if !cConfig.DisableAgent { if err := cConfig.LoadCrowdsec(); err != nil { return nil, err } } - if !flags.DisableAPI { - if err := cConfig.LoadAPIServer(); err != nil { + if !cConfig.DisableAPI { + if err := cConfig.LoadAPIServer(false); err != nil { return nil, err } } @@ -266,11 +287,7 @@ func LoadConfig(configFile string, disableAgent bool, disableAPI bool, quiet boo } if cConfig.DisableAPI && cConfig.DisableAgent { - return nil, errors.New("You must run at least the API Server or crowdsec") - } - - if flags.TestMode && !cConfig.DisableAgent { - cConfig.Crowdsec.LintOnly = true + return nil, errors.New("you must run at least the API Server or crowdsec") } if flags.OneShotDSN != "" && flags.SingleFileType == "" { @@ -289,10 +306,11 @@ func LoadConfig(configFile string, disableAgent bool, disableAPI bool, quiet boo if cConfig.API != nil && cConfig.API.Server != nil { cConfig.API.Server.OnlineClient = nil } - /*if the api is disabled as well, just read file and exit, don't daemonize*/ - if flags.DisableAPI { + // if the api is disabled as well, just read file and exit, don't daemonize + if cConfig.DisableAPI { cConfig.Common.Daemonize = false } + log.Infof("single file mode : log_media=%s daemonize=%t", cConfig.Common.LogMedia, cConfig.Common.Daemonize) } @@ -302,6 +320,7 @@ func LoadConfig(configFile string, disableAgent bool, disableAPI bool, quiet boo if cConfig.Common.Daemonize && runtime.GOOS == "windows" { log.Debug("Daemonization is not supported on Windows, disabling") + cConfig.Common.Daemonize = false } @@ -319,6 +338,10 @@ func LoadConfig(configFile string, disableAgent bool, disableAPI bool, quiet boo var crowdsecT0 time.Time func main() { + // The initial log level is INFO, even if the user provided an -error or -warning flag + // because we need feature flags before parsing cli flags + log.SetFormatter(&log.TextFormatter{TimestampFormat: time.RFC3339, FullTimestamp: true}) + if err := fflag.RegisterAllFeatures(); err != nil { log.Fatalf("failed to register features: %s", err) } @@ -345,13 +368,32 @@ func main() { } if flags.PrintVersion { - cwversion.Show() + os.Stdout.WriteString(cwversion.FullString()) os.Exit(0) } + if flags.CPUProfile != "" { + f, err := os.Create(flags.CPUProfile) + if err != nil { + log.Fatalf("could not create CPU profile: %s", err) + } + + log.Infof("CPU profile will be written to %s", flags.CPUProfile) + + if err := pprof.StartCPUProfile(f); err != nil { + f.Close() + log.Fatalf("could not start CPU profile: %s", err) + } + + defer f.Close() + defer pprof.StopCPUProfile() + } + err := StartRunSvc() if err != nil { - log.Fatal(err) + pprof.StopCPUProfile() + log.Fatal(err) //nolint:gocritic // Disable warning for the defer pprof.StopCPUProfile() call } + os.Exit(0) } diff --git a/cmd/crowdsec/metrics.go b/cmd/crowdsec/metrics.go index 8e87eecd037..ff280fc3512 100644 --- a/cmd/crowdsec/metrics.go +++ b/cmd/crowdsec/metrics.go @@ -3,16 +3,16 @@ package main import ( "fmt" "net/http" - "time" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" log "github.com/sirupsen/logrus" - "github.com/crowdsecurity/go-cs-lib/pkg/trace" - "github.com/crowdsecurity/go-cs-lib/pkg/version" + "github.com/crowdsecurity/go-cs-lib/trace" + "github.com/crowdsecurity/go-cs-lib/version" - v1 "github.com/crowdsecurity/crowdsec/pkg/apiserver/controllers/v1" + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" + "github.com/crowdsecurity/crowdsec/pkg/apiserver/controllers/v1" "github.com/crowdsecurity/crowdsec/pkg/cache" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/database" @@ -21,7 +21,8 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/parser" ) -/*prometheus*/ +// Prometheus + var globalParserHits = prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "cs_parser_hits_total", @@ -29,6 +30,7 @@ var globalParserHits = prometheus.NewCounterVec( }, []string{"source", "type"}, ) + var globalParserHitsOk = prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "cs_parser_hits_ok_total", @@ -36,6 +38,7 @@ var globalParserHitsOk = prometheus.NewCounterVec( }, []string{"source", "type"}, ) + var globalParserHitsKo = prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "cs_parser_hits_ko_total", @@ -102,25 +105,31 @@ var globalPourHistogram = prometheus.NewHistogramVec( func computeDynamicMetrics(next http.Handler, dbClient *database.Client) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - //update cache metrics (stash) + // catch panics here because they are not handled by servePrometheus + defer trace.CatchPanic("crowdsec/computeDynamicMetrics") + // update cache metrics (stash) cache.UpdateCacheMetrics() - //update cache metrics (regexp) + // update cache metrics (regexp) exprhelpers.UpdateRegexpCacheMetrics() - //decision metrics are only relevant for LAPI + // decision metrics are only relevant for LAPI if dbClient == nil { next.ServeHTTP(w, r) return } - decisionsFilters := make(map[string][]string, 0) - decisions, err := dbClient.QueryDecisionCountByScenario(decisionsFilters) + ctx := r.Context() + + decisions, err := dbClient.QueryDecisionCountByScenario(ctx) if err != nil { log.Errorf("Error querying decisions for metrics: %v", err) next.ServeHTTP(w, r) + return } + globalActiveDecisions.Reset() + for _, d := range decisions { globalActiveDecisions.With(prometheus.Labels{"reason": d.Scenario, "origin": d.Origin, "action": d.Type}).Set(float64(d.Count)) } @@ -131,11 +140,11 @@ func computeDynamicMetrics(next http.Handler, dbClient *database.Client) http.Ha "include_capi": {"false"}, } - alerts, err := dbClient.AlertsCountPerScenario(alertsFilter) - + alerts, err := dbClient.AlertsCountPerScenario(ctx, alertsFilter) if err != nil { log.Errorf("Error querying alerts for metrics: %v", err) next.ServeHTTP(w, r) + return } @@ -151,25 +160,18 @@ func registerPrometheus(config *csconfig.PrometheusCfg) { if !config.Enabled { return } - if config.ListenAddr == "" { - log.Warning("prometheus is enabled, but the listen address is empty, using '127.0.0.1'") - config.ListenAddr = "127.0.0.1" - } - if config.ListenPort == 0 { - log.Warning("prometheus is enabled, but the listen port is empty, using '6060'") - config.ListenPort = 6060 - } // Registering prometheus // If in aggregated mode, do not register events associated with a source, to keep the cardinality low - if config.Level == "aggregated" { + if config.Level == configuration.CFG_METRICS_AGGREGATE { log.Infof("Loading aggregated prometheus collectors") prometheus.MustRegister(globalParserHits, globalParserHitsOk, globalParserHitsKo, globalCsInfo, globalParsingHistogram, globalPourHistogram, leaky.BucketsUnderflow, leaky.BucketsCanceled, leaky.BucketsInstantiation, leaky.BucketsOverflow, v1.LapiRouteHits, leaky.BucketsCurrentCount, - cache.CacheMetrics, exprhelpers.RegexpCacheMetrics) + cache.CacheMetrics, exprhelpers.RegexpCacheMetrics, parser.NodesWlHitsOk, parser.NodesWlHits, + ) } else { log.Infof("Loading prometheus collectors") prometheus.MustRegister(globalParserHits, globalParserHitsOk, globalParserHitsKo, @@ -177,13 +179,15 @@ func registerPrometheus(config *csconfig.PrometheusCfg) { globalCsInfo, globalParsingHistogram, globalPourHistogram, v1.LapiRouteHits, v1.LapiMachineHits, v1.LapiBouncerHits, v1.LapiNilDecisions, v1.LapiNonNilDecisions, v1.LapiResponseTime, leaky.BucketsPour, leaky.BucketsUnderflow, leaky.BucketsCanceled, leaky.BucketsInstantiation, leaky.BucketsOverflow, leaky.BucketsCurrentCount, - globalActiveDecisions, globalAlerts, - cache.CacheMetrics, exprhelpers.RegexpCacheMetrics) - + globalActiveDecisions, globalAlerts, parser.NodesWlHitsOk, parser.NodesWlHits, + cache.CacheMetrics, exprhelpers.RegexpCacheMetrics, + ) } } -func servePrometheus(config *csconfig.PrometheusCfg, dbClient *database.Client, apiReady chan bool, agentReady chan bool) { +func servePrometheus(config *csconfig.PrometheusCfg, dbClient *database.Client, agentReady chan bool) { + <-agentReady + if !config.Enabled { return } @@ -191,10 +195,11 @@ func servePrometheus(config *csconfig.PrometheusCfg, dbClient *database.Client, defer trace.CatchPanic("crowdsec/servePrometheus") http.Handle("/metrics", computeDynamicMetrics(promhttp.Handler(), dbClient)) - <-apiReady - <-agentReady - log.Debugf("serving metrics after %s ms", time.Since(crowdsecT0)) + if err := http.ListenAndServe(fmt.Sprintf("%s:%d", config.ListenAddr, config.ListenPort), nil); err != nil { - log.Warningf("prometheus: %s", err) + // in time machine, we most likely have the LAPI using the port + if !flags.haveTimeMachine() { + log.Warningf("prometheus: %s", err) + } } } diff --git a/cmd/crowdsec/output.go b/cmd/crowdsec/output.go index 933706a211a..6f507fdcd6f 100644 --- a/cmd/crowdsec/output.go +++ b/cmd/crowdsec/output.go @@ -3,18 +3,12 @@ package main import ( "context" "fmt" - "net/url" "sync" "time" - "github.com/go-openapi/strfmt" log "github.com/sirupsen/logrus" - "github.com/crowdsecurity/go-cs-lib/pkg/version" - "github.com/crowdsecurity/crowdsec/pkg/apiclient" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" - "github.com/crowdsecurity/crowdsec/pkg/cwhub" leaky "github.com/crowdsecurity/crowdsec/pkg/leakybucket" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/parser" @@ -22,7 +16,6 @@ import ( ) func dedupAlerts(alerts []types.RuntimeAlert) ([]*models.Alert, error) { - var dedupCache []*models.Alert for idx, alert := range alerts { @@ -32,90 +25,51 @@ func dedupAlerts(alerts []types.RuntimeAlert) ([]*models.Alert, error) { dedupCache = append(dedupCache, alert.Alert) continue } - for k, src := range alert.Sources { - refsrc := *alert.Alert //copy + + for k := range alert.Sources { + refsrc := *alert.Alert // copy + log.Tracef("source[%s]", k) + + src := alert.Sources[k] refsrc.Source = &src dedupCache = append(dedupCache, &refsrc) } } + if len(dedupCache) != len(alerts) { log.Tracef("went from %d to %d alerts", len(alerts), len(dedupCache)) } + return dedupCache, nil } func PushAlerts(alerts []types.RuntimeAlert, client *apiclient.ApiClient) error { ctx := context.Background() - alertsToPush, err := dedupAlerts(alerts) + alertsToPush, err := dedupAlerts(alerts) if err != nil { return fmt.Errorf("failed to transform alerts for api: %w", err) } + _, _, err = client.Alerts.Add(ctx, alertsToPush) if err != nil { return fmt.Errorf("failed sending alert to LAPI: %w", err) } + return nil } var bucketOverflows []types.Event -func runOutput(input chan types.Event, overflow chan types.Event, buckets *leaky.Buckets, - postOverflowCTX parser.UnixParserCtx, postOverflowNodes []parser.Node, apiConfig csconfig.ApiCredentialsCfg) error { +func runOutput(input chan types.Event, overflow chan types.Event, buckets *leaky.Buckets, postOverflowCTX parser.UnixParserCtx, + postOverflowNodes []parser.Node, client *apiclient.ApiClient) error { + var ( + cache []types.RuntimeAlert + cacheMutex sync.Mutex + ) - var err error ticker := time.NewTicker(1 * time.Second) - - var cache []types.RuntimeAlert - var cacheMutex sync.Mutex - - scenarios, err := cwhub.GetInstalledScenariosAsString() - if err != nil { - return fmt.Errorf("loading list of installed hub scenarios: %w", err) - } - - apiURL, err := url.Parse(apiConfig.URL) - if err != nil { - return fmt.Errorf("parsing api url ('%s'): %w", apiConfig.URL, err) - } - papiURL, err := url.Parse(apiConfig.PapiURL) - if err != nil { - return fmt.Errorf("parsing polling api url ('%s'): %w", apiConfig.PapiURL, err) - } - password := strfmt.Password(apiConfig.Password) - - Client, err := apiclient.NewClient(&apiclient.Config{ - MachineID: apiConfig.Login, - Password: password, - Scenarios: scenarios, - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), - URL: apiURL, - PapiURL: papiURL, - VersionPrefix: "v1", - UpdateScenario: cwhub.GetInstalledScenariosAsString, - }) - if err != nil { - return fmt.Errorf("new client api: %w", err) - } - authResp, _, err := Client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ - MachineID: &apiConfig.Login, - Password: &password, - Scenarios: scenarios, - }) - if err != nil { - return fmt.Errorf("authenticate watcher (%s): %w", apiConfig.Login, err) - } - - if err := Client.GetClient().Transport.(*apiclient.JWTTransport).Expiration.UnmarshalText([]byte(authResp.Expire)); err != nil { - return fmt.Errorf("unable to parse jwt expiration: %w", err) - } - - Client.GetClient().Transport.(*apiclient.JWTTransport).Token = authResp.Token - - //start the heartbeat service - log.Debugf("Starting HeartBeat service") - Client.HeartBeat.StartHeartBeat(context.Background(), &outputsTomb) LOOP: for { select { @@ -126,9 +80,9 @@ LOOP: newcache := make([]types.RuntimeAlert, 0) cache = newcache cacheMutex.Unlock() - if err := PushAlerts(cachecopy, Client); err != nil { + if err := PushAlerts(cachecopy, client); err != nil { log.Errorf("while pushing to api : %s", err) - //just push back the events to the queue + // just push back the events to the queue cacheMutex.Lock() cache = append(cache, cachecopy...) cacheMutex.Unlock() @@ -139,19 +93,13 @@ LOOP: cacheMutex.Lock() cachecopy := cache cacheMutex.Unlock() - if err := PushAlerts(cachecopy, Client); err != nil { + if err := PushAlerts(cachecopy, client); err != nil { log.Errorf("while pushing leftovers to api : %s", err) } } + break LOOP case event := <-overflow: - //if the Alert is nil, it's to signal bucket is ready for GC, don't track this - if dumpStates && event.Overflow.Alert != nil { - if bucketOverflows == nil { - bucketOverflows = make([]types.Event, 0) - } - bucketOverflows = append(bucketOverflows, event) - } /*if alert is empty and mapKey is present, the overflow is just to cleanup bucket*/ if event.Overflow.Alert == nil && event.Overflow.Mapkey != "" { buckets.Bucket_map.Delete(event.Overflow.Mapkey) @@ -160,9 +108,17 @@ LOOP: /* process post overflow parser nodes */ event, err := parser.Parse(postOverflowCTX, event, postOverflowNodes) if err != nil { - return fmt.Errorf("postoverflow failed : %s", err) + return fmt.Errorf("postoverflow failed: %w", err) } log.Printf("%s", *event.Overflow.Alert.Message) + // if the Alert is nil, it's to signal bucket is ready for GC, don't track this + // dump after postoveflow processing to avoid missing whitelist info + if dumpStates && event.Overflow.Alert != nil { + if bucketOverflows == nil { + bucketOverflows = make([]types.Event, 0) + } + bucketOverflows = append(bucketOverflows, event) + } if event.Overflow.Whitelisted { log.Printf("[%s] is whitelisted, skip.", *event.Overflow.Alert.Message) continue @@ -182,6 +138,6 @@ LOOP: } ticker.Stop() - return nil + return nil } diff --git a/cmd/crowdsec/parse.go b/cmd/crowdsec/parse.go index aa93b6ec7f9..26eae66be2b 100644 --- a/cmd/crowdsec/parse.go +++ b/cmd/crowdsec/parse.go @@ -11,17 +11,22 @@ import ( ) func runParse(input chan types.Event, output chan types.Event, parserCTX parser.UnixParserCtx, nodes []parser.Node) error { - -LOOP: for { select { case <-parsersTomb.Dying(): log.Infof("Killing parser routines") - break LOOP + return nil case event := <-input: if !event.Process { continue } + /*Application security engine is going to generate 2 events: + - one that is treated as a log and can go to scenarios + - another one that will go directly to LAPI*/ + if event.Type == types.APPSEC { + outputEventChan <- event + continue + } if event.Line.Module == "" { log.Errorf("empty event.Line.Module field, the acquisition module must set it ! : %+v", event.Line) continue @@ -32,7 +37,7 @@ LOOP: /* parse the log using magic */ parsed, err := parser.Parse(parserCTX, event, nodes) if err != nil { - log.Errorf("failed parsing : %v\n", err) + log.Errorf("failed parsing: %v", err) } elapsed := time.Since(startParsing) globalParsingHistogram.With(prometheus.Labels{"source": event.Line.Src, "type": event.Line.Module}).Observe(elapsed.Seconds()) @@ -49,5 +54,4 @@ LOOP: output <- parsed } } - return nil } diff --git a/cmd/crowdsec/pour.go b/cmd/crowdsec/pour.go index 3f717e3975d..2fc7d7e42c9 100644 --- a/cmd/crowdsec/pour.go +++ b/cmd/crowdsec/pour.go @@ -4,57 +4,64 @@ import ( "fmt" "time" + "github.com/prometheus/client_golang/prometheus" + log "github.com/sirupsen/logrus" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" leaky "github.com/crowdsecurity/crowdsec/pkg/leakybucket" "github.com/crowdsecurity/crowdsec/pkg/types" - "github.com/prometheus/client_golang/prometheus" - log "github.com/sirupsen/logrus" ) func runPour(input chan types.Event, holders []leaky.BucketFactory, buckets *leaky.Buckets, cConfig *csconfig.Config) error { count := 0 + for { - //bucket is now ready + // bucket is now ready select { case <-bucketsTomb.Dying(): log.Infof("Bucket routine exiting") return nil case parsed := <-input: startTime := time.Now() + count++ if count%5000 == 0 { log.Infof("%d existing buckets", leaky.LeakyRoutineCount) - //when in forensics mode, garbage collect buckets + // when in forensics mode, garbage collect buckets if cConfig.Crowdsec.BucketsGCEnabled { if parsed.MarshaledTime != "" { z := &time.Time{} if err := z.UnmarshalText([]byte(parsed.MarshaledTime)); err != nil { - log.Warningf("Failed to unmarshal time from event '%s' : %s", parsed.MarshaledTime, err) + log.Warningf("Failed to parse time from event '%s' : %s", parsed.MarshaledTime, err) } else { log.Warning("Starting buckets garbage collection ...") + if err = leaky.GarbageCollectBuckets(*z, buckets); err != nil { - return fmt.Errorf("failed to start bucket GC : %s", err) + return fmt.Errorf("failed to start bucket GC : %w", err) } } } } } - //here we can bucketify with parsed + // here we can bucketify with parsed poured, err := leaky.PourItemToHolders(parsed, holders, buckets) if err != nil { log.Errorf("bucketify failed for: %v", parsed) continue } + elapsed := time.Since(startTime) globalPourHistogram.With(prometheus.Labels{"type": parsed.Line.Module, "source": parsed.Line.Src}).Observe(elapsed.Seconds()) + if poured { globalBucketPourOk.Inc() } else { globalBucketPourKo.Inc() } - if len(parsed.MarshaledTime) != 0 { + + if parsed.MarshaledTime != "" { if err := lastProcessedItem.UnmarshalText([]byte(parsed.MarshaledTime)); err != nil { - log.Warningf("failed to unmarshal time from event : %s", err) + log.Warningf("failed to parse time from event : %s", err) } } } diff --git a/cmd/crowdsec/run_in_svc.go b/cmd/crowdsec/run_in_svc.go index a5ab996b5dc..288b565e890 100644 --- a/cmd/crowdsec/run_in_svc.go +++ b/cmd/crowdsec/run_in_svc.go @@ -1,17 +1,16 @@ -//go:build linux || freebsd || netbsd || openbsd || solaris || !windows -// +build linux freebsd netbsd openbsd solaris !windows +//go:build !windows package main import ( + "context" "fmt" - "os" + "runtime/pprof" log "github.com/sirupsen/logrus" - "github.com/sirupsen/logrus/hooks/writer" - "github.com/crowdsecurity/go-cs-lib/pkg/trace" - "github.com/crowdsecurity/go-cs-lib/pkg/version" + "github.com/crowdsecurity/go-cs-lib/trace" + "github.com/crowdsecurity/go-cs-lib/version" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/database" @@ -25,15 +24,9 @@ func StartRunSvc() error { defer trace.CatchPanic("crowdsec/StartRunSvc") - // Set a default logger with level=fatal on stderr, - // in addition to the one we configure afterwards - log.AddHook(&writer.Hook{ - Writer: os.Stderr, - LogLevels: []log.Level{ - log.PanicLevel, - log.FatalLevel, - }, - }) + // Always try to stop CPU profiling to avoid passing flags around + // It's a noop if profiling is not enabled + defer pprof.StopCPUProfile() if cConfig, err = LoadConfig(flags.ConfigFile, flags.DisableAgent, flags.DisableAPI, false); err != nil { return err @@ -41,23 +34,32 @@ func StartRunSvc() error { log.Infof("Crowdsec %s", version.String()) - apiReady := make(chan bool, 1) agentReady := make(chan bool, 1) // Enable profiling early if cConfig.Prometheus != nil { var dbClient *database.Client + var err error - if cConfig.DbConfig != nil { - dbClient, err = database.NewClient(cConfig.DbConfig) + ctx := context.TODO() + if cConfig.DbConfig != nil { + dbClient, err = database.NewClient(ctx, cConfig.DbConfig) if err != nil { - return fmt.Errorf("unable to create database client: %s", err) + return fmt.Errorf("unable to create database client: %w", err) } } + registerPrometheus(cConfig.Prometheus) - go servePrometheus(cConfig.Prometheus, dbClient, apiReady, agentReady) + + go servePrometheus(cConfig.Prometheus, dbClient, agentReady) + } else { + // avoid leaking the channel + go func() { + <-agentReady + }() } - return Serve(cConfig, apiReady, agentReady) + + return Serve(cConfig, agentReady) } diff --git a/cmd/crowdsec/run_in_svc_windows.go b/cmd/crowdsec/run_in_svc_windows.go index d63a587ac16..a2a2dd8c47a 100644 --- a/cmd/crowdsec/run_in_svc_windows.go +++ b/cmd/crowdsec/run_in_svc_windows.go @@ -1,13 +1,15 @@ package main import ( + "context" "fmt" + "runtime/pprof" log "github.com/sirupsen/logrus" "golang.org/x/sys/windows/svc" - "github.com/crowdsecurity/go-cs-lib/pkg/trace" - "github.com/crowdsecurity/go-cs-lib/pkg/version" + "github.com/crowdsecurity/go-cs-lib/trace" + "github.com/crowdsecurity/go-cs-lib/version" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/database" @@ -19,6 +21,10 @@ func StartRunSvc() error { defer trace.CatchPanic("crowdsec/StartRunSvc") + // Always try to stop CPU profiling to avoid passing flags around + // It's a noop if profiling is not enabled + defer pprof.StopCPUProfile() + isRunninginService, err := svc.IsWindowsService() if err != nil { return fmt.Errorf("failed to determine if we are running in windows service mode: %w", err) @@ -68,7 +74,6 @@ func WindowsRun() error { log.Infof("Crowdsec %s", version.String()) - apiReady := make(chan bool, 1) agentReady := make(chan bool, 1) // Enable profiling early @@ -76,15 +81,17 @@ func WindowsRun() error { var dbClient *database.Client var err error + ctx := context.TODO() + if cConfig.DbConfig != nil { - dbClient, err = database.NewClient(cConfig.DbConfig) + dbClient, err = database.NewClient(ctx, cConfig.DbConfig) if err != nil { - return fmt.Errorf("unable to create database client: %s", err) + return fmt.Errorf("unable to create database client: %w", err) } } registerPrometheus(cConfig.Prometheus) - go servePrometheus(cConfig.Prometheus, dbClient, apiReady, agentReady) + go servePrometheus(cConfig.Prometheus, dbClient, agentReady) } - return Serve(cConfig, apiReady, agentReady) + return Serve(cConfig, agentReady) } diff --git a/cmd/crowdsec/serve.go b/cmd/crowdsec/serve.go index 2efb14613fb..14602c425fe 100644 --- a/cmd/crowdsec/serve.go +++ b/cmd/crowdsec/serve.go @@ -1,19 +1,22 @@ package main import ( + "context" "fmt" "os" "os/signal" + "runtime/pprof" "syscall" "time" log "github.com/sirupsen/logrus" "gopkg.in/tomb.v2" - "github.com/crowdsecurity/go-cs-lib/pkg/csdaemon" - "github.com/crowdsecurity/go-cs-lib/pkg/trace" + "github.com/crowdsecurity/go-cs-lib/csdaemon" + "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" "github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" leaky "github.com/crowdsecurity/crowdsec/pkg/leakybucket" @@ -40,13 +43,17 @@ func debugHandler(sig os.Signal, cConfig *csconfig.Config) error { if err := leaky.ShutdownAllBuckets(buckets); err != nil { log.Warningf("Failed to shut down routines : %s", err) } + log.Printf("Shutdown is finished, buckets are in %s", tmpFile) + return nil } func reloadHandler(sig os.Signal) (*csconfig.Config, error) { var tmpFile string + ctx := context.TODO() + // re-initialize tombs acquisTomb = tomb.Tomb{} parsersTomb = tomb.Tomb{} @@ -55,6 +62,7 @@ func reloadHandler(sig os.Signal) (*csconfig.Config, error) { apiTomb = tomb.Tomb{} crowdsecTomb = tomb.Tomb{} pluginTomb = tomb.Tomb{} + lpMetricsTomb = tomb.Tomb{} cConfig, err := LoadConfig(flags.ConfigFile, flags.DisableAgent, flags.DisableAPI, false) if err != nil { @@ -64,19 +72,29 @@ func reloadHandler(sig os.Signal) (*csconfig.Config, error) { if !cConfig.DisableAPI { if flags.DisableCAPI { log.Warningf("Communication with CrowdSec Central API disabled from args") + cConfig.API.Server.OnlineClient = nil } - apiServer, err := initAPIServer(cConfig) + + apiServer, err := initAPIServer(ctx, cConfig) if err != nil { return nil, fmt.Errorf("unable to init api server: %w", err) } - apiReady := make(chan bool, 1) - serveAPIServer(apiServer, apiReady) + serveAPIServer(apiServer) } if !cConfig.DisableAgent { - csParsers, err := initCrowdsec(cConfig) + hub, err := cwhub.NewHub(cConfig.Hub, nil, log.StandardLogger()) + if err != nil { + return nil, err + } + + if err = hub.Load(); err != nil { + return nil, err + } + + csParsers, datasources, err := initCrowdsec(cConfig, hub) if err != nil { return nil, fmt.Errorf("unable to init crowdsec: %w", err) } @@ -93,7 +111,7 @@ func reloadHandler(sig os.Signal) (*csconfig.Config, error) { } agentReady := make(chan bool, 1) - serveCrowdsec(csParsers, cConfig, agentReady) + serveCrowdsec(csParsers, cConfig, hub, datasources, agentReady) } log.Printf("Reload is finished") @@ -103,6 +121,7 @@ func reloadHandler(sig os.Signal) (*csconfig.Config, error) { log.Warningf("Failed to delete temp file (%s) : %s", tmpFile, err) } } + return cConfig, nil } @@ -110,10 +129,12 @@ func ShutdownCrowdsecRoutines() error { var reterr error log.Debugf("Shutting down crowdsec sub-routines") + if len(dataSources) > 0 { acquisTomb.Kill(nil) log.Debugf("waiting for acquisition to finish") drainChan(inputLineChan) + if err := acquisTomb.Wait(); err != nil { log.Warningf("Acquisition returned error : %s", err) reterr = err @@ -123,6 +144,7 @@ func ShutdownCrowdsecRoutines() error { log.Debugf("acquisition is finished, wait for parser/bucket/ouputs.") parsersTomb.Kill(nil) drainChan(inputEventChan) + if err := parsersTomb.Wait(); err != nil { log.Warningf("Parsers returned error : %s", err) reterr = err @@ -141,15 +163,40 @@ func ShutdownCrowdsecRoutines() error { time.Sleep(1 * time.Second) // ugly workaround for now outputsTomb.Kill(nil) - if err := outputsTomb.Wait(); err != nil { - log.Warningf("Ouputs returned error : %s", err) + done := make(chan error, 1) + go func() { + done <- outputsTomb.Wait() + }() + + // wait for outputs to finish, max 3 seconds + select { + case err := <-done: + if err != nil { + log.Warningf("Outputs returned error : %s", err) + reterr = err + } + + log.Debugf("outputs are done") + case <-time.After(3 * time.Second): + // this can happen if outputs are stuck in a http retry loop + log.Warningf("Outputs didn't finish in time, some events may have not been flushed") + } + + lpMetricsTomb.Kill(nil) + + if err := lpMetricsTomb.Wait(); err != nil { + log.Warningf("Metrics returned error : %s", err) reterr = err } - log.Debugf("outputs are done") + log.Debugf("metrics are done") + // He's dead, Jim. crowdsecTomb.Kill(nil) + // close the potential geoips reader we have to avoid leaking ressources on reload + exprhelpers.GeoIPClose() + return reterr } @@ -162,6 +209,7 @@ func shutdownAPI() error { } log.Debugf("done") + return nil } @@ -174,6 +222,7 @@ func shutdownCrowdsec() error { } log.Debugf("done") + return nil } @@ -201,7 +250,7 @@ func drainChan(c chan types.Event) { for { select { case _, ok := <-c: - if !ok { //closed + if !ok { // closed return } default: @@ -227,6 +276,10 @@ func HandleSignals(cConfig *csconfig.Config) error { exitChan := make(chan error) + // Always try to stop CPU profiling to avoid passing flags around + // It's a noop if profiling is not enabled + defer pprof.StopCPUProfile() + go func() { defer trace.CatchPanic("crowdsec/HandleSignals") Loop: @@ -269,10 +322,11 @@ func HandleSignals(cConfig *csconfig.Config) error { if err == nil { log.Warning("Crowdsec service shutting down") } + return err } -func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) error { +func Serve(cConfig *csconfig.Config, agentReady chan bool) error { acquisTomb = tomb.Tomb{} parsersTomb = tomb.Tomb{} bucketsTomb = tomb.Tomb{} @@ -280,9 +334,12 @@ func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) e apiTomb = tomb.Tomb{} crowdsecTomb = tomb.Tomb{} pluginTomb = tomb.Tomb{} + lpMetricsTomb = tomb.Tomb{} + + ctx := context.TODO() if cConfig.API.Server != nil && cConfig.API.Server.DbConfig != nil { - dbClient, err := database.NewClient(cConfig.API.Server.DbConfig) + dbClient, err := database.NewClient(ctx, cConfig.API.Server.DbConfig) if err != nil { return fmt.Errorf("failed to get database client: %w", err) } @@ -300,8 +357,9 @@ func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) e log.Warningln("Exprhelpers loaded without database client.") } - if cConfig.API.CTI != nil && *cConfig.API.CTI.Enabled { + if cConfig.API.CTI != nil && cConfig.API.CTI.Enabled != nil && *cConfig.API.CTI.Enabled { log.Infof("Crowdsec CTI helper enabled") + if err := exprhelpers.InitCrowdsecCTI(cConfig.API.CTI.Key, cConfig.API.CTI.CacheTimeout, cConfig.API.CTI.CacheSize, cConfig.API.CTI.LogLevel); err != nil { return fmt.Errorf("failed to init crowdsec cti: %w", err) } @@ -314,30 +372,40 @@ func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) e if flags.DisableCAPI { log.Warningf("Communication with CrowdSec Central API disabled from args") + cConfig.API.Server.OnlineClient = nil } - apiServer, err := initAPIServer(cConfig) + apiServer, err := initAPIServer(ctx, cConfig) if err != nil { return fmt.Errorf("api server init: %w", err) } if !flags.TestMode { - serveAPIServer(apiServer, apiReady) + serveAPIServer(apiServer) } - } else { - apiReady <- true } if !cConfig.DisableAgent { - csParsers, err := initCrowdsec(cConfig) + hub, err := cwhub.NewHub(cConfig.Hub, nil, log.StandardLogger()) + if err != nil { + return err + } + + if err = hub.Load(); err != nil { + return err + } + + csParsers, datasources, err := initCrowdsec(cConfig, hub) if err != nil { return fmt.Errorf("crowdsec init: %w", err) } // if it's just linting, we're done if !flags.TestMode { - serveCrowdsec(csParsers, cConfig, agentReady) + serveCrowdsec(csParsers, cConfig, hub, datasources, agentReady) + } else { + agentReady <- true } } else { agentReady <- true @@ -346,11 +414,12 @@ func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) e if flags.TestMode { log.Infof("Configuration test done") pluginBroker.Kill() - os.Exit(0) + + return nil } if cConfig.Common != nil && cConfig.Common.Daemonize { - csdaemon.NotifySystemd(log.StandardLogger()) + csdaemon.Notify(csdaemon.Ready, log.StandardLogger()) // wait for signals return HandleSignals(cConfig) } @@ -367,6 +436,7 @@ func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) e for _, ch := range waitChans { <-ch + switch ch { case apiTomb.Dead(): log.Infof("api shutdown") @@ -374,5 +444,6 @@ func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) e log.Infof("crowdsec shutdown") } } + return nil } diff --git a/cmd/crowdsec/win_service.go b/cmd/crowdsec/win_service.go index ab9ecc8151f..6aa363ca3a7 100644 --- a/cmd/crowdsec/win_service.go +++ b/cmd/crowdsec/win_service.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build windows -// +build windows package main @@ -24,7 +23,7 @@ type crowdsec_winservice struct { config *csconfig.Config } -func (m *crowdsec_winservice) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (ssec bool, errno uint32) { +func (m *crowdsec_winservice) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (bool, uint32) { const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown changes <- svc.Status{State: svc.StartPending} tick := time.Tick(500 * time.Millisecond) @@ -60,7 +59,8 @@ func (m *crowdsec_winservice) Execute(args []string, r <-chan svc.ChangeRequest, if err != nil { log.Fatal(err) } - return + + return false, 0 } func runService(name string) error { diff --git a/cmd/crowdsec/win_service_install.go b/cmd/crowdsec/win_service_install.go index b497bc93182..85b35420264 100644 --- a/cmd/crowdsec/win_service_install.go +++ b/cmd/crowdsec/win_service_install.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build windows -// +build windows package main diff --git a/cmd/crowdsec/win_service_manage.go b/cmd/crowdsec/win_service_manage.go index 4aa11527350..4e31dc019af 100644 --- a/cmd/crowdsec/win_service_manage.go +++ b/cmd/crowdsec/win_service_manage.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build windows +//go:build windows package main diff --git a/plugins/notifications/http/Makefile b/cmd/notification-dummy/Makefile similarity index 53% rename from plugins/notifications/http/Makefile rename to cmd/notification-dummy/Makefile index 56f49077262..251abe19df0 100644 --- a/plugins/notifications/http/Makefile +++ b/cmd/notification-dummy/Makefile @@ -4,14 +4,13 @@ ifeq ($(OS), Windows_NT) EXT = .exe endif -PLUGIN=http -BINARY_NAME = notification-$(PLUGIN)$(EXT) +GO = go +GOBUILD = $(GO) build -GOCMD = go -GOBUILD = $(GOCMD) build +BINARY_NAME = notification-dummy$(EXT) build: clean - $(GOBUILD) $(LD_OPTS) $(BUILD_VENDOR_FLAGS) -o $(BINARY_NAME) + $(GOBUILD) $(LD_OPTS) -o $(BINARY_NAME) .PHONY: clean clean: diff --git a/plugins/notifications/dummy/dummy.yaml b/cmd/notification-dummy/dummy.yaml similarity index 100% rename from plugins/notifications/dummy/dummy.yaml rename to cmd/notification-dummy/dummy.yaml diff --git a/plugins/notifications/dummy/main.go b/cmd/notification-dummy/main.go similarity index 91% rename from plugins/notifications/dummy/main.go rename to cmd/notification-dummy/main.go index ef8d29ffa44..7fbb10d4fca 100644 --- a/plugins/notifications/dummy/main.go +++ b/cmd/notification-dummy/main.go @@ -5,10 +5,12 @@ import ( "fmt" "os" - "github.com/crowdsecurity/crowdsec/pkg/protobufs" "github.com/hashicorp/go-hclog" plugin "github.com/hashicorp/go-plugin" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/pkg/csplugin" + "github.com/crowdsecurity/crowdsec/pkg/protobufs" ) type PluginConfig struct { @@ -18,6 +20,7 @@ type PluginConfig struct { } type DummyPlugin struct { + protobufs.UnimplementedNotifierServer PluginConfigByName map[string]PluginConfig } @@ -32,6 +35,7 @@ func (s *DummyPlugin) Notify(ctx context.Context, notification *protobufs.Notifi if _, ok := s.PluginConfigByName[notification.Name]; !ok { return nil, fmt.Errorf("invalid plugin config name %s", notification.Name) } + cfg := s.PluginConfigByName[notification.Name] if cfg.LogLevel != nil && *cfg.LogLevel != "" { @@ -42,19 +46,22 @@ func (s *DummyPlugin) Notify(ctx context.Context, notification *protobufs.Notifi logger.Debug(notification.Text) if cfg.OutputFile != nil && *cfg.OutputFile != "" { - f, err := os.OpenFile(*cfg.OutputFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + f, err := os.OpenFile(*cfg.OutputFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) if err != nil { logger.Error(fmt.Sprintf("Cannot open notification file: %s", err)) } + if _, err := f.WriteString(notification.Text + "\n"); err != nil { f.Close() logger.Error(fmt.Sprintf("Cannot write notification to file: %s", err)) } + err = f.Close() if err != nil { logger.Error(fmt.Sprintf("Cannot close notification file: %s", err)) } } + fmt.Println(notification.Text) return &protobufs.Empty{}, nil @@ -64,11 +71,12 @@ func (s *DummyPlugin) Configure(ctx context.Context, config *protobufs.Config) ( d := PluginConfig{} err := yaml.Unmarshal(config.Config, &d) s.PluginConfigByName[d.Name] = d + return &protobufs.Empty{}, err } func main() { - var handshake = plugin.HandshakeConfig{ + handshake := plugin.HandshakeConfig{ ProtocolVersion: 1, MagicCookieKey: "CROWDSEC_PLUGIN_KEY", MagicCookieValue: os.Getenv("CROWDSEC_PLUGIN_KEY"), @@ -78,7 +86,7 @@ func main() { plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: handshake, Plugins: map[string]plugin.Plugin{ - "dummy": &protobufs.NotifierPlugin{ + "dummy": &csplugin.NotifierPlugin{ Impl: sp, }, }, diff --git a/plugins/notifications/slack/Makefile b/cmd/notification-email/Makefile similarity index 53% rename from plugins/notifications/slack/Makefile rename to cmd/notification-email/Makefile index f43303eb882..7a782cc9db1 100644 --- a/plugins/notifications/slack/Makefile +++ b/cmd/notification-email/Makefile @@ -4,14 +4,13 @@ ifeq ($(OS), Windows_NT) EXT = .exe endif -PLUGIN=slack -BINARY_NAME = notification-$(PLUGIN)$(EXT) +GO = go +GOBUILD = $(GO) build -GOCMD = go -GOBUILD = $(GOCMD) build +BINARY_NAME = notification-email$(EXT) build: clean - $(GOBUILD) $(LD_OPTS) $(BUILD_VENDOR_FLAGS) -o $(BINARY_NAME) + $(GOBUILD) $(LD_OPTS) -o $(BINARY_NAME) .PHONY: clean clean: diff --git a/plugins/notifications/email/email.yaml b/cmd/notification-email/email.yaml similarity index 72% rename from plugins/notifications/email/email.yaml rename to cmd/notification-email/email.yaml index 37ce2a98209..512633c6380 100644 --- a/plugins/notifications/email/email.yaml +++ b/cmd/notification-email/email.yaml @@ -15,12 +15,14 @@ timeout: 20s # Time to wait for response from the plugin before conside # The following template receives a list of models.Alert objects # The output goes in the email message body format: | + {{range . -}} {{$alert := . -}} {{range .Decisions -}} -

{{.Value}} will get {{.Type}} for next {{.Duration}} for triggering {{.Scenario}} on machine {{$alert.MachineID}}.

CrowdSec CTI

+

{{.Value}} will get {{.Type}} for next {{.Duration}} for triggering {{.Scenario}} on machine {{$alert.MachineID}}.

CrowdSec CTI

{{end -}} {{end -}} + smtp_host: # example: smtp.gmail.com smtp_username: # Replace with your actual username @@ -35,7 +37,15 @@ receiver_emails: # - email2@gmail.com # One of "ssltls", "starttls", "none" -encryption_type: ssltls +encryption_type: "ssltls" + +# If you need to set the HELO hostname: +# helo_host: "localhost" + +# If the email server is hitting the default timeouts (10 seconds), you can increase them here +# +# connect_timeout: 10s +# send_timeout: 10s --- diff --git a/plugins/notifications/email/main.go b/cmd/notification-email/main.go similarity index 78% rename from plugins/notifications/email/main.go rename to cmd/notification-email/main.go index ac09c1eefef..5fc02cdd1d7 100644 --- a/plugins/notifications/email/main.go +++ b/cmd/notification-email/main.go @@ -2,14 +2,18 @@ package main import ( "context" + "errors" "fmt" "os" + "time" - "github.com/crowdsecurity/crowdsec/pkg/protobufs" "github.com/hashicorp/go-hclog" plugin "github.com/hashicorp/go-plugin" mail "github.com/xhit/go-simple-mail/v2" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/pkg/csplugin" + "github.com/crowdsecurity/crowdsec/pkg/protobufs" ) var baseLogger hclog.Logger = hclog.New(&hclog.LoggerOptions{ @@ -47,9 +51,12 @@ type PluginConfig struct { EncryptionType string `yaml:"encryption_type"` AuthType string `yaml:"auth_type"` HeloHost string `yaml:"helo_host"` + ConnectTimeout string `yaml:"connect_timeout"` + SendTimeout string `yaml:"send_timeout"` } type EmailPlugin struct { + protobufs.UnimplementedNotifierServer ConfigByName map[string]PluginConfig } @@ -69,19 +76,20 @@ func (n *EmailPlugin) Configure(ctx context.Context, config *protobufs.Config) ( } if d.Name == "" { - return nil, fmt.Errorf("name is required") + return nil, errors.New("name is required") } if d.SMTPHost == "" { - return nil, fmt.Errorf("SMTP host is not set") + return nil, errors.New("SMTP host is not set") } - if d.ReceiverEmails == nil || len(d.ReceiverEmails) == 0 { - return nil, fmt.Errorf("Receiver emails are not set") + if len(d.ReceiverEmails) == 0 { + return nil, errors.New("receiver emails are not set") } n.ConfigByName[d.Name] = d baseLogger.Debug(fmt.Sprintf("Email plugin '%s' use SMTP host '%s:%d'", d.Name, d.SMTPHost, d.SMTPPort)) + return &protobufs.Empty{}, nil } @@ -89,6 +97,7 @@ func (n *EmailPlugin) Notify(ctx context.Context, notification *protobufs.Notifi if _, ok := n.ConfigByName[notification.Name]; !ok { return nil, fmt.Errorf("invalid plugin config name %s", notification.Name) } + cfg := n.ConfigByName[notification.Name] logger := baseLogger.Named(cfg.Name) @@ -108,11 +117,33 @@ func (n *EmailPlugin) Notify(ctx context.Context, notification *protobufs.Notifi server.Authentication = AuthStringToType[cfg.AuthType] server.Helo = cfg.HeloHost + var err error + + if cfg.ConnectTimeout != "" { + server.ConnectTimeout, err = time.ParseDuration(cfg.ConnectTimeout) + if err != nil { + logger.Warn(fmt.Sprintf("invalid connect timeout '%s', using default '10s'", cfg.ConnectTimeout)) + + server.ConnectTimeout = 10 * time.Second + } + } + + if cfg.SendTimeout != "" { + server.SendTimeout, err = time.ParseDuration(cfg.SendTimeout) + if err != nil { + logger.Warn(fmt.Sprintf("invalid send timeout '%s', using default '10s'", cfg.SendTimeout)) + + server.SendTimeout = 10 * time.Second + } + } + logger.Debug("making smtp connection") + smtpClient, err := server.Connect() if err != nil { return &protobufs.Empty{}, err } + logger.Debug("smtp connection done") email := mail.NewMSG() @@ -125,12 +156,14 @@ func (n *EmailPlugin) Notify(ctx context.Context, notification *protobufs.Notifi if err != nil { return &protobufs.Empty{}, err } + logger.Info(fmt.Sprintf("sent email to %v", cfg.ReceiverEmails)) + return &protobufs.Empty{}, nil } func main() { - var handshake = plugin.HandshakeConfig{ + handshake := plugin.HandshakeConfig{ ProtocolVersion: 1, MagicCookieKey: "CROWDSEC_PLUGIN_KEY", MagicCookieValue: os.Getenv("CROWDSEC_PLUGIN_KEY"), @@ -139,7 +172,7 @@ func main() { plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: handshake, Plugins: map[string]plugin.Plugin{ - "email": &protobufs.NotifierPlugin{ + "email": &csplugin.NotifierPlugin{ Impl: &EmailPlugin{ConfigByName: make(map[string]PluginConfig)}, }, }, diff --git a/plugins/notifications/splunk/Makefile b/cmd/notification-file/Makefile similarity index 53% rename from plugins/notifications/splunk/Makefile rename to cmd/notification-file/Makefile index a7f04f4d0fe..4504328c49a 100644 --- a/plugins/notifications/splunk/Makefile +++ b/cmd/notification-file/Makefile @@ -4,14 +4,13 @@ ifeq ($(OS), Windows_NT) EXT = .exe endif -PLUGIN=splunk -BINARY_NAME = notification-$(PLUGIN)$(EXT) +GO = go +GOBUILD = $(GO) build -GOCMD = go -GOBUILD = $(GOCMD) build +BINARY_NAME = notification-file$(EXT) build: clean - $(GOBUILD) $(LD_OPTS) $(BUILD_VENDOR_FLAGS) -o $(BINARY_NAME) + $(GOBUILD) $(LD_OPTS) -o $(BINARY_NAME) .PHONY: clean clean: diff --git a/cmd/notification-file/file.yaml b/cmd/notification-file/file.yaml new file mode 100644 index 00000000000..61c77b9eb49 --- /dev/null +++ b/cmd/notification-file/file.yaml @@ -0,0 +1,23 @@ +# Don't change this +type: file + +name: file_default # this must match with the registered plugin in the profile +log_level: info # Options include: trace, debug, info, warn, error, off + +# This template render all events as ndjson +format: | + {{range . -}} + { "time": "{{.StopAt}}", "program": "crowdsec", "alert": {{. | toJson }} } + {{ end -}} + +# group_wait: # duration to wait collecting alerts before sending to this plugin, eg "30s" +# group_threshold: # if alerts exceed this, then the plugin will be sent the message. eg "10" + +#Use full path EG /tmp/crowdsec_alerts.json or %TEMP%\crowdsec_alerts.json +log_path: "/tmp/crowdsec_alerts.json" +rotate: + enabled: true # Change to false if you want to handle log rotate on system basis + max_size: 500 # in MB + max_files: 5 + max_age: 5 + compress: true diff --git a/cmd/notification-file/main.go b/cmd/notification-file/main.go new file mode 100644 index 00000000000..a4dbb8ee5db --- /dev/null +++ b/cmd/notification-file/main.go @@ -0,0 +1,253 @@ +package main + +import ( + "compress/gzip" + "context" + "fmt" + "io" + "os" + "path/filepath" + "sort" + "sync" + "time" + + "github.com/hashicorp/go-hclog" + plugin "github.com/hashicorp/go-plugin" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/pkg/csplugin" + "github.com/crowdsecurity/crowdsec/pkg/protobufs" +) + +var ( + FileWriter *os.File + FileWriteMutex *sync.Mutex + FileSize int64 +) + +type FileWriteCtx struct { + Ctx context.Context + Writer io.Writer +} + +func (w *FileWriteCtx) Write(p []byte) (n int, err error) { + if err := w.Ctx.Err(); err != nil { + return 0, err + } + return w.Writer.Write(p) +} + +type PluginConfig struct { + Name string `yaml:"name"` + LogLevel string `yaml:"log_level"` + LogPath string `yaml:"log_path"` + LogRotate LogRotate `yaml:"rotate"` +} + +type LogRotate struct { + MaxSize int `yaml:"max_size"` + MaxAge int `yaml:"max_age"` + MaxFiles int `yaml:"max_files"` + Enabled bool `yaml:"enabled"` + Compress bool `yaml:"compress"` +} + +type FilePlugin struct { + protobufs.UnimplementedNotifierServer + PluginConfigByName map[string]PluginConfig +} + +var logger hclog.Logger = hclog.New(&hclog.LoggerOptions{ + Name: "file-plugin", + Level: hclog.LevelFromString("INFO"), + Output: os.Stderr, + JSONFormat: true, +}) + +func (r *LogRotate) rotateLogs(cfg PluginConfig) { + // Rotate the log file + err := r.rotateLogFile(cfg.LogPath, r.MaxFiles) + if err != nil { + logger.Error("Failed to rotate log file", "error", err) + } + // Reopen the FileWriter + FileWriter.Close() + FileWriter, err = os.OpenFile(cfg.LogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + logger.Error("Failed to reopen log file", "error", err) + } + // Reset the file size + FileInfo, err := FileWriter.Stat() + if err != nil { + logger.Error("Failed to get file info", "error", err) + } + FileSize = FileInfo.Size() +} + +func (r *LogRotate) rotateLogFile(logPath string, maxBackups int) error { + // Rename the current log file + backupPath := logPath + "." + time.Now().Format("20060102-150405") + err := os.Rename(logPath, backupPath) + if err != nil { + return err + } + glob := logPath + ".*" + if r.Compress { + glob = logPath + ".*.gz" + err = compressFile(backupPath) + if err != nil { + return err + } + } + + // Remove old backups + files, err := filepath.Glob(glob) + if err != nil { + return err + } + + sort.Sort(sort.Reverse(sort.StringSlice(files))) + + for i, file := range files { + logger.Trace("Checking file", "file", file, "index", i, "maxBackups", maxBackups) + if i >= maxBackups { + logger.Trace("Removing file as over max backup count", "file", file) + os.Remove(file) + } else { + // Check the age of the file + fileInfo, err := os.Stat(file) + if err != nil { + return err + } + age := time.Since(fileInfo.ModTime()).Hours() + if age > float64(r.MaxAge*24) { + logger.Trace("Removing file as age was over configured amount", "file", file, "age", age) + os.Remove(file) + } + } + } + + return nil +} + +func compressFile(src string) error { + // Open the source file for reading + srcFile, err := os.Open(src) + if err != nil { + return err + } + defer srcFile.Close() + + // Create the destination file + dstFile, err := os.Create(src + ".gz") + if err != nil { + return err + } + defer dstFile.Close() + + // Create a gzip writer + gw := gzip.NewWriter(dstFile) + defer gw.Close() + + // Read the source file and write its contents to the gzip writer + _, err = io.Copy(gw, srcFile) + if err != nil { + return err + } + + // Delete the original (uncompressed) backup file + err = os.Remove(src) + if err != nil { + return err + } + + return nil +} + +func WriteToFileWithCtx(ctx context.Context, cfg PluginConfig, log string) error { + FileWriteMutex.Lock() + defer FileWriteMutex.Unlock() + originalFileInfo, err := FileWriter.Stat() + if err != nil { + logger.Error("Failed to get file info", "error", err) + } + currentFileInfo, _ := os.Stat(cfg.LogPath) + if !os.SameFile(originalFileInfo, currentFileInfo) { + // The file has been rotated outside our control + logger.Info("Log file has been rotated or missing attempting to reopen it") + FileWriter.Close() + FileWriter, err = os.OpenFile(cfg.LogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return err + } + FileInfo, err := FileWriter.Stat() + if err != nil { + return err + } + FileSize = FileInfo.Size() + logger.Info("Log file has been reopened successfully") + } + n, err := io.WriteString(&FileWriteCtx{Ctx: ctx, Writer: FileWriter}, log) + if err == nil { + FileSize += int64(n) + if FileSize > int64(cfg.LogRotate.MaxSize)*1024*1024 && cfg.LogRotate.Enabled { + logger.Debug("Rotating log file", "file", cfg.LogPath) + // Rotate the log file + cfg.LogRotate.rotateLogs(cfg) + } + } + return err +} + +func (s *FilePlugin) Notify(ctx context.Context, notification *protobufs.Notification) (*protobufs.Empty, error) { + if _, ok := s.PluginConfigByName[notification.Name]; !ok { + return nil, fmt.Errorf("invalid plugin config name %s", notification.Name) + } + cfg := s.PluginConfigByName[notification.Name] + + return &protobufs.Empty{}, WriteToFileWithCtx(ctx, cfg, notification.Text) +} + +func (s *FilePlugin) Configure(ctx context.Context, config *protobufs.Config) (*protobufs.Empty, error) { + d := PluginConfig{} + err := yaml.Unmarshal(config.Config, &d) + if err != nil { + logger.Error("Failed to parse config", "error", err) + return &protobufs.Empty{}, err + } + FileWriteMutex = &sync.Mutex{} + FileWriter, err = os.OpenFile(d.LogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + logger.Error("Failed to open log file", "error", err) + return &protobufs.Empty{}, err + } + FileInfo, err := FileWriter.Stat() + if err != nil { + logger.Error("Failed to get file info", "error", err) + return &protobufs.Empty{}, err + } + FileSize = FileInfo.Size() + s.PluginConfigByName[d.Name] = d + logger.SetLevel(hclog.LevelFromString(d.LogLevel)) + return &protobufs.Empty{}, err +} + +func main() { + handshake := plugin.HandshakeConfig{ + ProtocolVersion: 1, + MagicCookieKey: "CROWDSEC_PLUGIN_KEY", + MagicCookieValue: os.Getenv("CROWDSEC_PLUGIN_KEY"), + } + + sp := &FilePlugin{PluginConfigByName: make(map[string]PluginConfig)} + plugin.Serve(&plugin.ServeConfig{ + HandshakeConfig: handshake, + Plugins: map[string]plugin.Plugin{ + "file": &csplugin.NotifierPlugin{ + Impl: sp, + }, + }, + GRPCServer: plugin.DefaultGRPCServer, + Logger: logger, + }) +} diff --git a/plugins/notifications/dummy/Makefile b/cmd/notification-http/Makefile similarity index 52% rename from plugins/notifications/dummy/Makefile rename to cmd/notification-http/Makefile index d45d6f19844..30ed43a694c 100644 --- a/plugins/notifications/dummy/Makefile +++ b/cmd/notification-http/Makefile @@ -4,14 +4,13 @@ ifeq ($(OS), Windows_NT) EXT = .exe endif -PLUGIN = dummy -BINARY_NAME = notification-$(PLUGIN)$(EXT) +GO = go +GOBUILD = $(GO) build -GOCMD = go -GOBUILD = $(GOCMD) build +BINARY_NAME = notification-http$(EXT) build: clean - $(GOBUILD) $(LD_OPTS) $(BUILD_VENDOR_FLAGS) -o $(BINARY_NAME) + $(GOBUILD) $(LD_OPTS) -o $(BINARY_NAME) .PHONY: clean clean: diff --git a/plugins/notifications/http/http.yaml b/cmd/notification-http/http.yaml similarity index 100% rename from plugins/notifications/http/http.yaml rename to cmd/notification-http/http.yaml diff --git a/plugins/notifications/http/main.go b/cmd/notification-http/main.go similarity index 56% rename from plugins/notifications/http/main.go rename to cmd/notification-http/main.go index 7e15fccae84..3f84984315b 100644 --- a/plugins/notifications/http/main.go +++ b/cmd/notification-http/main.go @@ -4,27 +4,38 @@ import ( "bytes" "context" "crypto/tls" + "crypto/x509" "fmt" "io" + "net" "net/http" "os" + "strings" - "github.com/crowdsecurity/crowdsec/pkg/protobufs" "github.com/hashicorp/go-hclog" plugin "github.com/hashicorp/go-plugin" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/pkg/csplugin" + "github.com/crowdsecurity/crowdsec/pkg/protobufs" ) type PluginConfig struct { Name string `yaml:"name"` URL string `yaml:"url"` + UnixSocket string `yaml:"unix_socket"` Headers map[string]string `yaml:"headers"` SkipTLSVerification bool `yaml:"skip_tls_verification"` Method string `yaml:"method"` LogLevel *string `yaml:"log_level"` + Client *http.Client `yaml:"-"` + CertPath string `yaml:"cert_path"` + KeyPath string `yaml:"key_path"` + CAPath string `yaml:"ca_cert_path"` } type HTTPPlugin struct { + protobufs.UnimplementedNotifierServer PluginConfigByName map[string]PluginConfig } @@ -35,10 +46,78 @@ var logger hclog.Logger = hclog.New(&hclog.LoggerOptions{ JSONFormat: true, }) +func getCertPool(caPath string) (*x509.CertPool, error) { + cp, err := x509.SystemCertPool() + if err != nil { + return nil, fmt.Errorf("unable to load system CA certificates: %w", err) + } + + if cp == nil { + cp = x509.NewCertPool() + } + + if caPath == "" { + return cp, nil + } + + logger.Info(fmt.Sprintf("Using CA cert '%s'", caPath)) + + caCert, err := os.ReadFile(caPath) + if err != nil { + return nil, fmt.Errorf("unable to load CA certificate '%s': %w", caPath, err) + } + + cp.AppendCertsFromPEM(caCert) + + return cp, nil +} + +func getTLSClient(c *PluginConfig) error { + caCertPool, err := getCertPool(c.CAPath) + if err != nil { + return err + } + + tlsConfig := &tls.Config{ + RootCAs: caCertPool, + InsecureSkipVerify: c.SkipTLSVerification, + } + + if c.CertPath != "" && c.KeyPath != "" { + logger.Info(fmt.Sprintf("Using client certificate '%s' and key '%s'", c.CertPath, c.KeyPath)) + + cert, err := tls.LoadX509KeyPair(c.CertPath, c.KeyPath) + if err != nil { + return fmt.Errorf("unable to load client certificate '%s' and key '%s': %w", c.CertPath, c.KeyPath, err) + } + + tlsConfig.Certificates = []tls.Certificate{cert} + } + + transport := &http.Transport{ + TLSClientConfig: tlsConfig, + } + + if c.UnixSocket != "" { + logger.Info(fmt.Sprintf("Using socket '%s'", c.UnixSocket)) + + transport.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", strings.TrimSuffix(c.UnixSocket, "/")) + } + } + + c.Client = &http.Client{ + Transport: transport, + } + + return nil +} + func (s *HTTPPlugin) Notify(ctx context.Context, notification *protobufs.Notification) (*protobufs.Empty, error) { if _, ok := s.PluginConfigByName[notification.Name]; !ok { return nil, fmt.Errorf("invalid plugin config name %s", notification.Name) } + cfg := s.PluginConfigByName[notification.Name] if cfg.LogLevel != nil && *cfg.LogLevel != "" { @@ -46,13 +125,6 @@ func (s *HTTPPlugin) Notify(ctx context.Context, notification *protobufs.Notific } logger.Info(fmt.Sprintf("received signal for %s config", notification.Name)) - client := http.Client{} - - if cfg.SkipTLSVerification { - client.Transport = &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } - } request, err := http.NewRequest(cfg.Method, cfg.URL, bytes.NewReader([]byte(notification.Text))) if err != nil { @@ -63,8 +135,10 @@ func (s *HTTPPlugin) Notify(ctx context.Context, notification *protobufs.Notific logger.Debug(fmt.Sprintf("adding header %s: %s", headerName, headerValue)) request.Header.Add(headerName, headerValue) } - logger.Debug(fmt.Sprintf("making HTTP %s call to %s with body %s", cfg.Method, cfg.URL, string(notification.Text))) - resp, err := client.Do(request) + + logger.Debug(fmt.Sprintf("making HTTP %s call to %s with body %s", cfg.Method, cfg.URL, notification.Text)) + + resp, err := cfg.Client.Do(request.WithContext(ctx)) if err != nil { logger.Error(fmt.Sprintf("Failed to make HTTP request : %s", err)) return nil, err @@ -73,13 +147,15 @@ func (s *HTTPPlugin) Notify(ctx context.Context, notification *protobufs.Notific respData, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("failed to read response body got error %s", err) + return nil, fmt.Errorf("failed to read response body got error %w", err) } logger.Debug(fmt.Sprintf("got response %s", string(respData))) if resp.StatusCode < 200 || resp.StatusCode >= 300 { logger.Warn(fmt.Sprintf("HTTP server returned non 200 status code: %d", resp.StatusCode)) + logger.Debug(fmt.Sprintf("HTTP server returned body: %s", string(respData))) + return &protobufs.Empty{}, nil } @@ -88,14 +164,25 @@ func (s *HTTPPlugin) Notify(ctx context.Context, notification *protobufs.Notific func (s *HTTPPlugin) Configure(ctx context.Context, config *protobufs.Config) (*protobufs.Empty, error) { d := PluginConfig{} + err := yaml.Unmarshal(config.Config, &d) + if err != nil { + return nil, err + } + + err = getTLSClient(&d) + if err != nil { + return nil, err + } + s.PluginConfigByName[d.Name] = d logger.Debug(fmt.Sprintf("HTTP plugin '%s' use URL '%s'", d.Name, d.URL)) + return &protobufs.Empty{}, err } func main() { - var handshake = plugin.HandshakeConfig{ + handshake := plugin.HandshakeConfig{ ProtocolVersion: 1, MagicCookieKey: "CROWDSEC_PLUGIN_KEY", MagicCookieValue: os.Getenv("CROWDSEC_PLUGIN_KEY"), @@ -105,7 +192,7 @@ func main() { plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: handshake, Plugins: map[string]plugin.Plugin{ - "http": &protobufs.NotifierPlugin{ + "http": &csplugin.NotifierPlugin{ Impl: sp, }, }, diff --git a/cmd/notification-sentinel/Makefile b/cmd/notification-sentinel/Makefile new file mode 100644 index 00000000000..21d176a9039 --- /dev/null +++ b/cmd/notification-sentinel/Makefile @@ -0,0 +1,17 @@ +ifeq ($(OS), Windows_NT) + SHELL := pwsh.exe + .SHELLFLAGS := -NoProfile -Command + EXT = .exe +endif + +GO = go +GOBUILD = $(GO) build + +BINARY_NAME = notification-sentinel$(EXT) + +build: clean + $(GOBUILD) $(LD_OPTS) -o $(BINARY_NAME) + +.PHONY: clean +clean: + @$(RM) $(BINARY_NAME) $(WIN_IGNORE_ERR) diff --git a/cmd/notification-sentinel/main.go b/cmd/notification-sentinel/main.go new file mode 100644 index 00000000000..0293d45b0a4 --- /dev/null +++ b/cmd/notification-sentinel/main.go @@ -0,0 +1,134 @@ +package main + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "fmt" + "net/http" + "os" + "strings" + "time" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-plugin" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/pkg/csplugin" + "github.com/crowdsecurity/crowdsec/pkg/protobufs" +) + +type PluginConfig struct { + Name string `yaml:"name"` + CustomerID string `yaml:"customer_id"` + SharedKey string `yaml:"shared_key"` + LogType string `yaml:"log_type"` + LogLevel *string `yaml:"log_level"` +} + +type SentinelPlugin struct { + protobufs.UnimplementedNotifierServer + PluginConfigByName map[string]PluginConfig +} + +var logger hclog.Logger = hclog.New(&hclog.LoggerOptions{ + Name: "sentinel-plugin", + Level: hclog.LevelFromString("INFO"), + Output: os.Stderr, + JSONFormat: true, +}) + +func (s *SentinelPlugin) getAuthorizationHeader(now string, length int, pluginName string) (string, error) { + xHeaders := "X-Ms-Date:" + now + + stringToHash := fmt.Sprintf("POST\n%d\napplication/json\n%s\n/api/logs", length, xHeaders) + decodedKey, _ := base64.StdEncoding.DecodeString(s.PluginConfigByName[pluginName].SharedKey) + + h := hmac.New(sha256.New, decodedKey) + h.Write([]byte(stringToHash)) + + encodedHash := base64.StdEncoding.EncodeToString(h.Sum(nil)) + authorization := "SharedKey " + s.PluginConfigByName[pluginName].CustomerID + ":" + encodedHash + + logger.Trace("authorization header", "header", authorization) + + return authorization, nil +} + +func (s *SentinelPlugin) Notify(ctx context.Context, notification *protobufs.Notification) (*protobufs.Empty, error) { + if _, ok := s.PluginConfigByName[notification.Name]; !ok { + return nil, fmt.Errorf("invalid plugin config name %s", notification.Name) + } + cfg := s.PluginConfigByName[notification.Name] + + if cfg.LogLevel != nil && *cfg.LogLevel != "" { + logger.SetLevel(hclog.LevelFromString(*cfg.LogLevel)) + } + + logger.Info("received notification for sentinel config", "name", notification.Name) + + url := fmt.Sprintf("https://%s.ods.opinsights.azure.com/api/logs?api-version=2016-04-01", s.PluginConfigByName[notification.Name].CustomerID) + body := strings.NewReader(notification.Text) + + //Cannot use time.RFC1123 as azure wants GMT, not UTC + now := time.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05 GMT") + + authorization, err := s.getAuthorizationHeader(now, len(notification.Text), notification.Name) + if err != nil { + return &protobufs.Empty{}, err + } + + req, err := http.NewRequest(http.MethodPost, url, body) + if err != nil { + logger.Error("failed to create request", "error", err) + return &protobufs.Empty{}, err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Log-Type", s.PluginConfigByName[notification.Name].LogType) + req.Header.Set("Authorization", authorization) + req.Header.Set("X-Ms-Date", now) + + client := &http.Client{} + resp, err := client.Do(req.WithContext(ctx)) + if err != nil { + logger.Error("failed to send request", "error", err) + return &protobufs.Empty{}, err + } + defer resp.Body.Close() + logger.Debug("sent notification to sentinel", "status", resp.Status) + + if resp.StatusCode != http.StatusOK { + return &protobufs.Empty{}, fmt.Errorf("failed to send notification to sentinel: %s", resp.Status) + } + + return &protobufs.Empty{}, nil +} + +func (s *SentinelPlugin) Configure(ctx context.Context, config *protobufs.Config) (*protobufs.Empty, error) { + d := PluginConfig{} + err := yaml.Unmarshal(config.Config, &d) + s.PluginConfigByName[d.Name] = d + return &protobufs.Empty{}, err +} + +func main() { + handshake := plugin.HandshakeConfig{ + ProtocolVersion: 1, + MagicCookieKey: "CROWDSEC_PLUGIN_KEY", + MagicCookieValue: os.Getenv("CROWDSEC_PLUGIN_KEY"), + } + + sp := &SentinelPlugin{PluginConfigByName: make(map[string]PluginConfig)} + plugin.Serve(&plugin.ServeConfig{ + HandshakeConfig: handshake, + Plugins: map[string]plugin.Plugin{ + "sentinel": &csplugin.NotifierPlugin{ + Impl: sp, + }, + }, + GRPCServer: plugin.DefaultGRPCServer, + Logger: logger, + }) +} diff --git a/cmd/notification-sentinel/sentinel.yaml b/cmd/notification-sentinel/sentinel.yaml new file mode 100644 index 00000000000..8451c3ffb5d --- /dev/null +++ b/cmd/notification-sentinel/sentinel.yaml @@ -0,0 +1,21 @@ +type: sentinel # Don't change +name: sentinel_default # Must match the registered plugin in the profile + +# One of "trace", "debug", "info", "warn", "error", "off" +log_level: info +# group_wait: # Time to wait collecting alerts before relaying a message to this plugin, eg "30s" +# group_threshold: # Amount of alerts that triggers a message before has expired, eg "10" +# max_retry: # Number of attempts to relay messages to plugins in case of error +# timeout: # Time to wait for response from the plugin before considering the attempt a failure, eg "10s" + +#------------------------- +# plugin-specific options + +# The following template receives a list of models.Alert objects +# The output goes in the http request body +format: | + {{.|toJson}} + +customer_id: XXX-XXX +shared_key: XXXXXXX +log_type: crowdsec \ No newline at end of file diff --git a/cmd/notification-slack/Makefile b/cmd/notification-slack/Makefile new file mode 100644 index 00000000000..06c9ccc3fd4 --- /dev/null +++ b/cmd/notification-slack/Makefile @@ -0,0 +1,17 @@ +ifeq ($(OS), Windows_NT) + SHELL := pwsh.exe + .SHELLFLAGS := -NoProfile -Command + EXT = .exe +endif + +GO = go +GOBUILD = $(GO) build + +BINARY_NAME = notification-slack$(EXT) + +build: clean + $(GOBUILD) $(LD_OPTS) -o $(BINARY_NAME) + +.PHONY: clean +clean: + @$(RM) $(BINARY_NAME) $(WIN_IGNORE_ERR) diff --git a/plugins/notifications/slack/main.go b/cmd/notification-slack/main.go similarity index 73% rename from plugins/notifications/slack/main.go rename to cmd/notification-slack/main.go index 90183238119..34c7c0df361 100644 --- a/plugins/notifications/slack/main.go +++ b/cmd/notification-slack/main.go @@ -5,20 +5,26 @@ import ( "fmt" "os" - "github.com/crowdsecurity/crowdsec/pkg/protobufs" "github.com/hashicorp/go-hclog" plugin "github.com/hashicorp/go-plugin" - "github.com/slack-go/slack" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/pkg/csplugin" + "github.com/crowdsecurity/crowdsec/pkg/protobufs" ) type PluginConfig struct { - Name string `yaml:"name"` - Webhook string `yaml:"webhook"` - LogLevel *string `yaml:"log_level"` + Name string `yaml:"name"` + Webhook string `yaml:"webhook"` + Channel string `yaml:"channel"` + Username string `yaml:"username"` + IconEmoji string `yaml:"icon_emoji"` + IconURL string `yaml:"icon_url"` + LogLevel *string `yaml:"log_level"` } type Notify struct { + protobufs.UnimplementedNotifierServer ConfigByName map[string]PluginConfig } @@ -33,6 +39,7 @@ func (n *Notify) Notify(ctx context.Context, notification *protobufs.Notificatio if _, ok := n.ConfigByName[notification.Name]; !ok { return nil, fmt.Errorf("invalid plugin config name %s", notification.Name) } + cfg := n.ConfigByName[notification.Name] if cfg.LogLevel != nil && *cfg.LogLevel != "" { @@ -41,8 +48,13 @@ func (n *Notify) Notify(ctx context.Context, notification *protobufs.Notificatio logger.Info(fmt.Sprintf("found notify signal for %s config", notification.Name)) logger.Debug(fmt.Sprintf("posting to %s webhook, message %s", cfg.Webhook, notification.Text)) - err := slack.PostWebhook(n.ConfigByName[notification.Name].Webhook, &slack.WebhookMessage{ - Text: notification.Text, + + err := slack.PostWebhookContext(ctx, cfg.Webhook, &slack.WebhookMessage{ + Text: notification.Text, + Channel: cfg.Channel, + Username: cfg.Username, + IconEmoji: cfg.IconEmoji, + IconURL: cfg.IconURL, }) if err != nil { logger.Error(err.Error()) @@ -53,16 +65,19 @@ func (n *Notify) Notify(ctx context.Context, notification *protobufs.Notificatio func (n *Notify) Configure(ctx context.Context, config *protobufs.Config) (*protobufs.Empty, error) { d := PluginConfig{} + if err := yaml.Unmarshal(config.Config, &d); err != nil { return nil, err } + n.ConfigByName[d.Name] = d logger.Debug(fmt.Sprintf("Slack plugin '%s' use URL '%s'", d.Name, d.Webhook)) + return &protobufs.Empty{}, nil } func main() { - var handshake = plugin.HandshakeConfig{ + handshake := plugin.HandshakeConfig{ ProtocolVersion: 1, MagicCookieKey: "CROWDSEC_PLUGIN_KEY", MagicCookieValue: os.Getenv("CROWDSEC_PLUGIN_KEY"), @@ -71,7 +86,7 @@ func main() { plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: handshake, Plugins: map[string]plugin.Plugin{ - "slack": &protobufs.NotifierPlugin{ + "slack": &csplugin.NotifierPlugin{ Impl: &Notify{ConfigByName: make(map[string]PluginConfig)}, }, }, diff --git a/plugins/notifications/slack/slack.yaml b/cmd/notification-slack/slack.yaml similarity index 90% rename from plugins/notifications/slack/slack.yaml rename to cmd/notification-slack/slack.yaml index 4768e869780..677d4b757c1 100644 --- a/plugins/notifications/slack/slack.yaml +++ b/cmd/notification-slack/slack.yaml @@ -28,6 +28,12 @@ format: | webhook: +# API request data as defined by the Slack webhook API. +#channel: +#username: +#icon_emoji: +#icon_url: + --- # type: slack diff --git a/cmd/notification-splunk/Makefile b/cmd/notification-splunk/Makefile new file mode 100644 index 00000000000..aa15ecac918 --- /dev/null +++ b/cmd/notification-splunk/Makefile @@ -0,0 +1,17 @@ +ifeq ($(OS), Windows_NT) + SHELL := pwsh.exe + .SHELLFLAGS := -NoProfile -Command + EXT = .exe +endif + +GO = go +GOBUILD = $(GO) build + +BINARY_NAME = notification-splunk$(EXT) + +build: clean + $(GOBUILD) $(LD_OPTS) -o $(BINARY_NAME) + +.PHONY: clean +clean: + @$(RM) $(BINARY_NAME) $(WIN_IGNORE_ERR) diff --git a/plugins/notifications/splunk/main.go b/cmd/notification-splunk/main.go similarity index 87% rename from plugins/notifications/splunk/main.go rename to cmd/notification-splunk/main.go index a9b4be50af0..e18f416c14a 100644 --- a/plugins/notifications/splunk/main.go +++ b/cmd/notification-splunk/main.go @@ -10,11 +10,12 @@ import ( "os" "strings" - "github.com/crowdsecurity/crowdsec/pkg/protobufs" "github.com/hashicorp/go-hclog" plugin "github.com/hashicorp/go-plugin" + "gopkg.in/yaml.v3" - "gopkg.in/yaml.v2" + "github.com/crowdsecurity/crowdsec/pkg/csplugin" + "github.com/crowdsecurity/crowdsec/pkg/protobufs" ) var logger hclog.Logger = hclog.New(&hclog.LoggerOptions{ @@ -32,6 +33,7 @@ type PluginConfig struct { } type Splunk struct { + protobufs.UnimplementedNotifierServer PluginConfigByName map[string]PluginConfig Client http.Client } @@ -44,6 +46,7 @@ func (s *Splunk) Notify(ctx context.Context, notification *protobufs.Notificatio if _, ok := s.PluginConfigByName[notification.Name]; !ok { return &protobufs.Empty{}, fmt.Errorf("splunk invalid config name %s", notification.Name) } + cfg := s.PluginConfigByName[notification.Name] if cfg.LogLevel != nil && *cfg.LogLevel != "" { @@ -53,35 +56,41 @@ func (s *Splunk) Notify(ctx context.Context, notification *protobufs.Notificatio logger.Info(fmt.Sprintf("received notify signal for %s config", notification.Name)) p := Payload{Event: notification.Text} + data, err := json.Marshal(p) if err != nil { return &protobufs.Empty{}, err } - req, err := http.NewRequest("POST", cfg.URL, strings.NewReader(string(data))) + req, err := http.NewRequest(http.MethodPost, cfg.URL, strings.NewReader(string(data))) if err != nil { return &protobufs.Empty{}, err } req.Header.Add("Authorization", fmt.Sprintf("Splunk %s", cfg.Token)) logger.Debug(fmt.Sprintf("posting event %s to %s", string(data), req.URL)) - resp, err := s.Client.Do(req) + + resp, err := s.Client.Do(req.WithContext(ctx)) if err != nil { return &protobufs.Empty{}, err } - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { content, err := io.ReadAll(resp.Body) if err != nil { - return &protobufs.Empty{}, fmt.Errorf("got non 200 response and failed to read error %s", err) + return &protobufs.Empty{}, fmt.Errorf("got non 200 response and failed to read error %w", err) } + return &protobufs.Empty{}, fmt.Errorf("got non 200 response %s", string(content)) } + respData, err := io.ReadAll(resp.Body) if err != nil { - return &protobufs.Empty{}, fmt.Errorf("failed to read response body got error %s", err) + return &protobufs.Empty{}, fmt.Errorf("failed to read response body got error %w", err) } + logger.Debug(fmt.Sprintf("got response %s", string(respData))) + return &protobufs.Empty{}, nil } @@ -90,11 +99,12 @@ func (s *Splunk) Configure(ctx context.Context, config *protobufs.Config) (*prot err := yaml.Unmarshal(config.Config, &d) s.PluginConfigByName[d.Name] = d logger.Debug(fmt.Sprintf("Splunk plugin '%s' use URL '%s'", d.Name, d.URL)) + return &protobufs.Empty{}, err } func main() { - var handshake = plugin.HandshakeConfig{ + handshake := plugin.HandshakeConfig{ ProtocolVersion: 1, MagicCookieKey: "CROWDSEC_PLUGIN_KEY", MagicCookieValue: os.Getenv("CROWDSEC_PLUGIN_KEY"), @@ -109,7 +119,7 @@ func main() { plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: handshake, Plugins: map[string]plugin.Plugin{ - "splunk": &protobufs.NotifierPlugin{ + "splunk": &csplugin.NotifierPlugin{ Impl: sp, }, }, diff --git a/plugins/notifications/splunk/splunk.yaml b/cmd/notification-splunk/splunk.yaml similarity index 100% rename from plugins/notifications/splunk/splunk.yaml rename to cmd/notification-splunk/splunk.yaml diff --git a/config/acquis_win.yaml b/config/acquis_win.yaml index 86d233cca8e..b198ac645ae 100644 --- a/config/acquis_win.yaml +++ b/config/acquis_win.yaml @@ -10,7 +10,7 @@ labels: --- ##Firewall filenames: - - C:\Windows\System32\LogFiles\Firewall\pfirewall.log + - C:\Windows\System32\LogFiles\Firewall\*.log labels: type: windows-firewall --- @@ -28,4 +28,4 @@ use_time_machine: true filenames: - C:\inetpub\logs\LogFiles\*\*.log labels: - type: iis \ No newline at end of file + type: iis diff --git a/config/config.yaml b/config/config.yaml index 232b0bc4389..2b0e4dfca1a 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -6,7 +6,6 @@ common: log_max_size: 20 compress_logs: true log_max_files: 10 - working_dir: . config_paths: config_dir: /etc/crowdsec/ data_dir: /var/lib/crowdsec/data/ diff --git a/config/config_win.yaml b/config/config_win.yaml index 7863f4fdd50..5c34c69a2c0 100644 --- a/config/config_win.yaml +++ b/config/config_win.yaml @@ -3,7 +3,6 @@ common: log_media: file log_level: info log_dir: C:\ProgramData\CrowdSec\log\ - working_dir: . config_paths: config_dir: C:\ProgramData\CrowdSec\config\ data_dir: C:\ProgramData\CrowdSec\data\ diff --git a/config/config_win_no_lapi.yaml b/config/config_win_no_lapi.yaml index 35c7f2c6f11..af240228bb5 100644 --- a/config/config_win_no_lapi.yaml +++ b/config/config_win_no_lapi.yaml @@ -3,7 +3,6 @@ common: log_media: file log_level: info log_dir: C:\ProgramData\CrowdSec\log\ - working_dir: . config_paths: config_dir: C:\ProgramData\CrowdSec\config\ data_dir: C:\ProgramData\CrowdSec\data\ diff --git a/config/crowdsec.cron.daily b/config/crowdsec.cron.daily index 1c110df38fc..9c488d29884 100644 --- a/config/crowdsec.cron.daily +++ b/config/crowdsec.cron.daily @@ -2,12 +2,13 @@ test -x /usr/bin/cscli || exit 0 +# splay hub upgrade and crowdsec reload +sleep "$(seq 1 300 | shuf -n 1)" + /usr/bin/cscli --error hub update upgraded=$(/usr/bin/cscli --error hub upgrade) if [ -n "$upgraded" ]; then - # splay initial metrics push - sleep $(seq 1 90 | shuf -n 1) systemctl reload crowdsec fi diff --git a/config/crowdsec.service b/config/crowdsec.service index 147cae4946e..65a8d30bc5f 100644 --- a/config/crowdsec.service +++ b/config/crowdsec.service @@ -8,6 +8,7 @@ Environment=LC_ALL=C LANG=C ExecStartPre=/usr/local/bin/crowdsec -c /etc/crowdsec/config.yaml -t -error ExecStart=/usr/local/bin/crowdsec -c /etc/crowdsec/config.yaml #ExecStartPost=/bin/sleep 0.1 +ExecReload=/usr/local/bin/crowdsec -c /etc/crowdsec/config.yaml -t -error ExecReload=/bin/kill -HUP $MAINPID Restart=always RestartSec=60 diff --git a/config/dev.yaml b/config/dev.yaml index 2ff62506017..ca1f35f32ff 100644 --- a/config/dev.yaml +++ b/config/dev.yaml @@ -2,7 +2,6 @@ common: daemonize: true log_media: stdout log_level: info - working_dir: . config_paths: config_dir: ./config data_dir: ./data/ @@ -34,6 +33,7 @@ api: client: credentials_path: ./config/local_api_credentials.yaml server: + console_path: ./config/console.yaml #insecure_skip_verify: true listen_uri: 127.0.0.1:8081 profiles_path: ./config/profiles.yaml diff --git a/config/profiles.yaml b/config/profiles.yaml index 9d81c9298a2..c4982acd978 100644 --- a/config/profiles.yaml +++ b/config/profiles.yaml @@ -12,3 +12,18 @@ decisions: # - http_default # Set the required http parameters in /etc/crowdsec/notifications/http.yaml before enabling this. # - email_default # Set the required email parameters in /etc/crowdsec/notifications/email.yaml before enabling this. on_success: break +--- +name: default_range_remediation +#debug: true +filters: + - Alert.Remediation == true && Alert.GetScope() == "Range" +decisions: + - type: ban + duration: 4h +#duration_expr: Sprintf('%dh', (GetDecisionsCount(Alert.GetValue()) + 1) * 4) +# notifications: +# - slack_default # Set the webhook in /etc/crowdsec/notifications/slack.yaml before enabling this. +# - splunk_default # Set the splunk url and token in /etc/crowdsec/notifications/splunk.yaml before enabling this. +# - http_default # Set the required http parameters in /etc/crowdsec/notifications/http.yaml before enabling this. +# - email_default # Set the required email parameters in /etc/crowdsec/notifications/email.yaml before enabling this. +on_success: break diff --git a/config/user.yaml b/config/user.yaml index 67bdfa3fc49..a1047dcd0f7 100644 --- a/config/user.yaml +++ b/config/user.yaml @@ -3,7 +3,6 @@ common: log_media: stdout log_level: info log_dir: /var/log/ - working_dir: . config_paths: config_dir: /etc/crowdsec/ data_dir: /var/lib/crowdsec/data diff --git a/debian/control b/debian/control index 4673284e7b4..0ee08b71f85 100644 --- a/debian/control +++ b/debian/control @@ -8,3 +8,4 @@ Package: crowdsec Architecture: any Description: Crowdsec - An open-source, lightweight agent to detect and respond to bad behaviors. It also automatically benefits from our global community-wide IP reputation database Depends: coreutils +Suggests: cron diff --git a/debian/crowdsec.service b/debian/crowdsec.service index 8743a03bced..c1a5e403745 100644 --- a/debian/crowdsec.service +++ b/debian/crowdsec.service @@ -5,9 +5,10 @@ After=syslog.target network.target remote-fs.target nss-lookup.target [Service] Type=notify Environment=LC_ALL=C LANG=C -ExecStartPre=/usr/bin/crowdsec -c /etc/crowdsec/config.yaml -t +ExecStartPre=/usr/bin/crowdsec -c /etc/crowdsec/config.yaml -t -error ExecStart=/usr/bin/crowdsec -c /etc/crowdsec/config.yaml #ExecStartPost=/bin/sleep 0.1 +ExecReload=/usr/bin/crowdsec -c /etc/crowdsec/config.yaml -t -error ExecReload=/bin/kill -HUP $MAINPID Restart=always RestartSec=60 diff --git a/debian/install b/debian/install index 11c82d01ecb..fa422cac8d9 100644 --- a/debian/install +++ b/debian/install @@ -6,7 +6,9 @@ config/patterns/* etc/crowdsec/patterns config/crowdsec.service lib/systemd/system # Referenced configs: -plugins/notifications/slack/slack.yaml etc/crowdsec/notifications/ -plugins/notifications/http/http.yaml etc/crowdsec/notifications/ -plugins/notifications/splunk/splunk.yaml etc/crowdsec/notifications/ -plugins/notifications/email/email.yaml etc/crowdsec/notifications/ +cmd/notification-slack/slack.yaml etc/crowdsec/notifications/ +cmd/notification-http/http.yaml etc/crowdsec/notifications/ +cmd/notification-splunk/splunk.yaml etc/crowdsec/notifications/ +cmd/notification-email/email.yaml etc/crowdsec/notifications/ +cmd/notification-sentinel/sentinel.yaml etc/crowdsec/notifications/ +cmd/notification-file/file.yaml etc/crowdsec/notifications/ diff --git a/debian/postinst b/debian/postinst index a862c88750d..77f2511f556 100644 --- a/debian/postinst +++ b/debian/postinst @@ -58,10 +58,10 @@ if [ "$1" = configure ]; then db_get crowdsec/capi CAPI=$RET - cscli machines add -a + [ -s /etc/crowdsec/local_api_credentials.yaml ] || cscli machines add -a --force --error if [ "$CAPI" = true ]; then - cscli capi register + cscli capi register --error fi else @@ -91,7 +91,7 @@ if [ "$1" = configure ]; then systemctl --quiet is-enabled crowdsec || systemctl unmask crowdsec && systemctl enable crowdsec API=$(cscli config show --key "Config.API.Server") - if [ "$API" = "" ] ; then + if [ "$API" = "nil" ] ; then LAPI=false else PORT=$(cscli config show --key "Config.API.Server.ListenURI"|cut -d ":" -f2) @@ -102,6 +102,13 @@ if [ "$1" = configure ]; then echo "Not attempting to start crowdsec, port ${PORT} is already used or lapi was disabled" echo "This port is configured through /etc/crowdsec/config.yaml and /etc/crowdsec/local_api_credentials.yaml" fi + + echo "Get started with CrowdSec:" + echo " * Detailed guides are available in our documentation: https://docs.crowdsec.net" + echo " * Configuration items created by the community can be found at the Hub: https://hub.crowdsec.net" + echo " * Gain insights into your use of CrowdSec with the help of the console https://app.crowdsec.net" + + fi -echo "You can always run the configuration again interactively by using '/usr/share/crowdsec/wizard.sh -c" +echo "You can always run the configuration again interactively by using '/usr/share/crowdsec/wizard.sh -c'" diff --git a/debian/preinst b/debian/preinst index e2485ce53eb..217b836caa6 100644 --- a/debian/preinst +++ b/debian/preinst @@ -40,4 +40,4 @@ if [ "$1" = upgrade ]; then fi fi -echo "You can always run the configuration again interactively by using '/usr/share/crowdsec/wizard.sh -c" +echo "You can always run the configuration again interactively by using '/usr/share/crowdsec/wizard.sh -c'" diff --git a/debian/prerm b/debian/prerm index eb4eb4ed7d6..a463a6a1c80 100644 --- a/debian/prerm +++ b/debian/prerm @@ -1,5 +1,5 @@ if [ "$1" = "remove" ]; then - cscli dashboard remove -f -y || true + cscli dashboard remove -f -y --error || echo "Ignore the above error if you never installed the local dashboard." systemctl stop crowdsec systemctl disable crowdsec fi diff --git a/debian/rules b/debian/rules index e6202a6f774..c11771282ea 100755 --- a/debian/rules +++ b/debian/rules @@ -17,6 +17,7 @@ override_dh_auto_install: mkdir -p debian/crowdsec/usr/bin mkdir -p debian/crowdsec/etc/crowdsec + mkdir -p debian/crowdsec/etc/crowdsec/acquis.d mkdir -p debian/crowdsec/usr/share/crowdsec mkdir -p debian/crowdsec/etc/crowdsec/hub/ mkdir -p debian/crowdsec/usr/share/crowdsec/config @@ -25,10 +26,12 @@ override_dh_auto_install: mkdir -p debian/crowdsec/usr/lib/crowdsec/plugins/ mkdir -p debian/crowdsec/etc/crowdsec/notifications/ - install -m 551 plugins/notifications/slack/notification-slack debian/crowdsec/usr/lib/crowdsec/plugins/ - install -m 551 plugins/notifications/http/notification-http debian/crowdsec/usr/lib/crowdsec/plugins/ - install -m 551 plugins/notifications/splunk/notification-splunk debian/crowdsec/usr/lib/crowdsec/plugins/ - install -m 551 plugins/notifications/email/notification-email debian/crowdsec/usr/lib/crowdsec/plugins/ + install -m 551 cmd/notification-slack/notification-slack debian/crowdsec/usr/lib/crowdsec/plugins/ + install -m 551 cmd/notification-http/notification-http debian/crowdsec/usr/lib/crowdsec/plugins/ + install -m 551 cmd/notification-splunk/notification-splunk debian/crowdsec/usr/lib/crowdsec/plugins/ + install -m 551 cmd/notification-email/notification-email debian/crowdsec/usr/lib/crowdsec/plugins/ + install -m 551 cmd/notification-sentinel/notification-sentinel debian/crowdsec/usr/lib/crowdsec/plugins/ + install -m 551 cmd/notification-file/notification-file debian/crowdsec/usr/lib/crowdsec/plugins/ cp cmd/crowdsec/crowdsec debian/crowdsec/usr/bin cp cmd/crowdsec-cli/cscli debian/crowdsec/usr/bin diff --git a/docker/README.md b/docker/README.md index e1c7b517e73..ad31d10aed6 100644 --- a/docker/README.md +++ b/docker/README.md @@ -19,11 +19,7 @@ All the following images are available on Docker Hub for the architectures - `crowdsecurity/crowdsec:{version}` -Recommended for production usage. Also available on GitHub (ghcr.io). - - - `crowdsecurity/crowdsec:dev` - -The latest stable release. +Latest stable release recommended for production usage. Also available on GitHub (ghcr.io). - `crowdsecurity/crowdsec:dev` @@ -190,6 +186,14 @@ It is not recommended anymore to bind-mount the full config.yaml file and you sh If you want to use the [notification system](https://docs.crowdsec.net/docs/notification_plugins/intro), you have to use the full image (not slim) and mount at least a custom `profiles.yaml` and a notification configuration to `/etc/crowdsec/notifications` +```shell +docker run -d \ + -v ./profiles.yaml:/etc/crowdsec/profiles.yaml \ + -v ./http_notification.yaml:/etc/crowdsec/notifications/http_notification.yaml \ + -p 8080:8080 -p 6060:6060 \ + --name crowdsec crowdsecurity/crowdsec +``` + # Deployment use cases Crowdsec is composed of an `agent` that parses logs and creates `alerts`, and a @@ -312,16 +316,26 @@ config.yaml) each time the container is run. | `BOUNCERS_ALLOWED_OU` | bouncer-ou | OU values allowed for bouncers, separated by comma | | | | | | __Hub management__ | | | +| `NO_HUB_UPGRADE` | false | Skip hub update / upgrade when the container starts | | `COLLECTIONS` | | Collections to install, separated by space: `-e COLLECTIONS="crowdsecurity/linux crowdsecurity/apache2"` | | `PARSERS` | | Parsers to install, separated by space | | `SCENARIOS` | | Scenarios to install, separated by space | | `POSTOVERFLOWS` | | Postoverflows to install, separated by space | +| `CONTEXTS` | | Context files to install, separated by space | +| `APPSEC_CONFIGS` | | Appsec configs files to install, separated by space | +| `APPSEC_RULES` | | Appsec rules files to install, separated by space | | `DISABLE_COLLECTIONS` | | Collections to remove, separated by space: `-e DISABLE_COLLECTIONS="crowdsecurity/linux crowdsecurity/nginx"` | | `DISABLE_PARSERS` | | Parsers to remove, separated by space | | `DISABLE_SCENARIOS` | | Scenarios to remove, separated by space | | `DISABLE_POSTOVERFLOWS` | | Postoverflows to remove, separated by space | +| `DISABLE_CONTEXTS` | | Context files to remove, separated by space | +| `DISABLE_APPSEC_CONFIGS`| | Appsec configs files to remove, separated by space | +| `DISABLE_APPSEC_RULES` | | Appsec rules files to remove, separated by space | | | | | | __Log verbosity__ | | | +| `LEVEL_FATAL` | false | Force FATAL level for the container log | +| `LEVEL_ERROR` | false | Force ERROR level for the container log | +| `LEVEL_WARN` | false | Force WARN level for the container log | | `LEVEL_INFO` | false | Force INFO level for the container log | | `LEVEL_DEBUG` | false | Force DEBUG level for the container log | | `LEVEL_TRACE` | false | Force TRACE level (VERY verbose) for the container log | diff --git a/docker/config.yaml b/docker/config.yaml index 5259a0fe26e..6811329099a 100644 --- a/docker/config.yaml +++ b/docker/config.yaml @@ -3,7 +3,6 @@ common: log_media: stdout log_level: info log_dir: /var/log/ - working_dir: . config_paths: config_dir: /etc/crowdsec/ data_dir: /var/lib/crowdsec/data/ diff --git a/docker/docker_start.sh b/docker/docker_start.sh index 21b42dcb0b4..fb87c1eff9b 100755 --- a/docker/docker_start.sh +++ b/docker/docker_start.sh @@ -6,6 +6,9 @@ set -e shopt -s inherit_errexit +# Note that "if function_name" in bash matches when the function returns 0, +# meaning successful execution. + # match true, TRUE, True, tRuE, etc. istrue() { case "$(echo "$1" | tr '[:upper:]' '[:lower:]')" in @@ -50,6 +53,52 @@ cscli() { command cscli -c "$CONFIG_FILE" "$@" } +run_hub_update() { + index_modification_time=$(stat -c %Y /etc/crowdsec/hub/.index.json 2>/dev/null) + # Run cscli hub update if no date or if the index file is older than 24h + if [ -z "$index_modification_time" ] || [ $(( $(date +%s) - index_modification_time )) -gt 86400 ]; then + cscli hub update --with-content + else + echo "Skipping hub update, index file is recent" + fi +} + +is_mounted() { + path=$(readlink -f "$1") + mounts=$(awk '{print $2}' /proc/mounts) + while true; do + if grep -qE ^"$path"$ <<< "$mounts"; then + echo "$path was found in a volume" + return 0 + fi + path=$(dirname "$path") + if [ "$path" = "/" ]; then + return 1 + fi + done + return 1 #unreachable +} + +run_hub_update_if_from_volume() { + if is_mounted "/etc/crowdsec/hub/.index.json"; then + echo "Running hub update" + run_hub_update + else + echo "Skipping hub update, index file is not in a volume" + fi +} + +run_hub_upgrade_if_from_volume() { + isfalse "$NO_HUB_UPGRADE" || return 0 + if is_mounted "/var/lib/crowdsec/data"; then + echo "Running hub upgrade" + cscli hub upgrade + else + echo "Skipping hub upgrade, data directory is not in a volume" + fi + +} + # conf_get [file_path] # retrieve a value from a file (by default $CONFIG_FILE) conf_get() { @@ -101,19 +150,30 @@ register_bouncer() { # $2 can be install, remove, upgrade # $3 is a list of object names separated by space cscli_if_clean() { + local itemtype="$1" + local action="$2" + local objs=$3 + shift 3 # loop over all objects - for obj in $3; do - if cscli "$1" inspect "$obj" -o json | yq -e '.tainted // false' >/dev/null 2>&1; then - echo "Object $1/$obj is tainted, skipping" + for obj in $objs; do + if cscli "$itemtype" inspect "$obj" -o json | yq -e '.tainted // false' >/dev/null 2>&1; then + echo "Object $itemtype/$obj is tainted, skipping" + elif cscli "$itemtype" inspect "$obj" -o json | yq -e '.local // false' >/dev/null 2>&1; then + echo "Object $itemtype/$obj is local, skipping" else # # Too verbose? Only show errors if not in debug mode # if [ "$DEBUG" != "true" ]; then # error_only=--error # fi error_only="" - echo "Running: cscli $error_only $1 $2 \"$obj\"" + echo "Running: cscli $error_only $itemtype $action \"$obj\" $*" # shellcheck disable=SC2086 - cscli $error_only "$1" "$2" "$obj" + if ! cscli $error_only "$itemtype" "$action" "$obj" "$@"; then + echo "Failed to $action $itemtype/$obj, running hub update before retrying" + run_hub_update + # shellcheck disable=SC2086 + cscli $error_only "$itemtype" "$action" "$obj" "$@" + fi fi done } @@ -153,15 +213,16 @@ if [ -n "$CERT_FILE" ] || [ -n "$KEY_FILE" ] ; then export LAPI_KEY_FILE=${LAPI_KEY_FILE:-$KEY_FILE} fi -# Check and prestage databases -for geodb in GeoLite2-ASN.mmdb GeoLite2-City.mmdb; do - # We keep the pre-populated geoip databases in /staging instead of /var, - # because if the data directory is bind-mounted from the host, it will be - # empty and the files will be out of reach, requiring a runtime download. - # We link to them to save about 80Mb compared to cp/mv. - if [ ! -e "/var/lib/crowdsec/data/$geodb" ] && [ -e "/staging/var/lib/crowdsec/data/$geodb" ]; then - mkdir -p /var/lib/crowdsec/data - ln -s "/staging/var/lib/crowdsec/data/$geodb" /var/lib/crowdsec/data/ +# Link the preloaded data files when the data dir is mounted (common case) +# The symlinks can be overridden by hub upgrade +for target in "/staging/var/lib/crowdsec/data"/*; do + fname="$(basename "$target")" + # skip the db and wal files + if [[ $fname == crowdsec.db* ]]; then + continue + fi + if [ ! -e "/var/lib/crowdsec/data/$fname" ]; then + ln -s "$target" "/var/lib/crowdsec/data/$fname" fi done @@ -174,7 +235,7 @@ if [ ! -e "/etc/crowdsec/local_api_credentials.yaml" ] && [ ! -e "/etc/crowdsec/ mkdir -p /etc/crowdsec/ # if you change this, check that it still works # under alpine and k8s, with and without tls - cp -an /staging/etc/crowdsec/* /etc/crowdsec/ + rsync -av --ignore-existing /staging/etc/crowdsec/* /etc/crowdsec fi fi @@ -198,7 +259,7 @@ if isfalse "$DISABLE_LOCAL_API"; then # if the db is persistent but the credentials are not, we need to # delete the old machine to generate new credentials cscli machines delete "$CUSTOM_HOSTNAME" >/dev/null 2>&1 || true - cscli machines add "$CUSTOM_HOSTNAME" --auto + cscli machines add "$CUSTOM_HOSTNAME" --auto --force fi fi @@ -243,7 +304,7 @@ if istrue "$DISABLE_ONLINE_API"; then fi # registration to online API for signal push -if isfalse "$DISABLE_ONLINE_API" ; then +if isfalse "$DISABLE_LOCAL_API" && isfalse "$DISABLE_ONLINE_API" ; then CONFIG_DIR=$(conf_get '.config_paths.config_dir') export CONFIG_DIR config_exists=$(conf_get '.api.server.online_client | has("credentials_path")') @@ -255,7 +316,7 @@ if isfalse "$DISABLE_ONLINE_API" ; then fi # Enroll instance if enroll key is provided -if isfalse "$DISABLE_ONLINE_API" && [ "$ENROLL_KEY" != "" ]; then +if isfalse "$DISABLE_LOCAL_API" && isfalse "$DISABLE_ONLINE_API" && [ "$ENROLL_KEY" != "" ]; then enroll_args="" if [ "$ENROLL_INSTANCE_NAME" != "" ]; then enroll_args="--name $ENROLL_INSTANCE_NAME" @@ -273,13 +334,16 @@ fi # crowdsec sqlite database permissions if [ "$GID" != "" ]; then if istrue "$(conf_get '.db_config.type == "sqlite"')"; then - chown ":$GID" "$(conf_get '.db_config.db_path')" - echo "sqlite database permissions updated" + # force the creation of the db file(s) + cscli machines inspect create-db --error >/dev/null 2>&1 || : + # don't fail if the db is not there yet + if chown -f ":$GID" "$(conf_get '.db_config.db_path')" 2>/dev/null; then + echo "sqlite database permissions updated" + fi fi fi -# XXX only with LAPI -if istrue "$USE_TLS"; then +if isfalse "$DISABLE_LOCAL_API" && istrue "$USE_TLS"; then agents_allowed_yaml=$(csv2yaml "$AGENTS_ALLOWED_OU") export agents_allowed_yaml bouncers_allowed_yaml=$(csv2yaml "$BOUNCERS_ALLOWED_OU") @@ -295,11 +359,11 @@ fi conf_set_if "$PLUGIN_DIR" '.config_paths.plugin_dir = strenv(PLUGIN_DIR)' -## Install collections, parsers, scenarios & postoverflows -cscli hub update +## Install hub items + +run_hub_update_if_from_volume || true +run_hub_upgrade_if_from_volume || true -cscli_if_clean collections upgrade crowdsecurity/linux -cscli_if_clean parsers upgrade crowdsecurity/whitelists cscli_if_clean parsers install crowdsecurity/docker-logs cscli_if_clean parsers install crowdsecurity/cri-logs @@ -323,25 +387,55 @@ if [ "$POSTOVERFLOWS" != "" ]; then cscli_if_clean postoverflows install "$(difference "$POSTOVERFLOWS" "$DISABLE_POSTOVERFLOWS")" fi +if [ "$CONTEXTS" != "" ]; then + # shellcheck disable=SC2086 + cscli_if_clean contexts install "$(difference "$CONTEXTS" "$DISABLE_CONTEXTS")" +fi + +if [ "$APPSEC_CONFIGS" != "" ]; then + # shellcheck disable=SC2086 + cscli_if_clean appsec-configs install "$(difference "$APPSEC_CONFIGS" "$DISABLE_APPSEC_CONFIGS")" +fi + +if [ "$APPSEC_RULES" != "" ]; then + # shellcheck disable=SC2086 + cscli_if_clean appsec-rules install "$(difference "$APPSEC_RULES" "$DISABLE_APPSEC_RULES")" +fi + ## Remove collections, parsers, scenarios & postoverflows if [ "$DISABLE_COLLECTIONS" != "" ]; then # shellcheck disable=SC2086 - cscli_if_clean collections remove "$DISABLE_COLLECTIONS" + cscli_if_clean collections remove "$DISABLE_COLLECTIONS" --force fi if [ "$DISABLE_PARSERS" != "" ]; then # shellcheck disable=SC2086 - cscli_if_clean parsers remove "$DISABLE_PARSERS" + cscli_if_clean parsers remove "$DISABLE_PARSERS" --force fi if [ "$DISABLE_SCENARIOS" != "" ]; then # shellcheck disable=SC2086 - cscli_if_clean scenarios remove "$DISABLE_SCENARIOS" + cscli_if_clean scenarios remove "$DISABLE_SCENARIOS" --force fi if [ "$DISABLE_POSTOVERFLOWS" != "" ]; then # shellcheck disable=SC2086 - cscli_if_clean postoverflows remove "$DISABLE_POSTOVERFLOWS" + cscli_if_clean postoverflows remove "$DISABLE_POSTOVERFLOWS" --force +fi + +if [ "$DISABLE_CONTEXTS" != "" ]; then + # shellcheck disable=SC2086 + cscli_if_clean contexts remove "$DISABLE_CONTEXTS" --force +fi + +if [ "$DISABLE_APPSEC_CONFIGS" != "" ]; then + # shellcheck disable=SC2086 + cscli_if_clean appsec-configs remove "$DISABLE_APPSEC_CONFIGS" --force +fi + +if [ "$DISABLE_APPSEC_RULES" != "" ]; then + # shellcheck disable=SC2086 + cscli_if_clean appsec-rules remove "$DISABLE_APPSEC_RULES" --force fi ## Register bouncers via env @@ -353,12 +447,17 @@ for BOUNCER in $(compgen -A variable | grep -i BOUNCER_KEY); do fi done +if [ "$ENABLE_CONSOLE_MANAGEMENT" != "" ]; then + # shellcheck disable=SC2086 + cscli console enable console_management +fi + ## Register bouncers via secrets (Swarm only) shopt -s nullglob extglob for BOUNCER in /run/secrets/@(bouncer_key|BOUNCER_KEY)* ; do KEY=$(cat "${BOUNCER}") NAME=$(echo "${BOUNCER}" | awk -F "/" '{printf $NF}' | cut -d_ -f2-) - if [[ -n $KEY ]] && [[ -n $NAME ]]; then + if [[ -n $KEY ]] && [[ -n $NAME ]]; then register_bouncer "$NAME" "$KEY" fi done @@ -369,6 +468,12 @@ shopt -u nullglob extglob conf_set_if "$CAPI_WHITELISTS_PATH" '.api.server.capi_whitelists_path = strenv(CAPI_WHITELISTS_PATH)' conf_set_if "$METRICS_PORT" '.prometheus.listen_port=env(METRICS_PORT)' +if istrue "$DISABLE_LOCAL_API"; then + conf_set '.api.server.enable=false' +else + conf_set '.api.server.enable=true' +fi + ARGS="" if [ "$CONFIG_FILE" != "" ]; then ARGS="-c $CONFIG_FILE" @@ -390,10 +495,6 @@ if istrue "$DISABLE_AGENT"; then ARGS="$ARGS -no-cs" fi -if istrue "$DISABLE_LOCAL_API"; then - ARGS="$ARGS -no-api" -fi - if istrue "$LEVEL_TRACE"; then ARGS="$ARGS -trace" fi @@ -406,5 +507,17 @@ if istrue "$LEVEL_INFO"; then ARGS="$ARGS -info" fi +if istrue "$LEVEL_WARN"; then + ARGS="$ARGS -warning" +fi + +if istrue "$LEVEL_ERROR"; then + ARGS="$ARGS -error" +fi + +if istrue "$LEVEL_FATAL"; then + ARGS="$ARGS -fatal" +fi + # shellcheck disable=SC2086 exec crowdsec $ARGS diff --git a/docker/preload-hub-items b/docker/preload-hub-items new file mode 100755 index 00000000000..45155d17af9 --- /dev/null +++ b/docker/preload-hub-items @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +set -eu + +# pre-download everything but don't install anything + +echo "Pre-downloading Hub content..." + +types=$(cscli hub types -o raw) + +for itemtype in $types; do + ALL_ITEMS=$(cscli "$itemtype" list -a -o json | itemtype="$itemtype" yq '.[env(itemtype)][] | .name') + if [[ -n "${ALL_ITEMS}" ]]; then + #shellcheck disable=SC2086 + cscli "$itemtype" install \ + $ALL_ITEMS \ + --download-only \ + --error + fi +done + +echo " done." \ No newline at end of file diff --git a/docker/test/Pipfile b/docker/test/Pipfile index bffd8f2cc8f..c57ccb628e8 100644 --- a/docker/test/Pipfile +++ b/docker/test/Pipfile @@ -1,7 +1,7 @@ [packages] pytest-dotenv = "0.5.2" -pytest-xdist = "3.3.1" -pytest-cs = {ref = "0.7.16", git = "https://github.com/crowdsecurity/pytest-cs.git"} +pytest-xdist = "3.5.0" +pytest-cs = {ref = "0.7.19", git = "https://github.com/crowdsecurity/pytest-cs.git"} [dev-packages] gnureadline = "8.1.2" diff --git a/docker/test/Pipfile.lock b/docker/test/Pipfile.lock index 8763bbe99cd..99184d9f2a2 100644 --- a/docker/test/Pipfile.lock +++ b/docker/test/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "78f693678e411b7bdb5dd0280b7d6f8d9880069b331d44d96d32ba697275e30d" + "sha256": "b5d25a7199d15a900b285be1af97cf7b7083c6637d631ad777b454471c8319fe" }, "pipfile-spec": 6, "requires": { @@ -18,214 +18,237 @@ "default": { "certifi": { "hashes": [ - "sha256:0f0d56dc5a6ad56fd4ba36484d6cc34451e1c6548c61daad8c320169f91eddc7", - "sha256:c6c2e98f5c7869efca1f8916fed228dd91539f9f1b444c314c06eef02980c716" + "sha256:922820b53db7a7257ffbda3f597266d435245903d80737e34f8a45ff3e3230d8", + "sha256:bec941d2aa8195e248a60b31ff9f0558284cf01a52591ceda73ea9afffd69fd9" ], "markers": "python_version >= '3.6'", - "version": "==2023.5.7" + "version": "==2024.8.30" }, "cffi": { "hashes": [ - "sha256:00a9ed42e88df81ffae7a8ab6d9356b371399b91dbdf0c3cb1e84c03a13aceb5", - "sha256:03425bdae262c76aad70202debd780501fabeaca237cdfddc008987c0e0f59ef", - "sha256:04ed324bda3cda42b9b695d51bb7d54b680b9719cfab04227cdd1e04e5de3104", - "sha256:0e2642fe3142e4cc4af0799748233ad6da94c62a8bec3a6648bf8ee68b1c7426", - "sha256:173379135477dc8cac4bc58f45db08ab45d228b3363adb7af79436135d028405", - "sha256:198caafb44239b60e252492445da556afafc7d1e3ab7a1fb3f0584ef6d742375", - "sha256:1e74c6b51a9ed6589199c787bf5f9875612ca4a8a0785fb2d4a84429badaf22a", - "sha256:2012c72d854c2d03e45d06ae57f40d78e5770d252f195b93f581acf3ba44496e", - "sha256:21157295583fe8943475029ed5abdcf71eb3911894724e360acff1d61c1d54bc", - "sha256:2470043b93ff09bf8fb1d46d1cb756ce6132c54826661a32d4e4d132e1977adf", - "sha256:285d29981935eb726a4399badae8f0ffdff4f5050eaa6d0cfc3f64b857b77185", - "sha256:30d78fbc8ebf9c92c9b7823ee18eb92f2e6ef79b45ac84db507f52fbe3ec4497", - "sha256:320dab6e7cb2eacdf0e658569d2575c4dad258c0fcc794f46215e1e39f90f2c3", - "sha256:33ab79603146aace82c2427da5ca6e58f2b3f2fb5da893ceac0c42218a40be35", - "sha256:3548db281cd7d2561c9ad9984681c95f7b0e38881201e157833a2342c30d5e8c", - "sha256:3799aecf2e17cf585d977b780ce79ff0dc9b78d799fc694221ce814c2c19db83", - "sha256:39d39875251ca8f612b6f33e6b1195af86d1b3e60086068be9cc053aa4376e21", - "sha256:3b926aa83d1edb5aa5b427b4053dc420ec295a08e40911296b9eb1b6170f6cca", - "sha256:3bcde07039e586f91b45c88f8583ea7cf7a0770df3a1649627bf598332cb6984", - "sha256:3d08afd128ddaa624a48cf2b859afef385b720bb4b43df214f85616922e6a5ac", - "sha256:3eb6971dcff08619f8d91607cfc726518b6fa2a9eba42856be181c6d0d9515fd", - "sha256:40f4774f5a9d4f5e344f31a32b5096977b5d48560c5592e2f3d2c4374bd543ee", - "sha256:4289fc34b2f5316fbb762d75362931e351941fa95fa18789191b33fc4cf9504a", - "sha256:470c103ae716238bbe698d67ad020e1db9d9dba34fa5a899b5e21577e6d52ed2", - "sha256:4f2c9f67e9821cad2e5f480bc8d83b8742896f1242dba247911072d4fa94c192", - "sha256:50a74364d85fd319352182ef59c5c790484a336f6db772c1a9231f1c3ed0cbd7", - "sha256:54a2db7b78338edd780e7ef7f9f6c442500fb0d41a5a4ea24fff1c929d5af585", - "sha256:5635bd9cb9731e6d4a1132a498dd34f764034a8ce60cef4f5319c0541159392f", - "sha256:59c0b02d0a6c384d453fece7566d1c7e6b7bae4fc5874ef2ef46d56776d61c9e", - "sha256:5d598b938678ebf3c67377cdd45e09d431369c3b1a5b331058c338e201f12b27", - "sha256:5df2768244d19ab7f60546d0c7c63ce1581f7af8b5de3eb3004b9b6fc8a9f84b", - "sha256:5ef34d190326c3b1f822a5b7a45f6c4535e2f47ed06fec77d3d799c450b2651e", - "sha256:6975a3fac6bc83c4a65c9f9fcab9e47019a11d3d2cf7f3c0d03431bf145a941e", - "sha256:6c9a799e985904922a4d207a94eae35c78ebae90e128f0c4e521ce339396be9d", - "sha256:70df4e3b545a17496c9b3f41f5115e69a4f2e77e94e1d2a8e1070bc0c38c8a3c", - "sha256:7473e861101c9e72452f9bf8acb984947aa1661a7704553a9f6e4baa5ba64415", - "sha256:8102eaf27e1e448db915d08afa8b41d6c7ca7a04b7d73af6514df10a3e74bd82", - "sha256:87c450779d0914f2861b8526e035c5e6da0a3199d8f1add1a665e1cbc6fc6d02", - "sha256:8b7ee99e510d7b66cdb6c593f21c043c248537a32e0bedf02e01e9553a172314", - "sha256:91fc98adde3d7881af9b59ed0294046f3806221863722ba7d8d120c575314325", - "sha256:94411f22c3985acaec6f83c6df553f2dbe17b698cc7f8ae751ff2237d96b9e3c", - "sha256:98d85c6a2bef81588d9227dde12db8a7f47f639f4a17c9ae08e773aa9c697bf3", - "sha256:9ad5db27f9cabae298d151c85cf2bad1d359a1b9c686a275df03385758e2f914", - "sha256:a0b71b1b8fbf2b96e41c4d990244165e2c9be83d54962a9a1d118fd8657d2045", - "sha256:a0f100c8912c114ff53e1202d0078b425bee3649ae34d7b070e9697f93c5d52d", - "sha256:a591fe9e525846e4d154205572a029f653ada1a78b93697f3b5a8f1f2bc055b9", - "sha256:a5c84c68147988265e60416b57fc83425a78058853509c1b0629c180094904a5", - "sha256:a66d3508133af6e8548451b25058d5812812ec3798c886bf38ed24a98216fab2", - "sha256:a8c4917bd7ad33e8eb21e9a5bbba979b49d9a97acb3a803092cbc1133e20343c", - "sha256:b3bbeb01c2b273cca1e1e0c5df57f12dce9a4dd331b4fa1635b8bec26350bde3", - "sha256:cba9d6b9a7d64d4bd46167096fc9d2f835e25d7e4c121fb2ddfc6528fb0413b2", - "sha256:cc4d65aeeaa04136a12677d3dd0b1c0c94dc43abac5860ab33cceb42b801c1e8", - "sha256:ce4bcc037df4fc5e3d184794f27bdaab018943698f4ca31630bc7f84a7b69c6d", - "sha256:cec7d9412a9102bdc577382c3929b337320c4c4c4849f2c5cdd14d7368c5562d", - "sha256:d400bfb9a37b1351253cb402671cea7e89bdecc294e8016a707f6d1d8ac934f9", - "sha256:d61f4695e6c866a23a21acab0509af1cdfd2c013cf256bbf5b6b5e2695827162", - "sha256:db0fbb9c62743ce59a9ff687eb5f4afbe77e5e8403d6697f7446e5f609976f76", - "sha256:dd86c085fae2efd48ac91dd7ccffcfc0571387fe1193d33b6394db7ef31fe2a4", - "sha256:e00b098126fd45523dd056d2efba6c5a63b71ffe9f2bbe1a4fe1716e1d0c331e", - "sha256:e229a521186c75c8ad9490854fd8bbdd9a0c9aa3a524326b55be83b54d4e0ad9", - "sha256:e263d77ee3dd201c3a142934a086a4450861778baaeeb45db4591ef65550b0a6", - "sha256:ed9cb427ba5504c1dc15ede7d516b84757c3e3d7868ccc85121d9310d27eed0b", - "sha256:fa6693661a4c91757f4412306191b6dc88c1703f780c8234035eac011922bc01", - "sha256:fcd131dd944808b5bdb38e6f5b53013c5aa4f334c5cad0c72742f6eba4b73db0" - ], - "version": "==1.15.1" + "sha256:045d61c734659cc045141be4bae381a41d89b741f795af1dd018bfb532fd0df8", + "sha256:0984a4925a435b1da406122d4d7968dd861c1385afe3b45ba82b750f229811e2", + "sha256:0e2b1fac190ae3ebfe37b979cc1ce69c81f4e4fe5746bb401dca63a9062cdaf1", + "sha256:0f048dcf80db46f0098ccac01132761580d28e28bc0f78ae0d58048063317e15", + "sha256:1257bdabf294dceb59f5e70c64a3e2f462c30c7ad68092d01bbbfb1c16b1ba36", + "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824", + "sha256:1d599671f396c4723d016dbddb72fe8e0397082b0a77a4fab8028923bec050e8", + "sha256:28b16024becceed8c6dfbc75629e27788d8a3f9030691a1dbf9821a128b22c36", + "sha256:2bb1a08b8008b281856e5971307cc386a8e9c5b625ac297e853d36da6efe9c17", + "sha256:30c5e0cb5ae493c04c8b42916e52ca38079f1b235c2f8ae5f4527b963c401caf", + "sha256:31000ec67d4221a71bd3f67df918b1f88f676f1c3b535a7eb473255fdc0b83fc", + "sha256:386c8bf53c502fff58903061338ce4f4950cbdcb23e2902d86c0f722b786bbe3", + "sha256:3edc8d958eb099c634dace3c7e16560ae474aa3803a5df240542b305d14e14ed", + "sha256:45398b671ac6d70e67da8e4224a065cec6a93541bb7aebe1b198a61b58c7b702", + "sha256:46bf43160c1a35f7ec506d254e5c890f3c03648a4dbac12d624e4490a7046cd1", + "sha256:4ceb10419a9adf4460ea14cfd6bc43d08701f0835e979bf821052f1805850fe8", + "sha256:51392eae71afec0d0c8fb1a53b204dbb3bcabcb3c9b807eedf3e1e6ccf2de903", + "sha256:5da5719280082ac6bd9aa7becb3938dc9f9cbd57fac7d2871717b1feb0902ab6", + "sha256:610faea79c43e44c71e1ec53a554553fa22321b65fae24889706c0a84d4ad86d", + "sha256:636062ea65bd0195bc012fea9321aca499c0504409f413dc88af450b57ffd03b", + "sha256:6883e737d7d9e4899a8a695e00ec36bd4e5e4f18fabe0aca0efe0a4b44cdb13e", + "sha256:6b8b4a92e1c65048ff98cfe1f735ef8f1ceb72e3d5f0c25fdb12087a23da22be", + "sha256:6f17be4345073b0a7b8ea599688f692ac3ef23ce28e5df79c04de519dbc4912c", + "sha256:706510fe141c86a69c8ddc029c7910003a17353970cff3b904ff0686a5927683", + "sha256:72e72408cad3d5419375fc87d289076ee319835bdfa2caad331e377589aebba9", + "sha256:733e99bc2df47476e3848417c5a4540522f234dfd4ef3ab7fafdf555b082ec0c", + "sha256:7596d6620d3fa590f677e9ee430df2958d2d6d6de2feeae5b20e82c00b76fbf8", + "sha256:78122be759c3f8a014ce010908ae03364d00a1f81ab5c7f4a7a5120607ea56e1", + "sha256:805b4371bf7197c329fcb3ead37e710d1bca9da5d583f5073b799d5c5bd1eee4", + "sha256:85a950a4ac9c359340d5963966e3e0a94a676bd6245a4b55bc43949eee26a655", + "sha256:8f2cdc858323644ab277e9bb925ad72ae0e67f69e804f4898c070998d50b1a67", + "sha256:9755e4345d1ec879e3849e62222a18c7174d65a6a92d5b346b1863912168b595", + "sha256:98e3969bcff97cae1b2def8ba499ea3d6f31ddfdb7635374834cf89a1a08ecf0", + "sha256:a08d7e755f8ed21095a310a693525137cfe756ce62d066e53f502a83dc550f65", + "sha256:a1ed2dd2972641495a3ec98445e09766f077aee98a1c896dcb4ad0d303628e41", + "sha256:a24ed04c8ffd54b0729c07cee15a81d964e6fee0e3d4d342a27b020d22959dc6", + "sha256:a45e3c6913c5b87b3ff120dcdc03f6131fa0065027d0ed7ee6190736a74cd401", + "sha256:a9b15d491f3ad5d692e11f6b71f7857e7835eb677955c00cc0aefcd0669adaf6", + "sha256:ad9413ccdeda48c5afdae7e4fa2192157e991ff761e7ab8fdd8926f40b160cc3", + "sha256:b2ab587605f4ba0bf81dc0cb08a41bd1c0a5906bd59243d56bad7668a6fc6c16", + "sha256:b62ce867176a75d03a665bad002af8e6d54644fad99a3c70905c543130e39d93", + "sha256:c03e868a0b3bc35839ba98e74211ed2b05d2119be4e8a0f224fba9384f1fe02e", + "sha256:c59d6e989d07460165cc5ad3c61f9fd8f1b4796eacbd81cee78957842b834af4", + "sha256:c7eac2ef9b63c79431bc4b25f1cd649d7f061a28808cbc6c47b534bd789ef964", + "sha256:c9c3d058ebabb74db66e431095118094d06abf53284d9c81f27300d0e0d8bc7c", + "sha256:ca74b8dbe6e8e8263c0ffd60277de77dcee6c837a3d0881d8c1ead7268c9e576", + "sha256:caaf0640ef5f5517f49bc275eca1406b0ffa6aa184892812030f04c2abf589a0", + "sha256:cdf5ce3acdfd1661132f2a9c19cac174758dc2352bfe37d98aa7512c6b7178b3", + "sha256:d016c76bdd850f3c626af19b0542c9677ba156e4ee4fccfdd7848803533ef662", + "sha256:d01b12eeeb4427d3110de311e1774046ad344f5b1a7403101878976ecd7a10f3", + "sha256:d63afe322132c194cf832bfec0dc69a99fb9bb6bbd550f161a49e9e855cc78ff", + "sha256:da95af8214998d77a98cc14e3a3bd00aa191526343078b530ceb0bd710fb48a5", + "sha256:dd398dbc6773384a17fe0d3e7eeb8d1a21c2200473ee6806bb5e6a8e62bb73dd", + "sha256:de2ea4b5833625383e464549fec1bc395c1bdeeb5f25c4a3a82b5a8c756ec22f", + "sha256:de55b766c7aa2e2a3092c51e0483d700341182f08e67c63630d5b6f200bb28e5", + "sha256:df8b1c11f177bc2313ec4b2d46baec87a5f3e71fc8b45dab2ee7cae86d9aba14", + "sha256:e03eab0a8677fa80d646b5ddece1cbeaf556c313dcfac435ba11f107ba117b5d", + "sha256:e221cf152cff04059d011ee126477f0d9588303eb57e88923578ace7baad17f9", + "sha256:e31ae45bc2e29f6b2abd0de1cc3b9d5205aa847cafaecb8af1476a609a2f6eb7", + "sha256:edae79245293e15384b51f88b00613ba9f7198016a5948b5dddf4917d4d26382", + "sha256:f1e22e8c4419538cb197e4dd60acc919d7696e5ef98ee4da4e01d3f8cfa4cc5a", + "sha256:f3a2b4222ce6b60e2e8b337bb9596923045681d71e5a082783484d845390938e", + "sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a", + "sha256:f75c7ab1f9e4aca5414ed4d8e5c0e303a34f4421f8a0d47a4d019ceff0ab6af4", + "sha256:f79fc4fc25f1c8698ff97788206bb3c2598949bfe0fef03d299eb1b5356ada99", + "sha256:f7f5baafcc48261359e14bcd6d9bff6d4b28d9103847c9e136694cb0501aef87", + "sha256:fc48c783f9c87e60831201f2cce7f3b2e4846bf4d8728eabe54d60700b318a0b" + ], + "markers": "platform_python_implementation != 'PyPy'", + "version": "==1.17.1" }, "charset-normalizer": { "hashes": [ - "sha256:04e57ab9fbf9607b77f7d057974694b4f6b142da9ed4a199859d9d4d5c63fe96", - "sha256:09393e1b2a9461950b1c9a45d5fd251dc7c6f228acab64da1c9c0165d9c7765c", - "sha256:0b87549028f680ca955556e3bd57013ab47474c3124dc069faa0b6545b6c9710", - "sha256:1000fba1057b92a65daec275aec30586c3de2401ccdcd41f8a5c1e2c87078706", - "sha256:1249cbbf3d3b04902ff081ffbb33ce3377fa6e4c7356f759f3cd076cc138d020", - "sha256:1920d4ff15ce893210c1f0c0e9d19bfbecb7983c76b33f046c13a8ffbd570252", - "sha256:193cbc708ea3aca45e7221ae58f0fd63f933753a9bfb498a3b474878f12caaad", - "sha256:1a100c6d595a7f316f1b6f01d20815d916e75ff98c27a01ae817439ea7726329", - "sha256:1f30b48dd7fa1474554b0b0f3fdfdd4c13b5c737a3c6284d3cdc424ec0ffff3a", - "sha256:203f0c8871d5a7987be20c72442488a0b8cfd0f43b7973771640fc593f56321f", - "sha256:246de67b99b6851627d945db38147d1b209a899311b1305dd84916f2b88526c6", - "sha256:2dee8e57f052ef5353cf608e0b4c871aee320dd1b87d351c28764fc0ca55f9f4", - "sha256:2efb1bd13885392adfda4614c33d3b68dee4921fd0ac1d3988f8cbb7d589e72a", - "sha256:2f4ac36d8e2b4cc1aa71df3dd84ff8efbe3bfb97ac41242fbcfc053c67434f46", - "sha256:3170c9399da12c9dc66366e9d14da8bf7147e1e9d9ea566067bbce7bb74bd9c2", - "sha256:3b1613dd5aee995ec6d4c69f00378bbd07614702a315a2cf6c1d21461fe17c23", - "sha256:3bb3d25a8e6c0aedd251753a79ae98a093c7e7b471faa3aa9a93a81431987ace", - "sha256:3bb7fda7260735efe66d5107fb7e6af6a7c04c7fce9b2514e04b7a74b06bf5dd", - "sha256:41b25eaa7d15909cf3ac4c96088c1f266a9a93ec44f87f1d13d4a0e86c81b982", - "sha256:45de3f87179c1823e6d9e32156fb14c1927fcc9aba21433f088fdfb555b77c10", - "sha256:46fb8c61d794b78ec7134a715a3e564aafc8f6b5e338417cb19fe9f57a5a9bf2", - "sha256:48021783bdf96e3d6de03a6e39a1171ed5bd7e8bb93fc84cc649d11490f87cea", - "sha256:4957669ef390f0e6719db3613ab3a7631e68424604a7b448f079bee145da6e09", - "sha256:5e86d77b090dbddbe78867a0275cb4df08ea195e660f1f7f13435a4649e954e5", - "sha256:6339d047dab2780cc6220f46306628e04d9750f02f983ddb37439ca47ced7149", - "sha256:681eb3d7e02e3c3655d1b16059fbfb605ac464c834a0c629048a30fad2b27489", - "sha256:6c409c0deba34f147f77efaa67b8e4bb83d2f11c8806405f76397ae5b8c0d1c9", - "sha256:7095f6fbfaa55defb6b733cfeb14efaae7a29f0b59d8cf213be4e7ca0b857b80", - "sha256:70c610f6cbe4b9fce272c407dd9d07e33e6bf7b4aa1b7ffb6f6ded8e634e3592", - "sha256:72814c01533f51d68702802d74f77ea026b5ec52793c791e2da806a3844a46c3", - "sha256:7a4826ad2bd6b07ca615c74ab91f32f6c96d08f6fcc3902ceeedaec8cdc3bcd6", - "sha256:7c70087bfee18a42b4040bb9ec1ca15a08242cf5867c58726530bdf3945672ed", - "sha256:855eafa5d5a2034b4621c74925d89c5efef61418570e5ef9b37717d9c796419c", - "sha256:8700f06d0ce6f128de3ccdbc1acaea1ee264d2caa9ca05daaf492fde7c2a7200", - "sha256:89f1b185a01fe560bc8ae5f619e924407efca2191b56ce749ec84982fc59a32a", - "sha256:8b2c760cfc7042b27ebdb4a43a4453bd829a5742503599144d54a032c5dc7e9e", - "sha256:8c2f5e83493748286002f9369f3e6607c565a6a90425a3a1fef5ae32a36d749d", - "sha256:8e098148dd37b4ce3baca71fb394c81dc5d9c7728c95df695d2dca218edf40e6", - "sha256:94aea8eff76ee6d1cdacb07dd2123a68283cb5569e0250feab1240058f53b623", - "sha256:95eb302ff792e12aba9a8b8f8474ab229a83c103d74a750ec0bd1c1eea32e669", - "sha256:9bd9b3b31adcb054116447ea22caa61a285d92e94d710aa5ec97992ff5eb7cf3", - "sha256:9e608aafdb55eb9f255034709e20d5a83b6d60c054df0802fa9c9883d0a937aa", - "sha256:a103b3a7069b62f5d4890ae1b8f0597618f628b286b03d4bc9195230b154bfa9", - "sha256:a386ebe437176aab38c041de1260cd3ea459c6ce5263594399880bbc398225b2", - "sha256:a38856a971c602f98472050165cea2cdc97709240373041b69030be15047691f", - "sha256:a401b4598e5d3f4a9a811f3daf42ee2291790c7f9d74b18d75d6e21dda98a1a1", - "sha256:a7647ebdfb9682b7bb97e2a5e7cb6ae735b1c25008a70b906aecca294ee96cf4", - "sha256:aaf63899c94de41fe3cf934601b0f7ccb6b428c6e4eeb80da72c58eab077b19a", - "sha256:b0dac0ff919ba34d4df1b6131f59ce95b08b9065233446be7e459f95554c0dc8", - "sha256:baacc6aee0b2ef6f3d308e197b5d7a81c0e70b06beae1f1fcacffdbd124fe0e3", - "sha256:bf420121d4c8dce6b889f0e8e4ec0ca34b7f40186203f06a946fa0276ba54029", - "sha256:c04a46716adde8d927adb9457bbe39cf473e1e2c2f5d0a16ceb837e5d841ad4f", - "sha256:c0b21078a4b56965e2b12f247467b234734491897e99c1d51cee628da9786959", - "sha256:c1c76a1743432b4b60ab3358c937a3fe1341c828ae6194108a94c69028247f22", - "sha256:c4983bf937209c57240cff65906b18bb35e64ae872da6a0db937d7b4af845dd7", - "sha256:c4fb39a81950ec280984b3a44f5bd12819953dc5fa3a7e6fa7a80db5ee853952", - "sha256:c57921cda3a80d0f2b8aec7e25c8aa14479ea92b5b51b6876d975d925a2ea346", - "sha256:c8063cf17b19661471ecbdb3df1c84f24ad2e389e326ccaf89e3fb2484d8dd7e", - "sha256:ccd16eb18a849fd8dcb23e23380e2f0a354e8daa0c984b8a732d9cfaba3a776d", - "sha256:cd6dbe0238f7743d0efe563ab46294f54f9bc8f4b9bcf57c3c666cc5bc9d1299", - "sha256:d62e51710986674142526ab9f78663ca2b0726066ae26b78b22e0f5e571238dd", - "sha256:db901e2ac34c931d73054d9797383d0f8009991e723dab15109740a63e7f902a", - "sha256:e03b8895a6990c9ab2cdcd0f2fe44088ca1c65ae592b8f795c3294af00a461c3", - "sha256:e1c8a2f4c69e08e89632defbfabec2feb8a8d99edc9f89ce33c4b9e36ab63037", - "sha256:e4b749b9cc6ee664a3300bb3a273c1ca8068c46be705b6c31cf5d276f8628a94", - "sha256:e6a5bf2cba5ae1bb80b154ed68a3cfa2fa00fde979a7f50d6598d3e17d9ac20c", - "sha256:e857a2232ba53ae940d3456f7533ce6ca98b81917d47adc3c7fd55dad8fab858", - "sha256:ee4006268ed33370957f55bf2e6f4d263eaf4dc3cfc473d1d90baff6ed36ce4a", - "sha256:eef9df1eefada2c09a5e7a40991b9fc6ac6ef20b1372abd48d2794a316dc0449", - "sha256:f058f6963fd82eb143c692cecdc89e075fa0828db2e5b291070485390b2f1c9c", - "sha256:f25c229a6ba38a35ae6e25ca1264621cc25d4d38dca2942a7fce0b67a4efe918", - "sha256:f2a1d0fd4242bd8643ce6f98927cf9c04540af6efa92323e9d3124f57727bfc1", - "sha256:f7560358a6811e52e9c4d142d497f1a6e10103d3a6881f18d04dbce3729c0e2c", - "sha256:f779d3ad205f108d14e99bb3859aa7dd8e9c68874617c72354d7ecaec2a054ac", - "sha256:f87f746ee241d30d6ed93969de31e5ffd09a2961a051e60ae6bddde9ec3583aa" + "sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027", + "sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087", + "sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786", + "sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8", + "sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09", + "sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185", + "sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574", + "sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e", + "sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519", + "sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898", + "sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269", + "sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3", + "sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f", + "sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6", + "sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8", + "sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a", + "sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73", + "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc", + "sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714", + "sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2", + "sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc", + "sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce", + "sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d", + "sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e", + "sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6", + "sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269", + "sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96", + "sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d", + "sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a", + "sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4", + "sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77", + "sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d", + "sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0", + "sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed", + "sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068", + "sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac", + "sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25", + "sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8", + "sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab", + "sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26", + "sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2", + "sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db", + "sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f", + "sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5", + "sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99", + "sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c", + "sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d", + "sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811", + "sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa", + "sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a", + "sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03", + "sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b", + "sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04", + "sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c", + "sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001", + "sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458", + "sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389", + "sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99", + "sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985", + "sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537", + "sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238", + "sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f", + "sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d", + "sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796", + "sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a", + "sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143", + "sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8", + "sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c", + "sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5", + "sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5", + "sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711", + "sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4", + "sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6", + "sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c", + "sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7", + "sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4", + "sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b", + "sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae", + "sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12", + "sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c", + "sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae", + "sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8", + "sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887", + "sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b", + "sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4", + "sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f", + "sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5", + "sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33", + "sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519", + "sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561" ], "markers": "python_full_version >= '3.7.0'", - "version": "==3.2.0" + "version": "==3.3.2" }, "cryptography": { "hashes": [ - "sha256:01f1d9e537f9a15b037d5d9ee442b8c22e3ae11ce65ea1f3316a41c78756b711", - "sha256:079347de771f9282fbfe0e0236c716686950c19dee1b76240ab09ce1624d76d7", - "sha256:182be4171f9332b6741ee818ec27daff9fb00349f706629f5cbf417bd50e66fd", - "sha256:192255f539d7a89f2102d07d7375b1e0a81f7478925b3bc2e0549ebf739dae0e", - "sha256:2a034bf7d9ca894720f2ec1d8b7b5832d7e363571828037f9e0c4f18c1b58a58", - "sha256:342f3767e25876751e14f8459ad85e77e660537ca0a066e10e75df9c9e9099f0", - "sha256:439c3cc4c0d42fa999b83ded80a9a1fb54d53c58d6e59234cfe97f241e6c781d", - "sha256:49c3222bb8f8e800aead2e376cbef687bc9e3cb9b58b29a261210456a7783d83", - "sha256:674b669d5daa64206c38e507808aae49904c988fa0a71c935e7006a3e1e83831", - "sha256:7a9a3bced53b7f09da251685224d6a260c3cb291768f54954e28f03ef14e3766", - "sha256:7af244b012711a26196450d34f483357e42aeddb04128885d95a69bd8b14b69b", - "sha256:7d230bf856164de164ecb615ccc14c7fc6de6906ddd5b491f3af90d3514c925c", - "sha256:84609ade00a6ec59a89729e87a503c6e36af98ddcd566d5f3be52e29ba993182", - "sha256:9a6673c1828db6270b76b22cc696f40cde9043eb90373da5c2f8f2158957f42f", - "sha256:9b6d717393dbae53d4e52684ef4f022444fc1cce3c48c38cb74fca29e1f08eaa", - "sha256:9c3fe6534d59d071ee82081ca3d71eed3210f76ebd0361798c74abc2bcf347d4", - "sha256:a719399b99377b218dac6cf547b6ec54e6ef20207b6165126a280b0ce97e0d2a", - "sha256:b332cba64d99a70c1e0836902720887fb4529ea49ea7f5462cf6640e095e11d2", - "sha256:d124682c7a23c9764e54ca9ab5b308b14b18eba02722b8659fb238546de83a76", - "sha256:d73f419a56d74fef257955f51b18d046f3506270a5fd2ac5febbfa259d6c0fa5", - "sha256:f0dc40e6f7aa37af01aba07277d3d64d5a03dc66d682097541ec4da03cc140ee", - "sha256:f14ad275364c8b4e525d018f6716537ae7b6d369c094805cae45300847e0894f", - "sha256:f772610fe364372de33d76edcd313636a25684edb94cee53fd790195f5989d14" + "sha256:014f58110f53237ace6a408b5beb6c427b64e084eb451ef25a28308270086494", + "sha256:1bbcce1a551e262dfbafb6e6252f1ae36a248e615ca44ba302df077a846a8806", + "sha256:203e92a75716d8cfb491dc47c79e17d0d9207ccffcbcb35f598fbe463ae3444d", + "sha256:27e613d7077ac613e399270253259d9d53872aaf657471473ebfc9a52935c062", + "sha256:2bd51274dcd59f09dd952afb696bf9c61a7a49dfc764c04dd33ef7a6b502a1e2", + "sha256:38926c50cff6f533f8a2dae3d7f19541432610d114a70808f0926d5aaa7121e4", + "sha256:511f4273808ab590912a93ddb4e3914dfd8a388fed883361b02dea3791f292e1", + "sha256:58d4e9129985185a06d849aa6df265bdd5a74ca6e1b736a77959b498e0505b85", + "sha256:5b43d1ea6b378b54a1dc99dd8a2b5be47658fe9a7ce0a58ff0b55f4b43ef2b84", + "sha256:61ec41068b7b74268fa86e3e9e12b9f0c21fcf65434571dbb13d954bceb08042", + "sha256:666ae11966643886c2987b3b721899d250855718d6d9ce41b521252a17985f4d", + "sha256:68aaecc4178e90719e95298515979814bda0cbada1256a4485414860bd7ab962", + "sha256:7c05650fe8023c5ed0d46793d4b7d7e6cd9c04e68eabe5b0aeea836e37bdcec2", + "sha256:80eda8b3e173f0f247f711eef62be51b599b5d425c429b5d4ca6a05e9e856baa", + "sha256:8385d98f6a3bf8bb2d65a73e17ed87a3ba84f6991c155691c51112075f9ffc5d", + "sha256:88cce104c36870d70c49c7c8fd22885875d950d9ee6ab54df2745f83ba0dc365", + "sha256:9d3cdb25fa98afdd3d0892d132b8d7139e2c087da1712041f6b762e4f807cc96", + "sha256:a575913fb06e05e6b4b814d7f7468c2c660e8bb16d8d5a1faf9b33ccc569dd47", + "sha256:ac119bb76b9faa00f48128b7f5679e1d8d437365c5d26f1c2c3f0da4ce1b553d", + "sha256:c1332724be35d23a854994ff0b66530119500b6053d0bd3363265f7e5e77288d", + "sha256:d03a475165f3134f773d1388aeb19c2d25ba88b6a9733c5c590b9ff7bbfa2e0c", + "sha256:d75601ad10b059ec832e78823b348bfa1a59f6b8d545db3a24fd44362a1564cb", + "sha256:de41fd81a41e53267cb020bb3a7212861da53a7d39f863585d13ea11049cf277", + "sha256:e710bf40870f4db63c3d7d929aa9e09e4e7ee219e703f949ec4073b4294f6172", + "sha256:ea25acb556320250756e53f9e20a4177515f012c9eaea17eb7587a8c4d8ae034", + "sha256:f98bf604c82c416bc829e490c700ca1553eafdf2912a91e23a79d97d9801372a", + "sha256:fba1007b3ef89946dbbb515aeeb41e30203b004f0b4b00e5e16078b518563289" ], "markers": "python_version >= '3.7'", - "version": "==41.0.2" + "version": "==43.0.1" }, "docker": { "hashes": [ - "sha256:aa6d17830045ba5ef0168d5eaa34d37beeb113948c413affe1d5991fc11f9a20", - "sha256:aecd2277b8bf8e506e484f6ab7aec39abe0038e29fa4a6d3ba86c3fe01844ed9" + "sha256:ad8c70e6e3f8926cb8a92619b832b4ea5299e2831c14284663184e200546fa6c", + "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0" ], - "markers": "python_version >= '3.7'", - "version": "==6.1.3" + "markers": "python_version >= '3.8'", + "version": "==7.1.0" }, "execnet": { "hashes": [ - "sha256:88256416ae766bc9e8895c76a87928c0012183da3cc4fc18016e6f050e025f41", - "sha256:cc59bc4423742fd71ad227122eb0dd44db51efb3dc4095b45ac9a08c770096af" + "sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc", + "sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3" ], - "markers": "python_version >= '3.7'", - "version": "==2.0.2" + "markers": "python_version >= '3.8'", + "version": "==2.1.1" }, "idna": { "hashes": [ - "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4", - "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2" + "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", + "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3" ], - "markers": "python_version >= '3.5'", - "version": "==3.4" + "markers": "python_version >= '3.6'", + "version": "==3.10" }, "iniconfig": { "hashes": [ @@ -237,66 +260,70 @@ }, "packaging": { "hashes": [ - "sha256:994793af429502c4ea2ebf6bf664629d07c1a9fe974af92966e4b8d2df7edc61", - "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f" + "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002", + "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124" ], - "markers": "python_version >= '3.7'", - "version": "==23.1" + "markers": "python_version >= '3.8'", + "version": "==24.1" }, "pluggy": { "hashes": [ - "sha256:c2fd55a7d7a3863cba1a013e4e2414658b1d07b6bc57b3919e0c63c9abb99849", - "sha256:d12f0c4b579b15f5e054301bb226ee85eeeba08ffec228092f8defbaa3a4c4b3" + "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1", + "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669" ], - "markers": "python_version >= '3.7'", - "version": "==1.2.0" + "markers": "python_version >= '3.8'", + "version": "==1.5.0" }, "psutil": { "hashes": [ - "sha256:104a5cc0e31baa2bcf67900be36acde157756b9c44017b86b2c049f11957887d", - "sha256:3c6f686f4225553615612f6d9bc21f1c0e305f75d7d8454f9b46e901778e7217", - "sha256:4aef137f3345082a3d3232187aeb4ac4ef959ba3d7c10c33dd73763fbc063da4", - "sha256:5410638e4df39c54d957fc51ce03048acd8e6d60abc0f5107af51e5fb566eb3c", - "sha256:5b9b8cb93f507e8dbaf22af6a2fd0ccbe8244bf30b1baad6b3954e935157ae3f", - "sha256:7a7dd9997128a0d928ed4fb2c2d57e5102bb6089027939f3b722f3a210f9a8da", - "sha256:89518112647f1276b03ca97b65cc7f64ca587b1eb0278383017c2a0dcc26cbe4", - "sha256:8c5f7c5a052d1d567db4ddd231a9d27a74e8e4a9c3f44b1032762bd7b9fdcd42", - "sha256:ab8ed1a1d77c95453db1ae00a3f9c50227ebd955437bcf2a574ba8adbf6a74d5", - "sha256:acf2aef9391710afded549ff602b5887d7a2349831ae4c26be7c807c0a39fac4", - "sha256:b258c0c1c9d145a1d5ceffab1134441c4c5113b2417fafff7315a917a026c3c9", - "sha256:be8929ce4313f9f8146caad4272f6abb8bf99fc6cf59344a3167ecd74f4f203f", - "sha256:c607bb3b57dc779d55e1554846352b4e358c10fff3abf3514a7a6601beebdb30", - "sha256:ea8518d152174e1249c4f2a1c89e3e6065941df2fa13a1ab45327716a23c2b48" - ], - "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", - "version": "==5.9.5" + "sha256:02b69001f44cc73c1c5279d02b30a817e339ceb258ad75997325e0e6169d8b35", + "sha256:1287c2b95f1c0a364d23bc6f2ea2365a8d4d9b726a3be7294296ff7ba97c17f0", + "sha256:1e7c870afcb7d91fdea2b37c24aeb08f98b6d67257a5cb0a8bc3ac68d0f1a68c", + "sha256:21f1fb635deccd510f69f485b87433460a603919b45e2a324ad65b0cc74f8fb1", + "sha256:33ea5e1c975250a720b3a6609c490db40dae5d83a4eb315170c4fe0d8b1f34b3", + "sha256:34859b8d8f423b86e4385ff3665d3f4d94be3cdf48221fbe476e883514fdb71c", + "sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd", + "sha256:6ec7588fb3ddaec7344a825afe298db83fe01bfaaab39155fa84cf1c0d6b13c3", + "sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0", + "sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2", + "sha256:a021da3e881cd935e64a3d0a20983bda0bb4cf80e4f74fa9bfcb1bc5785360c6", + "sha256:a495580d6bae27291324fe60cea0b5a7c23fa36a7cd35035a16d93bdcf076b9d", + "sha256:a9a3dbfb4de4f18174528d87cc352d1f788b7496991cca33c6996f40c9e3c92c", + "sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0", + "sha256:e2e8d0054fc88153ca0544f5c4d554d42e33df2e009c4ff42284ac9ebdef4132", + "sha256:fc8c9510cde0146432bbdb433322861ee8c3efbf8589865c8bf8d21cb30c4d14", + "sha256:ffe7fc9b6b36beadc8c322f84e1caff51e8703b88eee1da46d1e3a6ae11b4fd0" + ], + "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5'", + "version": "==6.0.0" }, "pycparser": { "hashes": [ - "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9", - "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206" + "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6", + "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc" ], - "version": "==2.21" + "markers": "python_version >= '3.8'", + "version": "==2.22" }, "pytest": { "hashes": [ - "sha256:78bf16451a2eb8c7a2ea98e32dc119fd2aa758f1d5d66dbf0a59d69a3969df32", - "sha256:b4bf8c45bd59934ed84001ad51e11b4ee40d40a1229d2c79f9c592b0a3f6bd8a" + "sha256:70b98107bd648308a7952b06e6ca9a50bc660be218d53c257cc1fc94fda10181", + "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2" ], - "markers": "python_version >= '3.7'", - "version": "==7.4.0" + "markers": "python_version >= '3.8'", + "version": "==8.3.3" }, "pytest-cs": { "git": "https://github.com/crowdsecurity/pytest-cs.git", - "ref": "4a3451084215053af8a48ff37507b4f86bf75c10" + "ref": "aea7e8549faa32f5e1d1f17755a5db3712396a2a" }, "pytest-datadir": { "hashes": [ - "sha256:095f441782b1b907587eca7227fdbae94be43f1c96b4b2cbcc6801a4645be1af", - "sha256:9f7a3c4def6ac4cac3cc8181139ab53bd2667231052bd40cb07081748d4420f0" + "sha256:1617ed92f9afda0c877e4eac91904b5f779d24ba8f5e438752e3ae39d8d2ee3f", + "sha256:34adf361bcc7b37961bbc1dfa8d25a4829e778bab461703c38a5c50ca9c36dc8" ], - "markers": "python_version >= '3.6'", - "version": "==1.4.1" + "markers": "python_version >= '3.8'", + "version": "==1.5.0" }, "pytest-dotenv": { "hashes": [ @@ -308,73 +335,87 @@ }, "pytest-xdist": { "hashes": [ - "sha256:d5ee0520eb1b7bcca50a60a518ab7a7707992812c578198f8b44fdfac78e8c93", - "sha256:ff9daa7793569e6a68544850fd3927cd257cc03a7ef76c95e86915355e82b5f2" + "sha256:cbb36f3d67e0c478baa57fa4edc8843887e0f6cfc42d677530a36d7472b32d8a", + "sha256:d075629c7e00b611df89f490a5063944bee7a4362a5ff11c7cc7824a03dfce24" ], "index": "pypi", - "version": "==3.3.1" + "markers": "python_version >= '3.7'", + "version": "==3.5.0" }, "python-dotenv": { "hashes": [ - "sha256:a8df96034aae6d2d50a4ebe8216326c61c3eb64836776504fcca410e5937a3ba", - "sha256:f5971a9226b701070a4bf2c38c89e5a3f0d64de8debda981d1db98583009122a" + "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca", + "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a" ], "markers": "python_version >= '3.8'", - "version": "==1.0.0" + "version": "==1.0.1" }, "pyyaml": { "hashes": [ - "sha256:01b45c0191e6d66c470b6cf1b9531a771a83c1c4208272ead47a3ae4f2f603bf", - "sha256:0283c35a6a9fbf047493e3a0ce8d79ef5030852c51e9d911a27badfde0605293", - "sha256:055d937d65826939cb044fc8c9b08889e8c743fdc6a32b33e2390f66013e449b", - "sha256:07751360502caac1c067a8132d150cf3d61339af5691fe9e87803040dbc5db57", - "sha256:0b4624f379dab24d3725ffde76559cff63d9ec94e1736b556dacdfebe5ab6d4b", - "sha256:0ce82d761c532fe4ec3f87fc45688bdd3a4c1dc5e0b4a19814b9009a29baefd4", - "sha256:1e4747bc279b4f613a09eb64bba2ba602d8a6664c6ce6396a4d0cd413a50ce07", - "sha256:213c60cd50106436cc818accf5baa1aba61c0189ff610f64f4a3e8c6726218ba", - "sha256:231710d57adfd809ef5d34183b8ed1eeae3f76459c18fb4a0b373ad56bedcdd9", - "sha256:277a0ef2981ca40581a47093e9e2d13b3f1fbbeffae064c1d21bfceba2030287", - "sha256:2cd5df3de48857ed0544b34e2d40e9fac445930039f3cfe4bcc592a1f836d513", - "sha256:40527857252b61eacd1d9af500c3337ba8deb8fc298940291486c465c8b46ec0", - "sha256:432557aa2c09802be39460360ddffd48156e30721f5e8d917f01d31694216782", - "sha256:473f9edb243cb1935ab5a084eb238d842fb8f404ed2193a915d1784b5a6b5fc0", - "sha256:48c346915c114f5fdb3ead70312bd042a953a8ce5c7106d5bfb1a5254e47da92", - "sha256:50602afada6d6cbfad699b0c7bb50d5ccffa7e46a3d738092afddc1f9758427f", - "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2", - "sha256:77f396e6ef4c73fdc33a9157446466f1cff553d979bd00ecb64385760c6babdc", - "sha256:81957921f441d50af23654aa6c5e5eaf9b06aba7f0a19c18a538dc7ef291c5a1", - "sha256:819b3830a1543db06c4d4b865e70ded25be52a2e0631ccd2f6a47a2822f2fd7c", - "sha256:897b80890765f037df3403d22bab41627ca8811ae55e9a722fd0392850ec4d86", - "sha256:98c4d36e99714e55cfbaaee6dd5badbc9a1ec339ebfc3b1f52e293aee6bb71a4", - "sha256:9df7ed3b3d2e0ecfe09e14741b857df43adb5a3ddadc919a2d94fbdf78fea53c", - "sha256:9fa600030013c4de8165339db93d182b9431076eb98eb40ee068700c9c813e34", - "sha256:a80a78046a72361de73f8f395f1f1e49f956c6be882eed58505a15f3e430962b", - "sha256:afa17f5bc4d1b10afd4466fd3a44dc0e245382deca5b3c353d8b757f9e3ecb8d", - "sha256:b3d267842bf12586ba6c734f89d1f5b871df0273157918b0ccefa29deb05c21c", - "sha256:b5b9eccad747aabaaffbc6064800670f0c297e52c12754eb1d976c57e4f74dcb", - "sha256:bfaef573a63ba8923503d27530362590ff4f576c626d86a9fed95822a8255fd7", - "sha256:c5687b8d43cf58545ade1fe3e055f70eac7a5a1a0bf42824308d868289a95737", - "sha256:cba8c411ef271aa037d7357a2bc8f9ee8b58b9965831d9e51baf703280dc73d3", - "sha256:d15a181d1ecd0d4270dc32edb46f7cb7733c7c508857278d3d378d14d606db2d", - "sha256:d4b0ba9512519522b118090257be113b9468d804b19d63c71dbcf4a48fa32358", - "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53", - "sha256:d4eccecf9adf6fbcc6861a38015c2a64f38b9d94838ac1810a9023a0609e1b78", - "sha256:d67d839ede4ed1b28a4e8909735fc992a923cdb84e618544973d7dfc71540803", - "sha256:daf496c58a8c52083df09b80c860005194014c3698698d1a57cbcfa182142a3a", - "sha256:dbad0e9d368bb989f4515da330b88a057617d16b6a8245084f1b05400f24609f", - "sha256:e61ceaab6f49fb8bdfaa0f92c4b57bcfbea54c09277b1b4f7ac376bfb7a7c174", - "sha256:f84fbc98b019fef2ee9a1cb3ce93e3187a6df0b2538a651bfb890254ba9f90b5" + "sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff", + "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48", + "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086", + "sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e", + "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133", + "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5", + "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484", + "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee", + "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5", + "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68", + "sha256:24471b829b3bf607e04e88d79542a9d48bb037c2267d7927a874e6c205ca7e9a", + "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf", + "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99", + "sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8", + "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85", + "sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19", + "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc", + "sha256:43fa96a3ca0d6b1812e01ced1044a003533c47f6ee8aca31724f78e93ccc089a", + "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1", + "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317", + "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c", + "sha256:6395c297d42274772abc367baaa79683958044e5d3835486c16da75d2a694631", + "sha256:688ba32a1cffef67fd2e9398a2efebaea461578b0923624778664cc1c914db5d", + "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652", + "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5", + "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e", + "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b", + "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8", + "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476", + "sha256:82d09873e40955485746739bcb8b4586983670466c23382c19cffecbf1fd8706", + "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", + "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237", + "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b", + "sha256:9056c1ecd25795207ad294bcf39f2db3d845767be0ea6e6a34d856f006006083", + "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180", + "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425", + "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e", + "sha256:a8786accb172bd8afb8be14490a16625cbc387036876ab6ba70912730faf8e1f", + "sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725", + "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183", + "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab", + "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774", + "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725", + "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e", + "sha256:d7fded462629cfa4b685c5416b949ebad6cec74af5e2d42905d41e257e0869f5", + "sha256:d84a1718ee396f54f3a086ea0a66d8e552b2ab2017ef8b420e92edbc841c352d", + "sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290", + "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44", + "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed", + "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4", + "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba", + "sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12", + "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4" ], - "markers": "python_version >= '3.6'", - "version": "==6.0" + "markers": "python_version >= '3.8'", + "version": "==6.0.2" }, "requests": { "hashes": [ - "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f", - "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1" + "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760", + "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6" ], - "markers": "python_version >= '3.7'", - "version": "==2.31.0" + "markers": "python_version >= '3.8'", + "version": "==2.32.3" }, "trustme": { "hashes": [ @@ -386,35 +427,20 @@ }, "urllib3": { "hashes": [ - "sha256:48e7fafa40319d358848e1bc6809b208340fafe2096f1725d05d67443d0483d1", - "sha256:bee28b5e56addb8226c96f7f13ac28cb4c301dd5ea8a6ca179c0b9835e032825" - ], - "markers": "python_version >= '3.7'", - "version": "==2.0.3" - }, - "websocket-client": { - "hashes": [ - "sha256:c951af98631d24f8df89ab1019fc365f2227c0892f12fd150e935607c79dd0dd", - "sha256:f1f9f2ad5291f0225a49efad77abf9e700b6fef553900623060dad6e26503b9d" + "sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac", + "sha256:e7d814a81dad81e6caf2ec9fdedb284ecc9c73076b62654547cc64ccdcae26e9" ], - "markers": "python_version >= '3.7'", - "version": "==1.6.1" + "markers": "python_version >= '3.8'", + "version": "==2.2.3" } }, "develop": { "asttokens": { "hashes": [ - "sha256:4622110b2a6f30b77e1473affaa97e711bc2f07d3f10848420ff1898edbe94f3", - "sha256:6b0ac9e93fb0335014d382b8fa9b3afa7df546984258005da0b9e7095b3deb1c" - ], - "version": "==2.2.1" - }, - "backcall": { - "hashes": [ - "sha256:5cbdbf27be5e7cfadb448baf0aa95508f91f2bbc6c6437cd9cd06e2a4c215e1e", - "sha256:fbbce6a29f263178a1f7915c1940bde0ec2b2a967566fe1c65c1dfb7422bd255" + "sha256:051ed49c3dcae8913ea7cd08e46a606dba30b79993209636c4875bc1d637bc24", + "sha256:b03869718ba9a6eb027e134bfdf69f38a236d681c83c160d510768af11254ba0" ], - "version": "==0.2.0" + "version": "==2.4.1" }, "decorator": { "hashes": [ @@ -426,10 +452,11 @@ }, "executing": { "hashes": [ - "sha256:0314a69e37426e3608aada02473b4161d4caf5a4b244d1d0c48072b8fee7bacc", - "sha256:19da64c18d2d851112f09c287f8d3dbbdf725ab0e569077efb6cdcbd3497c107" + "sha256:8d63781349375b5ebccc3142f4b30350c0cd9c79f921cde38be2be4637e98eaf", + "sha256:8ea27ddd260da8150fa5a708269c4a10e76161e2496ec3e587da9e3c0fe4b9ab" ], - "version": "==1.2.0" + "markers": "python_version >= '3.8'", + "version": "==2.1.0" }, "gnureadline": { "hashes": [ @@ -470,62 +497,56 @@ "sha256:e3ac6018ef05126d442af680aad863006ec19d02290561ac88b8b1c0b0cfc726" ], "index": "pypi", + "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", "version": "==0.13.13" }, "ipython": { "hashes": [ - "sha256:1d197b907b6ba441b692c48cf2a3a2de280dc0ac91a3405b39349a50272ca0a1", - "sha256:248aca623f5c99a6635bc3857677b7320b9b8039f99f070ee0d20a5ca5a8e6bf" + "sha256:0d0d15ca1e01faeb868ef56bc7ee5a0de5bd66885735682e8a322ae289a13d1a", + "sha256:530ef1e7bb693724d3cdc37287c80b07ad9b25986c007a53aa1857272dac3f35" ], "markers": "python_version >= '3.11'", - "version": "==8.14.0" + "version": "==8.28.0" }, "jedi": { "hashes": [ - "sha256:203c1fd9d969ab8f2119ec0a3342e0b49910045abe6af0a3ae83a5764d54639e", - "sha256:bae794c30d07f6d910d32a7048af09b5a39ed740918da923c6b780790ebac612" + "sha256:cf0496f3651bc65d7174ac1b7d043eff454892c708a87d1b683e57b569927ffd", + "sha256:e983c654fe5c02867aef4cdfce5a2fbb4a50adc0af145f70504238f18ef5e7e0" ], "markers": "python_version >= '3.6'", - "version": "==0.18.2" + "version": "==0.19.1" }, "matplotlib-inline": { "hashes": [ - "sha256:f1f41aab5328aa5aaea9b16d083b128102f8712542f819fe7e6a420ff581b311", - "sha256:f887e5f10ba98e8d2b150ddcf4702c1e5f8b3a20005eb0f74bfdbd360ee6f304" + "sha256:8423b23ec666be3d16e16b60bdd8ac4e86e840ebd1dd11a30b9f117f2fa0ab90", + "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca" ], - "markers": "python_version >= '3.5'", - "version": "==0.1.6" + "markers": "python_version >= '3.8'", + "version": "==0.1.7" }, "parso": { "hashes": [ - "sha256:8c07be290bb59f03588915921e29e8a50002acaf2cdc5fa0e0114f91709fafa0", - "sha256:c001d4636cd3aecdaf33cbb40aebb59b094be2a74c556778ef5576c175e19e75" + "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18", + "sha256:eb3a7b58240fb99099a345571deecc0f9540ea5f4dd2fe14c2a99d6b281ab92d" ], "markers": "python_version >= '3.6'", - "version": "==0.8.3" + "version": "==0.8.4" }, "pexpect": { "hashes": [ - "sha256:0b48a55dcb3c05f3329815901ea4fc1537514d6ba867a152b581d69ae3710937", - "sha256:fc65a43959d153d0114afe13997d439c22823a27cefceb5ff35c2178c6784c0c" - ], - "markers": "sys_platform != 'win32'", - "version": "==4.8.0" - }, - "pickleshare": { - "hashes": [ - "sha256:87683d47965c1da65cdacaf31c8441d12b8044cdec9aca500cd78fc2c683afca", - "sha256:9649af414d74d4df115d5d718f82acb59c9d418196b7b4290ed47a12ce62df56" + "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523", + "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f" ], - "version": "==0.7.5" + "markers": "sys_platform != 'win32' and sys_platform != 'emscripten'", + "version": "==4.9.0" }, "prompt-toolkit": { "hashes": [ - "sha256:04505ade687dc26dc4284b1ad19a83be2f2afe83e7a828ace0c72f3a1df72aac", - "sha256:9dffbe1d8acf91e3de75f3b544e4842382fc06c6babe903ac9acb74dc6e08d88" + "sha256:d6623ab0477a80df74e646bdbc93621143f5caf104206aa29294d53de1a03d90", + "sha256:f49a827f90062e411f1ce1f854f2aedb3c23353244f8108b89283587397ac10e" ], "markers": "python_full_version >= '3.7.0'", - "version": "==3.0.39" + "version": "==3.0.48" }, "ptyprocess": { "hashes": [ @@ -536,18 +557,18 @@ }, "pure-eval": { "hashes": [ - "sha256:01eaab343580944bc56080ebe0a674b39ec44a945e6d09ba7db3cb8cec289350", - "sha256:2b45320af6dfaa1750f543d714b6d1c520a1688dec6fd24d339063ce0aaa9ac3" + "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0", + "sha256:5f4e983f40564c576c7c8635ae88db5956bb2229d7e9237d03b3c0b0190eaf42" ], - "version": "==0.2.2" + "version": "==0.2.3" }, "pygments": { "hashes": [ - "sha256:8ace4d3c1dd481894b2005f560ead0f9f19ee64fe983366be1a21e171d12775c", - "sha256:db2db3deb4b4179f399a09054b023b6a586b76499d36965813c71aa8ed7b5fd1" + "sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199", + "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a" ], - "markers": "python_version >= '3.7'", - "version": "==2.15.1" + "markers": "python_version >= '3.8'", + "version": "==2.18.0" }, "six": { "hashes": [ @@ -559,25 +580,25 @@ }, "stack-data": { "hashes": [ - "sha256:32d2dd0376772d01b6cb9fc996f3c8b57a357089dec328ed4b6553d037eaf815", - "sha256:cbb2a53eb64e5785878201a97ed7c7b94883f48b87bfb0bbe8b623c74679e4a8" + "sha256:836a778de4fec4dcd1dcd89ed8abff8a221f58308462e1c4aa2a3cf30148f0b9", + "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695" ], - "version": "==0.6.2" + "version": "==0.6.3" }, "traitlets": { "hashes": [ - "sha256:9e6ec080259b9a5940c797d58b613b5e31441c2257b87c2e795c5228ae80d2d8", - "sha256:f6cde21a9c68cf756af02035f72d5a723bf607e862e7be33ece505abf4a3bad9" + "sha256:9ed0579d3502c94b4b3732ac120375cda96f923114522847de4b3bb98b96b6b7", + "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f" ], - "markers": "python_version >= '3.7'", - "version": "==5.9.0" + "markers": "python_version >= '3.8'", + "version": "==5.14.3" }, "wcwidth": { "hashes": [ - "sha256:795b138f6875577cd91bba52baf9e445cd5118fd32723b460e30a0af30ea230e", - "sha256:a5220780a404dbe3353789870978e472cfe477761f06ee55077256e509b156d0" + "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859", + "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5" ], - "version": "==0.2.6" + "version": "==0.2.13" } } } diff --git a/docker/test/default.env b/docker/test/default.env index c46fdab7f1d..9607c8aaa5b 100644 --- a/docker/test/default.env +++ b/docker/test/default.env @@ -6,7 +6,7 @@ CROWDSEC_TEST_VERSION="dev" # All of the following flavors will be tested when using the "flavor" fixture CROWDSEC_TEST_FLAVORS="full" # CROWDSEC_TEST_FLAVORS="full,slim,debian" -# CROWDSEC_TEST_FLAVORS="full,slim,debian,geoip,plugins-debian-slim,debian-geoip,debian-plugins" +# CROWDSEC_TEST_FLAVORS="full,slim,debian,debian-slim" # network to use CROWDSEC_TEST_NETWORK="net-test" diff --git a/docker/test/tests/test_agent.py b/docker/test/tests/test_agent.py index e1ede3f8927..e55d11af850 100644 --- a/docker/test/tests/test_agent.py +++ b/docker/test/tests/test_agent.py @@ -13,7 +13,7 @@ def test_no_agent(crowdsec, flavor): 'DISABLE_AGENT': 'true', } with crowdsec(flavor=flavor, environment=env) as cs: - cs.wait_for_log("*CrowdSec Local API listening on 0.0.0.0:8080*") + cs.wait_for_log("*CrowdSec Local API listening on *:8080*") cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) res = cs.cont.exec_run('cscli lapi status') assert res.exit_code == 0 @@ -37,7 +37,7 @@ def test_machine_register(crowdsec, flavor, tmp_path_factory): with crowdsec(flavor=flavor, environment=env, volumes=volumes) as cs: cs.wait_for_log([ "*Generate local agent credentials*", - "*CrowdSec Local API listening on 0.0.0.0:8080*", + "*CrowdSec Local API listening on *:8080*", ]) cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) res = cs.cont.exec_run('cscli lapi status') @@ -50,7 +50,7 @@ def test_machine_register(crowdsec, flavor, tmp_path_factory): with crowdsec(flavor=flavor, environment=env, volumes=volumes) as cs: cs.wait_for_log([ "*Generate local agent credentials*", - "*CrowdSec Local API listening on 0.0.0.0:8080*", + "*CrowdSec Local API listening on *:8080*", ]) cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) res = cs.cont.exec_run('cscli lapi status') @@ -65,7 +65,7 @@ def test_machine_register(crowdsec, flavor, tmp_path_factory): with crowdsec(flavor=flavor, environment=env, volumes=volumes) as cs: cs.wait_for_log([ "*Generate local agent credentials*", - "*CrowdSec Local API listening on 0.0.0.0:8080*", + "*CrowdSec Local API listening on *:8080*", ]) cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) res = cs.cont.exec_run('cscli lapi status') @@ -78,7 +78,7 @@ def test_machine_register(crowdsec, flavor, tmp_path_factory): with crowdsec(flavor=flavor, environment=env, volumes=volumes) as cs: cs.wait_for_log([ "*Local agent already registered*", - "*CrowdSec Local API listening on 0.0.0.0:8080*", + "*CrowdSec Local API listening on *:8080*", ]) cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) res = cs.cont.exec_run('cscli lapi status') diff --git a/docker/test/tests/test_agent_only.py b/docker/test/tests/test_agent_only.py index d9db3ca3042..038b726e324 100644 --- a/docker/test/tests/test_agent_only.py +++ b/docker/test/tests/test_agent_only.py @@ -29,7 +29,7 @@ def test_split_lapi_agent(crowdsec, flavor): cs_agent = crowdsec(name=agentname, environment=agent_env, flavor=flavor) with cs_lapi as lapi: - lapi.wait_for_log("*CrowdSec Local API listening on 0.0.0.0:8080*") + lapi.wait_for_log("*CrowdSec Local API listening on *:8080*") lapi.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) with cs_agent as agent: agent.wait_for_log("*Starting processing data*") diff --git a/docker/test/tests/test_bouncer.py b/docker/test/tests/test_bouncer.py index 1324c3bd38c..98b86de858c 100644 --- a/docker/test/tests/test_bouncer.py +++ b/docker/test/tests/test_bouncer.py @@ -36,8 +36,6 @@ def test_register_bouncer_env(crowdsec, flavor): bouncer1, bouncer2 = j assert bouncer1['name'] == 'bouncer1name' assert bouncer2['name'] == 'bouncer2name' - assert bouncer1['api_key'] == hex512('bouncer1key') - assert bouncer2['api_key'] == hex512('bouncer2key') # add a second bouncer at runtime res = cs.cont.exec_run('cscli bouncers add bouncer3name -k bouncer3key') @@ -48,7 +46,6 @@ def test_register_bouncer_env(crowdsec, flavor): assert len(j) == 3 bouncer3 = j[2] assert bouncer3['name'] == 'bouncer3name' - assert bouncer3['api_key'] == hex512('bouncer3key') # remove all bouncers res = cs.cont.exec_run('cscli bouncers delete bouncer1name bouncer2name bouncer3name') diff --git a/docker/test/tests/test_capi_whitelists.py b/docker/test/tests/test_capi_whitelists.py index f8e3c17c026..19378ba86f0 100644 --- a/docker/test/tests/test_capi_whitelists.py +++ b/docker/test/tests/test_capi_whitelists.py @@ -25,7 +25,7 @@ def test_capi_whitelists(crowdsec, tmp_path_factory, flavor,): with crowdsec(flavor=flavor, environment=env, volumes=volumes) as cs: cs.wait_for_log("*Starting processing data*") cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) - res = cs.cont.exec_run(f'cscli config show-yaml') + res = cs.cont.exec_run('cscli config show-yaml') assert res.exit_code == 0 stdout = res.output.decode() y = yaml.safe_load(stdout) diff --git a/docker/test/tests/test_flavors.py b/docker/test/tests/test_flavors.py index c6aba888dbe..7e78b8d681b 100644 --- a/docker/test/tests/test_flavors.py +++ b/docker/test/tests/test_flavors.py @@ -22,6 +22,7 @@ def test_cscli_lapi(crowdsec, flavor): assert "You can successfully interact with Local API (LAPI)" in stdout +@pytest.mark.skip(reason="currently broken by hub upgrade") def test_flavor_content(crowdsec, flavor): """Test flavor contents""" with crowdsec(flavor=flavor) as cs: @@ -41,7 +42,7 @@ def test_flavor_content(crowdsec, flavor): x = cs.cont.exec_run( 'ls -1 /usr/local/lib/crowdsec/plugins/') stdout = x.output.decode() - if 'slim' in flavor or 'geoip' in flavor: + if 'slim' in flavor: # the exact return code and full message depend # on the 'ls' implementation (busybox vs coreutils) assert x.exit_code != 0 @@ -50,9 +51,11 @@ def test_flavor_content(crowdsec, flavor): assert 'notification-http' not in stdout assert 'notification-slack' not in stdout assert 'notification-splunk' not in stdout + assert 'notification-sentinel' not in stdout else: assert x.exit_code == 0 assert 'notification-email' in stdout assert 'notification-http' in stdout assert 'notification-slack' in stdout assert 'notification-splunk' in stdout + assert 'notification-sentinel' in stdout diff --git a/docker/test/tests/test_hub_collections.py b/docker/test/tests/test_hub_collections.py index b890bebb9c6..962f8ff8df4 100644 --- a/docker/test/tests/test_hub_collections.py +++ b/docker/test/tests/test_hub_collections.py @@ -30,8 +30,8 @@ def test_install_two_collections(crowdsec, flavor): cs.wait_for_log([ # f'*collections install "{it1}"*' # f'*collections install "{it2}"*' - f'*Enabled collections : {it1}*', - f'*Enabled collections : {it2}*', + f'*Enabled collections: {it1}*', + f'*Enabled collections: {it2}*', ]) @@ -72,7 +72,7 @@ def test_install_and_disable_collection(crowdsec, flavor): assert it not in items logs = cs.log_lines() # check that there was no attempt to install - assert not any(f'Enabled collections : {it}' in line for line in logs) + assert not any(f'Enabled collections: {it}' in line for line in logs) # already done in bats, prividing here as example of a somewhat complex test @@ -91,7 +91,7 @@ def test_taint_bubble_up(crowdsec, tmp_path_factory, flavor): # implicit check for tainted=False assert items[coll]['status'] == 'enabled' cs.wait_for_log([ - f'*Enabled collections : {coll}*', + f'*Enabled collections: {coll}*', ]) scenario = 'crowdsecurity/http-crawl-non_statics' diff --git a/docker/test/tests/test_hub_scenarios.py b/docker/test/tests/test_hub_scenarios.py index a60ede667b8..2a8c3a275f2 100644 --- a/docker/test/tests/test_hub_scenarios.py +++ b/docker/test/tests/test_hub_scenarios.py @@ -21,8 +21,8 @@ def test_install_two_scenarios(crowdsec, flavor): } with crowdsec(flavor=flavor, environment=env) as cs: cs.wait_for_log([ - f'*scenarios install "{it1}*"', - f'*scenarios install "{it2}*"', + f'*scenarios install "{it1}"*', + f'*scenarios install "{it2}"*', "*Starting processing data*" ]) cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) diff --git a/docker/test/tests/test_local_api_url.py b/docker/test/tests/test_local_api_url.py index 262e8fbefce..aa90c9fb798 100644 --- a/docker/test/tests/test_local_api_url.py +++ b/docker/test/tests/test_local_api_url.py @@ -11,7 +11,7 @@ def test_local_api_url_default(crowdsec, flavor): """Test LOCAL_API_URL (default)""" with crowdsec(flavor=flavor) as cs: cs.wait_for_log([ - "*CrowdSec Local API listening on 0.0.0.0:8080*", + "*CrowdSec Local API listening on *:8080*", "*Starting processing data*" ]) cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) @@ -29,7 +29,7 @@ def test_local_api_url(crowdsec, flavor): } with crowdsec(flavor=flavor, environment=env) as cs: cs.wait_for_log([ - "*CrowdSec Local API listening on 0.0.0.0:8080*", + "*CrowdSec Local API listening on *:8080*", "*Starting processing data*" ]) cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) @@ -54,7 +54,7 @@ def test_local_api_url_ipv6(crowdsec, flavor): with crowdsec(flavor=flavor, environment=env) as cs: cs.wait_for_log([ "*Starting processing data*", - "*CrowdSec Local API listening on 0.0.0.0:8080*", + "*CrowdSec Local API listening on [::1]:8080*", ]) cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) res = cs.cont.exec_run('cscli lapi status') diff --git a/docker/test/tests/test_local_item.py b/docker/test/tests/test_local_item.py new file mode 100644 index 00000000000..3d6ac2fc954 --- /dev/null +++ b/docker/test/tests/test_local_item.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python + +""" +Test bind-mounting local items +""" + +from http import HTTPStatus +import json + +import pytest + +pytestmark = pytest.mark.docker + + +def test_inject_local_item(crowdsec, tmp_path_factory, flavor): + """Test mounting a custom whitelist at startup""" + + localitems = tmp_path_factory.mktemp('localitems') + custom_whitelists = localitems / 'custom_whitelists.yaml' + + with open(custom_whitelists, 'w') as f: + f.write('{"whitelist":{"reason":"Good IPs","ip":["1.2.3.4"]}}') + + volumes = { + custom_whitelists: {'bind': '/etc/crowdsec/parsers/s02-enrich/custom_whitelists.yaml'} + } + + with crowdsec(flavor=flavor, volumes=volumes) as cs: + cs.wait_for_log([ + "*Starting processing data*" + ]) + cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) + + # the parser should be enabled + res = cs.cont.exec_run('cscli parsers list -o json') + assert res.exit_code == 0 + j = json.loads(res.output) + items = {c['name']: c for c in j['parsers']} + assert items['custom_whitelists.yaml']['status'] == 'enabled,local' + + # regression test: the linux collection should not be tainted + # (the parsers were not copied from /staging when using "cp -an" with local parsers) + res = cs.cont.exec_run('cscli collections inspect crowdsecurity/linux -o json') + assert res.exit_code == 0 + j = json.loads(res.output) + # crowdsec <= 1.5.5 omits a "tainted" when it's false + assert j.get('tainted', False) is False diff --git a/docker/test/tests/test_tls.py b/docker/test/tests/test_tls.py index cea29b9fca2..d2f512fcbc1 100644 --- a/docker/test/tests/test_tls.py +++ b/docker/test/tests/test_tls.py @@ -4,7 +4,7 @@ Test agent-lapi and cscli-lapi communication via TLS, on the same container. """ -import random +import uuid from pytest_cs import Status @@ -22,8 +22,7 @@ def test_missing_key_file(crowdsec, flavor): } with crowdsec(flavor=flavor, environment=env, wait_status=Status.EXITED) as cs: - # XXX: this message appears twice, is that normal? - cs.wait_for_log("*while serving local API: missing TLS key file*") + cs.wait_for_log("*local API server stopped with error: missing TLS key file*") def test_missing_cert_file(crowdsec, flavor): @@ -35,7 +34,7 @@ def test_missing_cert_file(crowdsec, flavor): } with crowdsec(flavor=flavor, environment=env, wait_status=Status.EXITED) as cs: - cs.wait_for_log("*while serving local API: missing TLS cert file*") + cs.wait_for_log("*local API server stopped with error: missing TLS cert file*") def test_tls_missing_ca(crowdsec, flavor, certs_dir): @@ -140,7 +139,7 @@ def test_tls_lapi_var(crowdsec, flavor, certs_dir): def test_tls_split_lapi_agent(crowdsec, flavor, certs_dir): """Server-only certificate, split containers""" - rand = random.randint(0, 10000) + rand = uuid.uuid1() lapiname = 'lapi-' + str(rand) agentname = 'agent-' + str(rand) @@ -174,7 +173,7 @@ def test_tls_split_lapi_agent(crowdsec, flavor, certs_dir): with cs_lapi as lapi: lapi.wait_for_log([ "*(tls) Client Auth Type set to VerifyClientCertIfGiven*", - "*CrowdSec Local API listening on 0.0.0.0:8080*" + "*CrowdSec Local API listening on *:8080*" ]) # TODO: wait_for_https lapi.wait_for_http(8080, '/health', want_status=None) @@ -193,7 +192,7 @@ def test_tls_split_lapi_agent(crowdsec, flavor, certs_dir): def test_tls_mutual_split_lapi_agent(crowdsec, flavor, certs_dir): """Server and client certificates, split containers""" - rand = random.randint(0, 10000) + rand = uuid.uuid1() lapiname = 'lapi-' + str(rand) agentname = 'agent-' + str(rand) @@ -225,7 +224,7 @@ def test_tls_mutual_split_lapi_agent(crowdsec, flavor, certs_dir): with cs_lapi as lapi: lapi.wait_for_log([ "*(tls) Client Auth Type set to VerifyClientCertIfGiven*", - "*CrowdSec Local API listening on 0.0.0.0:8080*" + "*CrowdSec Local API listening on *:8080*" ]) # TODO: wait_for_https lapi.wait_for_http(8080, '/health', want_status=None) @@ -241,10 +240,10 @@ def test_tls_mutual_split_lapi_agent(crowdsec, flavor, certs_dir): assert "You can successfully interact with Local API (LAPI)" in stdout -def test_tls_client_ou(crowdsec, certs_dir): +def test_tls_client_ou(crowdsec, flavor, certs_dir): """Check behavior of client certificate vs AGENTS_ALLOWED_OU""" - rand = random.randint(0, 10000) + rand = uuid.uuid1() lapiname = 'lapi-' + str(rand) agentname = 'agent-' + str(rand) @@ -270,30 +269,43 @@ def test_tls_client_ou(crowdsec, certs_dir): certs_dir(lapi_hostname=lapiname, agent_ou='custom-client-ou'): {'bind': '/etc/ssl/crowdsec', 'mode': 'ro'}, } - cs_lapi = crowdsec(name=lapiname, environment=lapi_env, volumes=volumes) - cs_agent = crowdsec(name=agentname, environment=agent_env, volumes=volumes) + cs_lapi = crowdsec(flavor=flavor, name=lapiname, environment=lapi_env, volumes=volumes) + cs_agent = crowdsec(flavor=flavor, name=agentname, environment=agent_env, volumes=volumes) with cs_lapi as lapi: lapi.wait_for_log([ "*(tls) Client Auth Type set to VerifyClientCertIfGiven*", - "*CrowdSec Local API listening on 0.0.0.0:8080*" + "*CrowdSec Local API listening on *:8080*" ]) # TODO: wait_for_https lapi.wait_for_http(8080, '/health', want_status=None) with cs_agent as agent: lapi.wait_for_log([ - "*client certificate OU (?custom-client-ou?) doesn't match expected OU (?agent-ou?)*", + "*client certificate OU ?custom-client-ou? doesn't match expected OU ?agent-ou?*", ]) lapi_env['AGENTS_ALLOWED_OU'] = 'custom-client-ou' - cs_lapi = crowdsec(name=lapiname, environment=lapi_env, volumes=volumes) - cs_agent = crowdsec(name=agentname, environment=agent_env, volumes=volumes) + # change container names to avoid conflict + # recreate certificates because they need the new hostname + + rand = uuid.uuid1() + lapiname = 'lapi-' + str(rand) + agentname = 'agent-' + str(rand) + + agent_env['LOCAL_API_URL'] = f'https://{lapiname}:8080' + + volumes = { + certs_dir(lapi_hostname=lapiname, agent_ou='custom-client-ou'): {'bind': '/etc/ssl/crowdsec', 'mode': 'ro'}, + } + + cs_lapi = crowdsec(flavor=flavor, name=lapiname, environment=lapi_env, volumes=volumes) + cs_agent = crowdsec(flavor=flavor, name=agentname, environment=agent_env, volumes=volumes) with cs_lapi as lapi: lapi.wait_for_log([ "*(tls) Client Auth Type set to VerifyClientCertIfGiven*", - "*CrowdSec Local API listening on 0.0.0.0:8080*" + "*CrowdSec Local API listening on *:8080*" ]) # TODO: wait_for_https lapi.wait_for_http(8080, '/health', want_status=None) diff --git a/go.mod b/go.mod index 2af7b3adf39..f28f21c6eb4 100644 --- a/go.mod +++ b/go.mod @@ -1,38 +1,42 @@ module github.com/crowdsecurity/crowdsec -go 1.20 +go 1.22 + +// Don't use the toolchain directive to avoid uncontrolled downloads during +// a build, especially in sandboxed environments (freebsd, gentoo...). +// toolchain go1.21.3 require ( - entgo.io/ent v0.11.3 - github.com/AlecAivazis/survey/v2 v2.2.7 - github.com/Masterminds/semver/v3 v3.1.1 - github.com/Masterminds/sprig/v3 v3.2.2 - github.com/agext/levenshtein v1.2.1 - github.com/alexliesenfeld/health v0.5.1 - github.com/antonmedv/expr v1.12.5 - github.com/appleboy/gin-jwt/v2 v2.8.0 - github.com/aquasecurity/table v1.8.0 - github.com/aws/aws-lambda-go v1.38.0 - github.com/aws/aws-sdk-go v1.42.25 - github.com/beevik/etree v1.1.0 - github.com/blackfireio/osinfo v1.0.3 + entgo.io/ent v0.13.1 + github.com/AlecAivazis/survey/v2 v2.3.7 + github.com/Masterminds/semver/v3 v3.2.1 + github.com/Masterminds/sprig/v3 v3.2.3 + github.com/agext/levenshtein v1.2.3 + github.com/alexliesenfeld/health v0.8.0 + github.com/appleboy/gin-jwt/v2 v2.9.2 + github.com/aws/aws-lambda-go v1.47.0 + github.com/aws/aws-sdk-go v1.52.0 + github.com/beevik/etree v1.4.1 + github.com/blackfireio/osinfo v1.0.5 github.com/bluele/gcache v0.0.2 github.com/buger/jsonparser v1.1.1 - github.com/c-robinson/iplib v1.0.3 - github.com/cespare/xxhash/v2 v2.2.0 + github.com/c-robinson/iplib v1.0.8 + github.com/cespare/xxhash/v2 v2.3.0 + github.com/corazawaf/libinjection-go v0.1.2 + github.com/crowdsecurity/coraza/v3 v3.0.0-20240108124027-a62b8d8e5607 github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26 - github.com/crowdsecurity/go-cs-lib v0.0.2 - github.com/crowdsecurity/grokky v0.2.1 + github.com/crowdsecurity/go-cs-lib v0.0.15 + github.com/crowdsecurity/grokky v0.2.2 github.com/crowdsecurity/machineid v1.0.2 github.com/davecgh/go-spew v1.1.1 - github.com/dghubble/sling v1.3.0 - github.com/docker/docker v24.0.4+incompatible + github.com/dghubble/sling v1.4.2 + github.com/docker/docker v24.0.9+incompatible github.com/docker/go-connections v0.4.0 - github.com/enescakir/emoji v1.0.0 - github.com/fatih/color v1.15.0 - github.com/fsnotify/fsnotify v1.6.0 + github.com/expr-lang/expr v1.16.9 + github.com/fatih/color v1.16.0 + github.com/fsnotify/fsnotify v1.7.0 github.com/gin-gonic/gin v1.9.1 - github.com/go-co-op/gocron v1.17.0 + github.com/go-co-op/gocron v1.37.0 github.com/go-openapi/errors v0.20.1 github.com/go-openapi/strfmt v0.19.11 github.com/go-openapi/swag v0.22.3 @@ -40,111 +44,118 @@ require ( github.com/go-sql-driver/mysql v1.6.0 github.com/goccy/go-yaml v1.11.0 github.com/gofrs/uuid v4.0.0+incompatible - github.com/golang-jwt/jwt/v4 v4.4.2 - github.com/google/go-querystring v1.0.0 - github.com/google/uuid v1.3.0 - github.com/google/winops v0.0.0-20211216095627-f0e86eb1453b + github.com/golang-jwt/jwt/v4 v4.5.0 + github.com/google/go-querystring v1.1.0 + github.com/google/uuid v1.6.0 + github.com/google/winops v0.0.0-20230712152054-af9b550d0601 github.com/goombaio/namegenerator v0.0.0-20181006234301-989e774b106e + github.com/gorilla/websocket v1.5.0 github.com/hashicorp/go-hclog v1.5.0 github.com/hashicorp/go-plugin v1.4.10 github.com/hashicorp/go-version v1.2.1 + github.com/hexops/gotextdiff v1.0.3 github.com/ivanpirog/coloredcobra v1.0.1 - github.com/jackc/pgx/v4 v4.14.1 + github.com/jackc/pgx/v4 v4.18.2 github.com/jarcoal/httpmock v1.1.0 + github.com/jedib0t/go-pretty/v6 v6.5.9 github.com/jszwec/csvutil v1.5.1 github.com/lithammer/dedent v1.1.0 - github.com/mattn/go-isatty v0.0.19 + github.com/mattn/go-isatty v0.0.20 github.com/mattn/go-sqlite3 v1.14.16 github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 github.com/nxadm/tail v1.4.8 - github.com/oschwald/geoip2-golang v1.4.0 - github.com/oschwald/maxminddb-golang v1.8.0 + github.com/oschwald/geoip2-golang v1.9.0 + github.com/oschwald/maxminddb-golang v1.12.0 github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 github.com/pkg/errors v0.9.1 - github.com/prometheus/client_golang v1.14.0 - github.com/prometheus/client_model v0.3.0 + github.com/prometheus/client_golang v1.16.0 + github.com/prometheus/client_model v0.4.0 github.com/prometheus/prom2json v1.3.0 github.com/r3labs/diff/v2 v2.14.1 - github.com/segmentio/kafka-go v0.4.34 + github.com/sanity-io/litter v1.5.5 + github.com/segmentio/kafka-go v0.4.45 github.com/shirou/gopsutil/v3 v3.23.5 github.com/sirupsen/logrus v1.9.3 - github.com/spf13/cobra v1.7.0 - github.com/stretchr/testify v1.8.3 + github.com/slack-go/slack v0.12.2 + github.com/spf13/cobra v1.8.0 + github.com/stretchr/testify v1.9.0 github.com/umahmood/haversine v0.0.0-20151105152445-808ab04add26 - github.com/wasilibs/go-re2 v1.3.0 - golang.org/x/crypto v0.9.0 - golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 - golang.org/x/mod v0.11.0 - golang.org/x/sys v0.9.0 - google.golang.org/grpc v1.56.1 - google.golang.org/protobuf v1.30.0 + github.com/wasilibs/go-re2 v1.7.0 + github.com/xhit/go-simple-mail/v2 v2.16.0 + golang.org/x/crypto v0.26.0 + golang.org/x/mod v0.17.0 + golang.org/x/sys v0.24.0 + golang.org/x/text v0.17.0 + google.golang.org/grpc v1.67.1 + google.golang.org/protobuf v1.34.2 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v3 v3.0.1 - k8s.io/apiserver v0.27.3 + k8s.io/apiserver v0.28.4 ) require ( - ariga.io/atlas v0.7.2-0.20220927111110-867ee0cca56a // indirect + ariga.io/atlas v0.19.1-0.20240203083654-5948b60a8e43 // indirect github.com/Masterminds/goutils v1.1.1 // indirect github.com/Microsoft/go-winio v0.6.1 // indirect github.com/ahmetalpbalkan/dlog v0.0.0-20170105205344-4fb5f8204f26 // indirect github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef // indirect github.com/beorn7/perks v1.0.1 // indirect - github.com/bytedance/sonic v1.9.1 // indirect - github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/bytedance/sonic v1.10.2 // indirect + github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d // indirect + github.com/chenzhuoyu/iasm v0.9.1 // indirect github.com/coreos/go-systemd/v22 v22.5.0 // indirect - github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect + github.com/cpuguy83/go-md2man/v2 v2.0.3 // indirect github.com/creack/pty v1.1.18 // indirect github.com/docker/distribution v2.8.2+incompatible // indirect github.com/docker/go-units v0.5.0 // indirect - github.com/gabriel-vasile/mimetype v1.4.2 // indirect + github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gin-contrib/sse v0.1.0 // indirect - github.com/go-logr/logr v1.2.3 // indirect + github.com/go-logr/logr v1.2.4 // indirect github.com/go-ole/go-ole v1.2.6 // indirect github.com/go-openapi/analysis v0.19.16 // indirect github.com/go-openapi/inflect v0.19.0 // indirect github.com/go-openapi/jsonpointer v0.19.6 // indirect - github.com/go-openapi/jsonreference v0.20.1 // indirect + github.com/go-openapi/jsonreference v0.20.2 // indirect github.com/go-openapi/loads v0.20.0 // indirect github.com/go-openapi/runtime v0.19.24 // indirect github.com/go-openapi/spec v0.20.0 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect - github.com/go-playground/validator/v10 v10.14.0 // indirect + github.com/go-playground/validator/v10 v10.17.0 // indirect github.com/go-stack/stack v1.8.0 // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/golang/glog v1.1.0 // indirect + github.com/golang/glog v1.2.2 // indirect github.com/golang/protobuf v1.5.3 // indirect - github.com/google/go-cmp v0.5.9 // indirect + github.com/google/go-cmp v0.6.0 // indirect github.com/google/gofuzz v1.2.0 // indirect github.com/hashicorp/hcl/v2 v2.13.0 // indirect github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb // indirect - github.com/huandu/xstrings v1.3.2 // indirect + github.com/huandu/xstrings v1.3.3 // indirect github.com/imdario/mergo v0.3.12 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect - github.com/jackc/pgconn v1.10.1 // indirect + github.com/jackc/pgconn v1.14.3 // indirect github.com/jackc/pgio v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect - github.com/jackc/pgproto3/v2 v2.2.0 // indirect - github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect - github.com/jackc/pgtype v1.9.1 // indirect + github.com/jackc/pgproto3/v2 v2.3.3 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/pgtype v1.14.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect - github.com/klauspost/compress v1.15.7 // indirect - github.com/klauspost/cpuid/v2 v2.2.4 // indirect - github.com/leodido/go-urn v1.2.4 // indirect + github.com/klauspost/compress v1.17.3 // indirect + github.com/klauspost/cpuid/v2 v2.2.6 // indirect + github.com/leodido/go-urn v1.3.0 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect - github.com/magefile/mage v1.14.0 // indirect + github.com/magefile/mage v1.15.1-0.20230912152418-9f54e0f83e2a // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-colorable v0.1.13 // indirect - github.com/mattn/go-runewidth v0.0.13 // indirect + github.com/mattn/go-runewidth v0.0.15 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d // indirect github.com/mitchellh/copystructure v1.2.0 // indirect @@ -159,45 +170,53 @@ require ( github.com/oklog/run v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.0.3-0.20211202183452-c5a74bcca799 // indirect - github.com/pelletier/go-toml/v2 v2.0.8 // indirect - github.com/pierrec/lz4/v4 v4.1.15 // indirect + github.com/pelletier/go-toml/v2 v2.1.1 // indirect + github.com/petar-dambovaliev/aho-corasick v0.0.0-20230725210150-fb29fc3c913e // indirect + github.com/pierrec/lz4/v4 v4.1.18 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect - github.com/prometheus/common v0.37.0 // indirect - github.com/prometheus/procfs v0.8.0 // indirect + github.com/prometheus/common v0.44.0 // indirect + github.com/prometheus/procfs v0.10.1 // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/robfig/cron/v3 v3.0.1 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect + github.com/sergi/go-diff v1.3.1 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/shopspring/decimal v1.2.0 // indirect github.com/spf13/cast v1.3.1 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/tetratelabs/wazero v1.2.1 // indirect - github.com/tidwall/gjson v1.13.0 // indirect + github.com/tetratelabs/wazero v1.8.0 // indirect + github.com/tidwall/gjson v1.17.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect github.com/tklauser/go-sysconf v0.3.11 // indirect github.com/tklauser/numcpus v0.6.0 // indirect + github.com/toorop/go-dkim v0.0.0-20201103131630-e1cd1a0a5208 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect - github.com/ugorji/go/codec v1.2.11 // indirect + github.com/ugorji/go/codec v1.2.12 // indirect github.com/vmihailenco/msgpack v4.0.4+incompatible // indirect + github.com/wasilibs/wazero-helpers v0.0.0-20240620070341-3dff1577cd52 // indirect github.com/yusufpapurcu/wmi v1.2.3 // indirect github.com/zclconf/go-cty v1.8.0 // indirect go.mongodb.org/mongo-driver v1.9.4 // indirect - golang.org/x/arch v0.3.0 // indirect - golang.org/x/net v0.10.0 // indirect - golang.org/x/sync v0.1.0 // indirect - golang.org/x/term v0.8.0 // indirect - golang.org/x/text v0.9.0 // indirect - golang.org/x/tools v0.7.0 // indirect + go.uber.org/atomic v1.10.0 // indirect + golang.org/x/arch v0.7.0 // indirect + golang.org/x/net v0.28.0 // indirect + golang.org/x/sync v0.8.0 // indirect + golang.org/x/term v0.23.0 // indirect + golang.org/x/time v0.3.0 // indirect + golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect google.golang.org/appengine v1.6.7 // indirect - google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gotest.tools/v3 v3.5.0 // indirect - k8s.io/api v0.27.3 // indirect - k8s.io/apimachinery v0.27.3 // indirect - k8s.io/klog/v2 v2.90.1 // indirect - k8s.io/utils v0.0.0-20230209194617-a36077c30491 // indirect + k8s.io/api v0.28.4 // indirect + k8s.io/apimachinery v0.28.4 // indirect + k8s.io/klog/v2 v2.100.1 // indirect + k8s.io/utils v0.0.0-20230406110748-d93618cff8a2 // indirect + rsc.io/binaryregexp v0.2.0 // indirect sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect sigs.k8s.io/structured-merge-diff/v4 v4.2.3 // indirect ) diff --git a/go.sum b/go.sum index db249b654d1..b2bd77c9915 100644 --- a/go.sum +++ b/go.sum @@ -1,63 +1,33 @@ -ariga.io/atlas v0.7.2-0.20220927111110-867ee0cca56a h1:6/nt4DODfgxzHTTg3tYy7YkVzruGQGZ/kRvXpA45KUo= -ariga.io/atlas v0.7.2-0.20220927111110-867ee0cca56a/go.mod h1:ft47uSh5hWGDCmQC9DsztZg6Xk+KagM5Ts/mZYKb9JE= +ariga.io/atlas v0.19.1-0.20240203083654-5948b60a8e43 h1:GwdJbXydHCYPedeeLt4x/lrlIISQ4JTH1mRWuE5ZZ14= +ariga.io/atlas v0.19.1-0.20240203083654-5948b60a8e43/go.mod h1:uj3pm+hUTVN/X5yfdBexHlZv+1Xu5u5ZbZx7+CDavNU= bitbucket.org/creachadair/stringset v0.0.9 h1:L4vld9nzPt90UZNrXjNelTshD74ps4P5NGs3Iq6yN3o= bitbucket.org/creachadair/stringset v0.0.9/go.mod h1:t+4WcQ4+PXTa8aQdNKe40ZP6iwesoMFWAxPGd3UGjyY= -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= -cloud.google.com/go v0.44.1/go.mod h1:iSa0KzasP4Uvy3f1mN/7PiObzGgflwredwwASm/v6AU= -cloud.google.com/go v0.44.2/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY= -cloud.google.com/go v0.45.1/go.mod h1:RpBamKRgapWJb87xiFSdk4g1CME7QZg3uwTez+TSTjc= -cloud.google.com/go v0.46.3/go.mod h1:a6bKKbmY7er1mI7TEI4lsAkts/mkhTSZK8w33B4RAg0= -cloud.google.com/go v0.50.0/go.mod h1:r9sluTvynVuxRIOHXQEHMFffphuXHOMZMycpNR5e6To= -cloud.google.com/go v0.52.0/go.mod h1:pXajvRH/6o3+F9jDHZWQ5PbGhn+o8w9qiu/CffaVdO4= -cloud.google.com/go v0.53.0/go.mod h1:fp/UouUEsRkN6ryDKNW/Upv/JBKnv6WDthjR6+vze6M= -cloud.google.com/go v0.54.0/go.mod h1:1rq2OEkV3YMf6n/9ZvGWI3GWw0VoqH/1x2nd8Is/bPc= -cloud.google.com/go v0.56.0/go.mod h1:jr7tqZxxKOVYizybht9+26Z/gUq7tiRzu+ACVAMbKVk= -cloud.google.com/go v0.57.0/go.mod h1:oXiQ6Rzq3RAkkY7N6t3TcE6jE+CIBBbA36lwQ1JyzZs= -cloud.google.com/go v0.62.0/go.mod h1:jmCYTdRCQuc1PHIIJ/maLInMho30T/Y0M4hTdTShOYc= -cloud.google.com/go v0.65.0/go.mod h1:O5N8zS7uWy9vkA9vayVHs65eM1ubvY4h553ofrNHObY= -cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o= -cloud.google.com/go/bigquery v1.3.0/go.mod h1:PjpwJnslEMmckchkHFfq+HTD2DmtT67aNFKH1/VBDHE= -cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvftPBK2Dvzc= -cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUMb4Nv6dBIg= -cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc= -cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ= -cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= -cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= -cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= -cloud.google.com/go/pubsub v1.1.0/go.mod h1:EwwdRX2sKPjnvnqCa270oGRyludottCI76h+R3AArQw= -cloud.google.com/go/pubsub v1.2.0/go.mod h1:jhfEVHT8odbXTkndysNHCcx0awwzvfOlguIAii9o8iA= -cloud.google.com/go/pubsub v1.3.1/go.mod h1:i+ucay31+CNRpDW4Lu78I4xXG+O1r/MAHgjpRVR+TSU= -cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw= -cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0ZeosJ0Rtdos= -cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohlUTyfDhBk= -cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= -cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= -dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= -entgo.io/ent v0.11.3 h1:F5FBGAWiDCGder7YT+lqMnyzXl6d0xU3xMBM/SO3CMc= -entgo.io/ent v0.11.3/go.mod h1:mvDhvynOzAsOe7anH7ynPPtMjA/eeXP96kAfweevyxc= -github.com/AlecAivazis/survey/v2 v2.2.7 h1:5NbxkF4RSKmpywYdcRgUmos1o+roJY8duCLZXbVjoig= -github.com/AlecAivazis/survey/v2 v2.2.7/go.mod h1:9DYvHgXtiXm6nCn+jXnOXLKbH+Yo9u8fAS/SduGdoPk= +entgo.io/ent v0.13.1 h1:uD8QwN1h6SNphdCCzmkMN3feSUzNnVvV/WIkHKMbzOE= +entgo.io/ent v0.13.1/go.mod h1:qCEmo+biw3ccBn9OyL4ZK5dfpwg++l1Gxwac5B1206A= +github.com/AlecAivazis/survey/v2 v2.3.7 h1:6I/u8FvytdGsgonrYsVn2t8t4QiRnh6QSTqkkhIiSjQ= +github.com/AlecAivazis/survey/v2 v2.3.7/go.mod h1:xUTIdE4KCOIjsBAE1JYsUPoCqYdZ1reCfTwbto0Fduo= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8= +github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= +github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI= github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU= -github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc= github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= -github.com/Masterminds/sprig/v3 v3.2.2 h1:17jRggJu518dr3QaafizSXOjKYp94wKfABxUmyxvxX8= -github.com/Masterminds/sprig/v3 v3.2.2/go.mod h1:UoaO7Yp8KlPnJIYWTFkMaqPUYKTfGFPhxNuwnnxkKlk= +github.com/Masterminds/semver/v3 v3.2.0/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= +github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0= +github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= +github.com/Masterminds/sprig/v3 v3.2.3 h1:eL2fZNezLomi0uOLqjQoN6BfsDD+fyLtgbJMAj9n6YA= +github.com/Masterminds/sprig/v3 v3.2.3/go.mod h1:rXcFaZ2zZbLRJv/xSysmlgIM1u11eBaRMhvYXJNkGuM= github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= -github.com/Netflix/go-expect v0.0.0-20180615182759-c93bf25de8e8 h1:xzYJEypr/85nBpB11F9br+3HUrpgb+fcm5iADzXXYEw= -github.com/Netflix/go-expect v0.0.0-20180615182759-c93bf25de8e8/go.mod h1:oX5x61PbNXchhh0oikYAH+4Pcfw5LKv21+Jnpr6r6Pc= +github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2 h1:+vx7roKuyA63nhn5WAunQHLTznkw5W8b1Xc0dNjp83s= +github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2/go.mod h1:HBCaDeC1lPdgDeDbhX8XFpy1jqjK0IBG8W5K+xYqA0w= github.com/PuerkitoBio/purell v1.1.0/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= -github.com/agext/levenshtein v1.2.1 h1:QmvMAjj2aEICytGiWzmxoE0x2KZvE0fvmqMOfy2tjT8= -github.com/agext/levenshtein v1.2.1/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558= +github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7lmo= +github.com/agext/levenshtein v1.2.3/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558= github.com/agnivade/levenshtein v1.0.1/go.mod h1:CURSv5d9Uaml+FovSIICkLbAUZ9S4RqaHDIsdSBg7lM= github.com/ahmetalpbalkan/dlog v0.0.0-20170105205344-4fb5f8204f26 h1:pzStYMLAXM7CNQjS/Wn+zK9MUxDhSUNfVvnHsyQyjs0= github.com/ahmetalpbalkan/dlog v0.0.0-20170105205344-4fb5f8204f26/go.mod h1:ilK+u7u1HoqaDk0mjhh27QJB7PyWMreGffEvOCoEKiY= @@ -66,134 +36,126 @@ github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuy github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= -github.com/alexliesenfeld/health v0.5.1 h1:cohQdtQbJdA6bj0aMD4gdXA9xQyvh9NxWO9XLGYTYcY= -github.com/alexliesenfeld/health v0.5.1/go.mod h1:N4NDIeQtlWumG+6z1ne1v62eQxktz5ylEgGgH9emdMw= +github.com/alexliesenfeld/health v0.8.0 h1:lCV0i+ZJPTbqP7LfKG7p3qZBl5VhelwUFCIVWl77fgk= +github.com/alexliesenfeld/health v0.8.0/go.mod h1:TfNP0f+9WQVWMQRzvMUjlws4ceXKEL3WR+6Hp95HUFc= github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8= -github.com/antonmedv/expr v1.12.5 h1:Fq4okale9swwL3OeLLs9WD9H6GbgBLJyN/NUHRv+n0E= -github.com/antonmedv/expr v1.12.5/go.mod h1:FPC8iWArxls7axbVLsW+kpg1mz29A1b2M6jt+hZfDkU= github.com/apparentlymart/go-textseg/v13 v13.0.0 h1:Y+KvPE1NYz0xl601PVImeQfFyEy6iT90AvPUL1NNfNw= github.com/apparentlymart/go-textseg/v13 v13.0.0/go.mod h1:ZK2fH7c4NqDTLtiYLvIkEghdlcqw7yxLeM89kiTRPUo= -github.com/appleboy/gin-jwt/v2 v2.8.0 h1:Glo7cb9eBR+hj8Y7WzgfkOlqCaNLjP+RV4dNO3fpdps= -github.com/appleboy/gin-jwt/v2 v2.8.0/go.mod h1:KsK7E8HTvRg3vOiumTsr/ntNTHbZ3IbHLe4Eto31p7k= +github.com/appleboy/gin-jwt/v2 v2.9.2 h1:GeS3lm9mb9HMmj7+GNjYUtpp3V1DAQ1TkUFa5poiZ7Y= +github.com/appleboy/gin-jwt/v2 v2.9.2/go.mod h1:mxGjKt9Lrx9Xusy1SrnmsCJMZG6UJwmdHN9bN27/QDw= github.com/appleboy/gofight/v2 v2.1.2 h1:VOy3jow4vIK8BRQJoC/I9muxyYlJ2yb9ht2hZoS3rf4= github.com/appleboy/gofight/v2 v2.1.2/go.mod h1:frW+U1QZEdDgixycTj4CygQ48yLTUhplt43+Wczp3rw= -github.com/aquasecurity/table v1.8.0 h1:9ntpSwrUfjrM6/YviArlx/ZBGd6ix8W+MtojQcM7tv0= -github.com/aquasecurity/table v1.8.0/go.mod h1:eqOmvjjB7AhXFgFqpJUEE/ietg7RrMSJZXyTN8E/wZw= github.com/asaskevich/govalidator v0.0.0-20180720115003-f9ffefc3facf/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= github.com/asaskevich/govalidator v0.0.0-20200108200545-475eaeb16496/go.mod h1:oGkLhpf+kjZl6xBf758TQhh5XrAeiJv/7FRz/2spLIg= github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535/go.mod h1:oGkLhpf+kjZl6xBf758TQhh5XrAeiJv/7FRz/2spLIg= github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef h1:46PFijGLmAjMPwCCCo7Jf0W6f9slllCkkv7vyc1yOSg= github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= -github.com/aws/aws-lambda-go v1.38.0 h1:4CUdxGzvuQp0o8Zh7KtupB9XvCiiY8yKqJtzco+gsDw= -github.com/aws/aws-lambda-go v1.38.0/go.mod h1:jwFe2KmMsHmffA1X2R09hH6lFzJQxzI8qK17ewzbQMM= +github.com/aws/aws-lambda-go v1.47.0 h1:0H8s0vumYx/YKs4sE7YM0ktwL2eWse+kfopsRI1sXVI= +github.com/aws/aws-lambda-go v1.47.0/go.mod h1:dpMpZgvWx5vuQJfBt0zqBha60q7Dd7RfgJv23DymV8A= github.com/aws/aws-sdk-go v1.34.28/go.mod h1:H7NKnBqNVzoTJpGfLrQkkD+ytBA93eiDYi/+8rV9s48= -github.com/aws/aws-sdk-go v1.42.25 h1:BbdvHAi+t9LRiaYUyd53noq9jcaAcfzOhSVbKfr6Avs= -github.com/aws/aws-sdk-go v1.42.25/go.mod h1:gyRszuZ/icHmHAVE4gc/r+cfCmhA1AD+vqfWbgI+eHs= -github.com/beevik/etree v1.1.0 h1:T0xke/WvNtMoCqgzPhkX2r4rjY3GDZFi+FjpRZY2Jbs= -github.com/beevik/etree v1.1.0/go.mod h1:r8Aw8JqVegEf0w2fDnATrX9VpkMcyFeM0FhwO62wh+A= +github.com/aws/aws-sdk-go v1.52.0 h1:ptgek/4B2v/ljsjYSEvLQ8LTD+SQyrqhOOWvHc/VGPI= +github.com/aws/aws-sdk-go v1.52.0/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= +github.com/beevik/etree v1.3.0 h1:hQTc+pylzIKDb23yYprodCWWTt+ojFfUZyzU09a/hmU= +github.com/beevik/etree v1.3.0/go.mod h1:aiPf89g/1k3AShMVAzriilpcE4R/Vuor90y83zVZWFc= +github.com/beevik/etree v1.4.1 h1:PmQJDDYahBGNKDcpdX8uPy1xRCwoCGVUiW669MEirVI= +github.com/beevik/etree v1.4.1/go.mod h1:gPNJNaBGVZ9AwsidazFZyygnd+0pAU38N4D+WemwKNs= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/blackfireio/osinfo v1.0.3 h1:Yk2t2GTPjBcESv6nDSWZKO87bGMQgO+Hi9OoXPpxX8c= -github.com/blackfireio/osinfo v1.0.3/go.mod h1:Pd987poVNmd5Wsx6PRPw4+w7kLlf9iJxoRKPtPAjOrA= +github.com/blackfireio/osinfo v1.0.5 h1:6hlaWzfcpb87gRmznVf7wSdhysGqLRz9V/xuSdCEXrA= +github.com/blackfireio/osinfo v1.0.5/go.mod h1:Pd987poVNmd5Wsx6PRPw4+w7kLlf9iJxoRKPtPAjOrA= github.com/bluele/gcache v0.0.2 h1:WcbfdXICg7G/DGBh1PFfcirkWOQV+v077yF1pSy3DGw= github.com/bluele/gcache v0.0.2/go.mod h1:m15KV+ECjptwSPxKhOhQoAFQVtUFjTVkc3H8o0t/fp0= github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= -github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= -github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= -github.com/c-robinson/iplib v1.0.3 h1:NG0UF0GoEsrC1/vyfX1Lx2Ss7CySWl3KqqXh3q4DdPU= -github.com/c-robinson/iplib v1.0.3/go.mod h1:i3LuuFL1hRT5gFpBRnEydzw8R6yhGkF4szNDIbF8pgo= -github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= -github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/bytedance/sonic v1.10.0-rc/go.mod h1:ElCzW+ufi8qKqNW0FY314xriJhyJhuoJ3gFZdAHF7NM= +github.com/bytedance/sonic v1.10.2 h1:GQebETVBxYB7JGWJtLBi07OVzWwt+8dWA00gEVW2ZFE= +github.com/bytedance/sonic v1.10.2/go.mod h1:iZcSUejdk5aukTND/Eu/ivjQuEL0Cu9/rf50Hi0u/g4= +github.com/c-robinson/iplib v1.0.8 h1:exDRViDyL9UBLcfmlxxkY5odWX5092nPsQIykHXhIn4= +github.com/c-robinson/iplib v1.0.8/go.mod h1:i3LuuFL1hRT5gFpBRnEydzw8R6yhGkF4szNDIbF8pgo= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= -github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= -github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= -github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= -github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= +github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d h1:77cEq6EriyTZ0g/qfRdp61a3Uu/AWrgIq2s0ClJV1g0= +github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d/go.mod h1:8EPpVsBuRksnlj1mLy4AWzRNQYxauNi62uWcE3to6eA= +github.com/chenzhuoyu/iasm v0.9.0/go.mod h1:Xjy2NpN3h7aUqeqM+woSuuvxmIe6+DDsiNLIrkAmYog= +github.com/chenzhuoyu/iasm v0.9.1 h1:tUHQJXo3NhBqw6s33wkGn9SP3bvrWLdlVIJ3hQBL7P0= +github.com/chenzhuoyu/iasm v0.9.1/go.mod h1:Xjy2NpN3h7aUqeqM+woSuuvxmIe6+DDsiNLIrkAmYog= github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= +github.com/corazawaf/libinjection-go v0.1.2 h1:oeiV9pc5rvJ+2oqOqXEAMJousPpGiup6f7Y3nZj5GoM= +github.com/corazawaf/libinjection-go v0.1.2/go.mod h1:OP4TM7xdJ2skyXqNX1AN1wN5nNZEmJNuWbNPOItn7aw= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= -github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= -github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= -github.com/creachadair/staticfile v0.1.3/go.mod h1:a3qySzCIXEprDGxk6tSxSI+dBBdLzqeBOMhZ+o2d3pM= +github.com/cpuguy83/go-md2man/v2 v2.0.3 h1:qMCsGGgs+MAzDFyp9LpAe1Lqy/fY/qCovCm0qnXZOBM= +github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/creack/pty v1.1.17/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= +github.com/crowdsecurity/coraza/v3 v3.0.0-20240108124027-a62b8d8e5607 h1:hyrYw3h8clMcRL2u5ooZ3tmwnmJftmhb9Ws1MKmavvI= +github.com/crowdsecurity/coraza/v3 v3.0.0-20240108124027-a62b8d8e5607/go.mod h1:br36fEqurGYZQGit+iDYsIzW0FF6VufMbDzyyLxEuPA= github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26 h1:r97WNVC30Uen+7WnLs4xDScS/Ex988+id2k6mDf8psU= github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26/go.mod h1:zpv7r+7KXwgVUZnUNjyP22zc/D7LKjyoY02weH2RBbk= -github.com/crowdsecurity/go-cs-lib v0.0.2 h1:+Tjmf/IclOXNzU9sxKVQvUl9CkMfbM60xQ0zA05NWps= -github.com/crowdsecurity/go-cs-lib v0.0.2/go.mod h1:iznTJ19qLTYdZBcRb5RVDlcUdSlayBCivBkWsXlOY3g= -github.com/crowdsecurity/grokky v0.2.1 h1:t4VYnDlAd0RjDM2SlILalbwfCrQxtJSMGdQOR0zwkE4= -github.com/crowdsecurity/grokky v0.2.1/go.mod h1:33usDIYzGDsgX1kHAThCbseso6JuWNJXOzRQDGXHtWM= +github.com/crowdsecurity/go-cs-lib v0.0.15 h1:zNWqOPVLHgKUstlr6clom9d66S0eIIW66jQG3Y7FEvo= +github.com/crowdsecurity/go-cs-lib v0.0.15/go.mod h1:ePyQyJBxp1W/1bq4YpVAilnLSz7HkzmtI7TRhX187EU= +github.com/crowdsecurity/grokky v0.2.2 h1:yALsI9zqpDArYzmSSxfBq2dhYuGUTKMJq8KOEIAsuo4= +github.com/crowdsecurity/grokky v0.2.2/go.mod h1:33usDIYzGDsgX1kHAThCbseso6JuWNJXOzRQDGXHtWM= github.com/crowdsecurity/machineid v1.0.2 h1:wpkpsUghJF8Khtmn/tg6GxgdhLA1Xflerh5lirI+bdc= github.com/crowdsecurity/machineid v1.0.2/go.mod h1:XWUSlnS0R0+u/JK5ulidwlbceNT3ZOCKteoVQEn6Luo= +github.com/davecgh/go-spew v0.0.0-20161028175848-04cdfd42973b/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dghubble/sling v1.3.0 h1:pZHjCJq4zJvc6qVQ5wN1jo5oNZlNE0+8T/h0XeXBUKU= -github.com/dghubble/sling v1.3.0/go.mod h1:XXShWaBWKzNLhu2OxikSNFrlsvowtz4kyRuXUG7oQKY= +github.com/dghubble/sling v1.4.2 h1:vs1HIGBbSl2SEALyU+irpYFLZMfc49Fp+jYryFebQjM= +github.com/dghubble/sling v1.4.2/go.mod h1:o0arCOz0HwfqYQJLrRtqunaWOn4X6jxE/6ORKRpVTD4= github.com/docker/distribution v2.8.2+incompatible h1:T3de5rq0dB1j30rp0sA2rER+m322EBzniBPB6ZIzuh8= github.com/docker/distribution v2.8.2+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= -github.com/docker/docker v24.0.4+incompatible h1:s/LVDftw9hjblvqIeTiGYXBCD95nOEEl7qRsRrIOuQI= -github.com/docker/docker v24.0.4+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/docker v24.0.9+incompatible h1:HPGzNmwfLZWdxHqK9/II92pyi1EpYKsAqcl4G0Of9v0= +github.com/docker/docker v24.0.9+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/go-connections v0.4.0 h1:El9xVISelRB7BuFusrZozjnkIM5YnzCViNKohAFqRJQ= github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec= github.com/docker/go-units v0.3.3/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= -github.com/enescakir/emoji v1.0.0 h1:W+HsNql8swfCQFtioDGDHCHri8nudlK1n5p2rHCJoog= -github.com/enescakir/emoji v1.0.0/go.mod h1:Bt1EKuLnKDTYpLALApstIkAjdDrS/8IAgTkKp+WKFD0= -github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= -github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/expr-lang/expr v1.16.9 h1:WUAzmR0JNI9JCiF0/ewwHB1gmcGw5wW7nWt8gc6PpCI= +github.com/expr-lang/expr v1.16.9/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= -github.com/fatih/color v1.15.0 h1:kOqh6YHBtK8aywxGerMG2Eq3H6Qgoqeo13Bk2Mv/nBs= -github.com/fatih/color v1.15.0/go.mod h1:0h5ZqXfHYED7Bhv2ZJamyIOUej9KtShiJESRwBDUSsw= +github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= +github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= +github.com/foxcpp/go-mockdns v1.0.0 h1:7jBqxd3WDWwi/6WhDvacvH1XsN3rOLXyHM1uhvIx6FI= +github.com/foxcpp/go-mockdns v1.0.0/go.mod h1:lgRN6+KxQBawyIghpnl5CezHFGS9VLzvtVlwxvzXTQ4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= -github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= -github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= -github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= -github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= +github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= -github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ1qq1U= github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= github.com/globalsign/mgo v0.0.0-20180905125535-1ca0a4f7cbcb/go.mod h1:xkRDCp4j0OGD1HRkm4kmhM+pmpv3AKq5SU7GMg4oO/Q= github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8/go.mod h1:xkRDCp4j0OGD1HRkm4kmhM+pmpv3AKq5SU7GMg4oO/Q= -github.com/go-co-op/gocron v1.17.0 h1:IixLXsti+Qo0wMvmn6Kmjp2csk2ykpkcL+EmHmST18w= -github.com/go-co-op/gocron v1.17.0/go.mod h1:IpDBSaJOVfFw7hXZuTag3SCSkqazXBBUkbQ1m1aesBs= -github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= -github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/go-co-op/gocron v1.37.0 h1:ZYDJGtQ4OMhTLKOKMIch+/CY70Brbb1dGdooLEhh7b0= +github.com/go-co-op/gocron v1.37.0/go.mod h1:3L/n6BkO7ABj+TrfSVXLRzsP26zmikL4ISkLQ0O8iNY= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= -github.com/go-kit/log v0.2.0/go.mod h1:NwTd00d/i8cPZ3xOwwiv2PO5MOcx78fFErGNcVmBjv0= github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= -github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= github.com/go-logr/logr v1.2.0/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= -github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= -github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= +github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-openapi/analysis v0.0.0-20180825180245-b006789cd277/go.mod h1:k70tL6pCuVxPJOHXQ+wIac1FUrvNkHolPie/cLEU6hI= @@ -229,8 +191,8 @@ github.com/go-openapi/jsonreference v0.18.0/go.mod h1:g4xxGn04lDIRh0GJb5QlpE3Hfo github.com/go-openapi/jsonreference v0.19.2/go.mod h1:jMjeRr2HHw6nAVajTXJ4eiUwohSTlpa0o73RUL1owJc= github.com/go-openapi/jsonreference v0.19.3/go.mod h1:rjx6GuL8TTa9VaixXglHmQmIL98+wF9xc8zWvFonSJ8= github.com/go-openapi/jsonreference v0.19.5/go.mod h1:RdybgQwPxbL4UEjuAruzK1x3nE69AqPYEJeo/TWfEeg= -github.com/go-openapi/jsonreference v0.20.1 h1:FBLnyygC4/IZZr893oiomc9XaghoveYTrLC1F86HID8= -github.com/go-openapi/jsonreference v0.20.1/go.mod h1:Bl1zwGIM8/wsvqjsOQLJ/SH+En5Ap4rVB5KVcIDZG2k= +github.com/go-openapi/jsonreference v0.20.2 h1:3sVjiK66+uXK/6oQ8xgcRKcFgQ5KXa2KvnJRumpMGbE= +github.com/go-openapi/jsonreference v0.20.2/go.mod h1:Bl1zwGIM8/wsvqjsOQLJ/SH+En5Ap4rVB5KVcIDZG2k= github.com/go-openapi/loads v0.17.0/go.mod h1:72tmFy5wsWx89uEVddd0RjRWPZm92WRLhf7AC+0+OOU= github.com/go-openapi/loads v0.18.0/go.mod h1:72tmFy5wsWx89uEVddd0RjRWPZm92WRLhf7AC+0+OOU= github.com/go-openapi/loads v0.19.0/go.mod h1:72tmFy5wsWx89uEVddd0RjRWPZm92WRLhf7AC+0+OOU= @@ -283,23 +245,21 @@ github.com/go-openapi/validate v0.19.12/go.mod h1:Rzou8hA/CBw8donlS6WNEUQupNvUZ0 github.com/go-openapi/validate v0.19.15/go.mod h1:tbn/fdOwYHgrhPBzidZfJC2MIVvs9GA7monOmWBbeCI= github.com/go-openapi/validate v0.20.0 h1:pzutNCCBZGZlE+u8HD3JZyWdc/TVbtVwlWUp8/vgUKk= github.com/go-openapi/validate v0.20.0/go.mod h1:b60iJT+xNNLfaQJUqLI7946tYiFEOuE9E4k54HpKcJ0= -github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= -github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= -github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= -github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= -github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= -github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= +github.com/go-playground/validator/v10 v10.17.0 h1:SmVVlfAOtlZncTxRuinDPomC2DkXJ4E5T9gDA0AIH74= +github.com/go-playground/validator/v10 v10.17.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= -github.com/go-test/deep v1.0.3 h1:ZrJSEWsXzPOxaZnFteGEfooLba+ju3FYIbOrS+rQd68= +github.com/go-test/deep v1.0.4 h1:u2CU3YKy9I2pmu9pX0eq50wCgjfGIt539SqR7FbHiho= +github.com/go-test/deep v1.0.4/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA= github.com/gobuffalo/attrs v0.0.0-20190224210810-a9411de4debd/go.mod h1:4duuawTqi2wkkpB4ePgWMaai6/Kc6WEz83bhFwpHzj0= github.com/gobuffalo/depgen v0.0.0-20190329151759-d478694a28d3/go.mod h1:3STtPUQYuzV0gBVOY3vy6CfMm/ljR4pABfrTeHNLHUY= github.com/gobuffalo/depgen v0.1.0/go.mod h1:+ifsuy7fhi15RWncXQQKjWS9JPkdah5sZvtHc2RXGlg= @@ -328,109 +288,68 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/goccy/go-yaml v1.11.0 h1:n7Z+zx8S9f9KgzG6KtQKf+kwqXZlLNR2F6018Dgau54= github.com/goccy/go-yaml v1.11.0/go.mod h1:H+mJrWtjPTJAHvRbV09MCK9xYwODM+wRTVFFTWckfng= -github.com/godbus/dbus v4.1.0+incompatible/go.mod h1:/YcGZj5zSblfDWMMoOzV4fas9FZnQYTkDnsGvmh2Grw= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= -github.com/golang-jwt/jwt/v4 v4.2.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= -github.com/golang-jwt/jwt/v4 v4.4.2 h1:rcc4lwaZgFMCZ5jxF9ABolDcIHdBytAFgqFPbSJQAYs= -github.com/golang-jwt/jwt/v4 v4.4.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/glog v0.0.0-20210429001901-424d2337a529/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/glog v1.1.0 h1:/d3pCKDPWNnvIWe0vVUpNP32qc8U3PDVxySP/y360qE= -github.com/golang/glog v1.1.0/go.mod h1:pfYeQZ3JWZoXTV5sFc986z3HTpwQs9At6P4ImfuP3NQ= -github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= -github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= -github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= -github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= -github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= +github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= +github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang/glog v1.2.2 h1:1+mZ9upx1Dh6FmUTFR1naJ77miKiXgALjWOZ3NVFPmY= +github.com/golang/glog v1.2.2/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.3.4/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= -github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk= -github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= -github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= -github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= -github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= -github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= -github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk= -github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= +github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/logger v1.1.1/go.mod h1:BkeJZ+1FhQ+/d087r4dzojEg1u2ZX+ZqG1jTUrLM+zQ= -github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= -github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= -github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= -github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= -github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20200212024743-f11f1df84d12/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20200229191704-1ebb73c60ed3/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/winops v0.0.0-20211216095627-f0e86eb1453b h1:THwEE9J2wPxF3BZm7WjLCASMcM7ctFzqLpTsCGh7gDY= -github.com/google/winops v0.0.0-20211216095627-f0e86eb1453b/go.mod h1:ShbX8v8clPm/3chw9zHVwtW3QhrFpL8mXOwNxClt4pg= -github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= -github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= +github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/winops v0.0.0-20230712152054-af9b550d0601 h1:XvlrmqZIuwxuRE88S9mkxX+FkV+YakqbiAC5Z4OzDnM= +github.com/google/winops v0.0.0-20230712152054-af9b550d0601/go.mod h1:rT1mcjzuvcDDbRmUTsoH6kV0DG91AkFe9UCjASraK5I= github.com/goombaio/namegenerator v0.0.0-20181006234301-989e774b106e h1:XmA6L9IPRdUr28a+SK/oMchGgQy159wvzXA5tJ7l+40= github.com/goombaio/namegenerator v0.0.0-20181006234301-989e774b106e/go.mod h1:AFIo+02s+12CEg8Gzz9kzhCbmbq6JcKNrhHffCGA9z4= -github.com/groob/plist v0.0.0-20210519001750-9f754062e6d6/go.mod h1:itkABA+w2cw7x5nYUS/pLRef6ludkZKOigbROmCTaFw= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/go-hclog v1.5.0 h1:bI2ocEMgcVlz55Oj1xZNBsVi900c7II+fWDyV9o+13c= github.com/hashicorp/go-hclog v1.5.0/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= github.com/hashicorp/go-plugin v1.4.10 h1:xUbmA4jC6Dq163/fWcp8P3JuHilrHHMLNRxzGQJ9hNk= github.com/hashicorp/go-plugin v1.4.10/go.mod h1:6/1TEzT0eQznvI/gV2CM29DLSkAK/e58mUWKVsPaph0= github.com/hashicorp/go-version v1.2.1 h1:zEfKbn2+PDgroKdiOzqiE8rsmLqU2uwi5PB5pBJ3TkI= github.com/hashicorp/go-version v1.2.1/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= -github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/hcl/v2 v2.13.0 h1:0Apadu1w6M11dyGFxWnmhhcMjkbAiKCv7G1r/2QgCNc= github.com/hashicorp/hcl/v2 v2.13.0/go.mod h1:e4z5nxYlWNPdDSNYX+ph14EvWYMFm3eP0zIUqPc2jr0= github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb h1:b5rjCoWHc7eqmAS4/qyk21ZsHyb6Mxv/jykxvNTkU4M= github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb/go.mod h1:+NfK9FKeTrX5uv1uIXGdwYDTeHna2qgaIlx54MXqjAM= -github.com/hinshun/vt10x v0.0.0-20180616224451-1954e6464174 h1:WlZsjVhE8Af9IcZDGgJGQpNflI3+MJSBhsgT5PCtzBQ= -github.com/hinshun/vt10x v0.0.0-20180616224451-1954e6464174/go.mod h1:DqJ97dSdRW1W22yXSB90986pcOyQ7r45iio1KN2ez1A= -github.com/huandu/xstrings v1.3.1/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= -github.com/huandu/xstrings v1.3.2 h1:L18LIDzqlW6xN2rEkpdV8+oL/IXWJ1APd+vsdYy4Wdw= -github.com/huandu/xstrings v1.3.2/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= -github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= +github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= +github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= +github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog= +github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68= +github.com/huandu/xstrings v1.3.3 h1:/Gcsuc1x8JVbJ9/rlye4xZnVAbEkGauT8lbebqcQws4= +github.com/huandu/xstrings v1.3.3/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/imdario/mergo v0.3.11/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= github.com/imdario/mergo v0.3.12 h1:b6R2BslTbIEToALKP7LxUvijTsNI9TAe80pLWN2g/HU= github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= @@ -449,8 +368,8 @@ github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsU github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= -github.com/jackc/pgconn v1.10.1 h1:DzdIHIjG1AxGwoEEqS+mGsURyjt4enSmqzACXvVzOT8= -github.com/jackc/pgconn v1.10.1/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= +github.com/jackc/pgconn v1.14.3 h1:bVoTr12EGANZz66nZPkMInAV/KHD2TxH9npjXXgiB3w= +github.com/jackc/pgconn v1.14.3/go.mod h1:RZbme4uasqzybK2RK5c65VsHxoyaml09lx3tXOcO/VM= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= @@ -466,29 +385,32 @@ github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvW github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.2.0 h1:r7JypeP2D3onoQTCxWdTpCtJ4D+qpKr0TxvoyMhZ5ns= -github.com/jackc/pgproto3/v2 v2.2.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= +github.com/jackc/pgproto3/v2 v2.3.3 h1:1HLSx5H+tXR9pW3in3zaztoEwQYRC9SQaYUHjTSUOag= +github.com/jackc/pgproto3/v2 v2.3.3/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM= -github.com/jackc/pgtype v1.9.1 h1:MJc2s0MFS8C3ok1wQTdQxWuXQcB6+HwAm5x1CzW7mf0= -github.com/jackc/pgtype v1.9.1/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= +github.com/jackc/pgtype v1.14.0 h1:y+xUdabmyMkJLyApYuPj38mW+aAIqCe5uuBB51rH3Vw= +github.com/jackc/pgtype v1.14.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= -github.com/jackc/pgx/v4 v4.14.1 h1:71oo1KAGI6mXhLiTMn6iDFcp3e7+zon/capWjl2OEFU= -github.com/jackc/pgx/v4 v4.14.1/go.mod h1:RgDuE4Z34o7XE92RpLsvFiOEfrAUT0Xt2KxvX73W06M= +github.com/jackc/pgx/v4 v4.18.2 h1:xVpYkNR5pk5bMCZGfClbO962UIqVABcAGt7ha1s/FeU= +github.com/jackc/pgx/v4 v4.18.2/go.mod h1:Ey4Oru5tH5sB6tV7hDmfWFahwF15Eb7DNXlRKx2CkVw= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle v1.2.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jarcoal/httpmock v1.1.0 h1:F47ChZj1Y2zFsCXxNkBPwNNKnAyOATcdQibk0qEdVCE= github.com/jarcoal/httpmock v1.1.0/go.mod h1:ATjnClrvW/3tijVmpL/va5Z3aAyGvqU3gCT8nX0Txik= +github.com/jedib0t/go-pretty/v6 v6.5.9 h1:ACteMBRrrmm1gMsXe9PSTOClQ63IXDUt03H5U+UV8OU= +github.com/jedib0t/go-pretty/v6 v6.5.9/go.mod h1:zbn98qrYlh95FIhwwsbIip0LYpwSG8SUOScs+v9/t0E= github.com/jhump/protoreflect v1.6.0 h1:h5jfMVslIg6l29nsMs0D8Wj17RDVdNYti0vDN/PZZoE= +github.com/jhump/protoreflect v1.6.0/go.mod h1:eaTn3RZAmMBcV0fifFvlm6VHNz3wSkYyXYWUh7ymB74= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= @@ -496,19 +418,12 @@ github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfC github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= -github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= -github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= -github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= -github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= -github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= -github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= github.com/jszwec/csvutil v1.5.1 h1:c3GFBhj6DFMUl4dMK3+B6rz2+LWWS/e9VJiVJ9t9kfQ= github.com/jszwec/csvutil v1.5.1/go.mod h1:Rpu7Uu9giO9subDyMCIQfHVDuLrcaC36UA4YcJjGBkg= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= -github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/karrick/godirwalk v1.8.0/go.mod h1:H5KPZjojv4lE+QYImBI8xVtrBRgYrIVsaRPx4tDPEn4= github.com/karrick/godirwalk v1.10.3/go.mod h1:RoGL9dQei4vP9ilrpETWE8CLOZ1kiN0LhBygSwrAsHA= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= @@ -518,41 +433,42 @@ github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.9.5/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= -github.com/klauspost/compress v1.15.7 h1:7cgTQxJCU/vy+oP/E3B9RGbQTgbiVzIJWIKOLoAsPok= -github.com/klauspost/compress v1.15.7/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= +github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= +github.com/klauspost/compress v1.17.3 h1:qkRjuerhUU1EmXLYGkSH6EZL+vPSxIrYjLNAK4slzwA= +github.com/klauspost/compress v1.17.3/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= -github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= +github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc= +github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/pty v1.1.4/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.5/go.mod h1:9r2w37qlBe7rQ6e1fg1S/9xpWHSnaqNdHD3WcMdbPDA= -github.com/kr/pty v1.1.8 h1:AkaSdXYQOWeaO3neb8EM634ahkXXe3jYbVh/F9lq+GI= github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348 h1:MtvEpTB6LX3vkb4ax0b5D2DHbNAUsen0Gx5wZoq3lV4= -github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= -github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= -github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/leodido/go-urn v1.3.0 h1:jX8FDLfW4ThVXctBNZ+3cIWnCSnrACDV73r76dy0aQQ= +github.com/leodido/go-urn v1.3.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8= github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= github.com/lithammer/dedent v1.1.0 h1:VNzHMVCBNG1j0fh3OrsFRkVUwStdDArbgBWoPAffktY= github.com/lithammer/dedent v1.1.0/go.mod h1:jrXYCQtgg0nJiN+StA2KgR7w6CiQNv9Fd/Z9BP0jIOc= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= -github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= -github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= +github.com/magefile/mage v1.15.1-0.20230912152418-9f54e0f83e2a h1:tdPcGgyiH0K+SbsJBBm2oPyEIOTAvLBwD9TuUwVtZho= +github.com/magefile/mage v1.15.1-0.20230912152418-9f54e0f83e2a/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/mailru/easyjson v0.0.0-20180823135443-60711f1a8329/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= @@ -576,10 +492,10 @@ github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hd github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= -github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-runewidth v0.0.13 h1:lTGmDsbAYt5DmK6OnoV7EuIF1wEIFAcxld6ypU4OSgU= -github.com/mattn/go-runewidth v0.0.13/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= +github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= @@ -588,6 +504,8 @@ github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfr github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d h1:5PJl274Y63IEHC+7izoQE9x6ikvDFZS2mDVS3drnohI= github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= +github.com/miekg/dns v1.1.50 h1:DQUfb9uc6smULcREF09Uc+/Gd46YWqJd5DbpPE9xkcA= +github.com/miekg/dns v1.1.50/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME= github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw= github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s= @@ -609,7 +527,6 @@ github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3 github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= @@ -619,7 +536,6 @@ github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= -github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= @@ -629,56 +545,49 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8 github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.0.3-0.20211202183452-c5a74bcca799 h1:rc3tiVYb5z54aKaDfakKn0dDjIyPpTtszkjuMzyt7ec= github.com/opencontainers/image-spec v1.0.3-0.20211202183452-c5a74bcca799/go.mod h1:BtxoFyWECRxE4U/7sNtV5W15zMzWCbyJoFRP3s7yZA0= -github.com/oschwald/geoip2-golang v1.4.0 h1:5RlrjCgRyIGDz/mBmPfnAF4h8k0IAcRv9PvrpOfz+Ug= -github.com/oschwald/geoip2-golang v1.4.0/go.mod h1:8QwxJvRImBH+Zl6Aa6MaIcs5YdlZSTKtzmPGzQqi9ng= -github.com/oschwald/maxminddb-golang v1.6.0/go.mod h1:DUJFucBg2cvqx42YmDa/+xHvb0elJtOm3o4aFQ/nb/w= -github.com/oschwald/maxminddb-golang v1.8.0 h1:Uh/DSnGoxsyp/KYbY1AuP0tYEwfs0sCph9p/UMXK/Hk= -github.com/oschwald/maxminddb-golang v1.8.0/go.mod h1:RXZtst0N6+FY/3qCNmZMBApR19cdQj43/NM9VkrNAis= +github.com/oschwald/geoip2-golang v1.9.0 h1:uvD3O6fXAXs+usU+UGExshpdP13GAqp4GBrzN7IgKZc= +github.com/oschwald/geoip2-golang v1.9.0/go.mod h1:BHK6TvDyATVQhKNbQBdrj9eAvuwOMi2zSFXizL3K81Y= +github.com/oschwald/maxminddb-golang v1.12.0 h1:9FnTOD0YOhP7DGxGsq4glzpGy5+w7pq50AS6wALUMYs= +github.com/oschwald/maxminddb-golang v1.12.0/go.mod h1:q0Nob5lTCqyQ8WT6FYgS1L7PXKVVbgiymefNwIjPzgY= github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 h1:onHthvaw9LFnH4t2DcNVpwGmV9E1BkGknEliJkfwQj0= github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58/go.mod h1:DXv8WO4yhMYhSNPKjeNKa5WY9YCIEBRbNzFFPJbWO6Y= github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= github.com/pelletier/go-toml v1.4.0/go.mod h1:PN7xzY2wHTK0K9p34ErDQMlFxa51Fk0OUruD3k1mMwo= github.com/pelletier/go-toml v1.7.0/go.mod h1:vwGMzjaWMwyfHwgIBhI2YUM4fB6nL6lVAvS1LBMMhTE= -github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= -github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= -github.com/pierrec/lz4/v4 v4.1.15 h1:MO0/ucJhngq7299dKLwIMtgTfbkoSPF6AoMYDd8Q4q0= +github.com/pelletier/go-toml/v2 v2.1.1 h1:LWAJwfNvjQZCFIDKWYQaM62NcYeYViCmWIwmOStowAI= +github.com/pelletier/go-toml/v2 v2.1.1/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= +github.com/petar-dambovaliev/aho-corasick v0.0.0-20230725210150-fb29fc3c913e h1:POJco99aNgosh92lGqmx7L1ei+kCymivB/419SD15PQ= +github.com/petar-dambovaliev/aho-corasick v0.0.0-20230725210150-fb29fc3c913e/go.mod h1:EHPiTAKtiFmrMldLUNswFwfZ2eJIYBHktdaUTZxYWRw= github.com/pierrec/lz4/v4 v4.1.15/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= +github.com/pierrec/lz4/v4 v4.1.18 h1:xaKrnTkyoqfh1YItXl56+6KJNVYWlEEPuAQW9xsplYQ= +github.com/pierrec/lz4/v4 v4.1.18/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v0.0.0-20151028094244-d8ed2627bdf0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= -github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= -github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= -github.com/prometheus/client_golang v1.12.1/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY= -github.com/prometheus/client_golang v1.14.0 h1:nJdhIvne2eSX/XRAFV9PcvFFRbrjbcTUj0VP62TMhnw= -github.com/prometheus/client_golang v1.14.0/go.mod h1:8vpkKitgIVNcqrRBWh1C4TIUQgYNtG/XQE4E/Zae36Y= +github.com/prometheus/client_golang v1.16.0 h1:yk/hx9hDbrGHovbci4BY+pRMfSuuat626eFsHb7tmT8= +github.com/prometheus/client_golang v1.16.0/go.mod h1:Zsulrv/L9oM40tJ7T815tM89lFEugiJ9HzIqaAx4LKc= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.1.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.3.0 h1:UBgGFHqYdG/TPFD1B1ogZywDqEkwp3fBMvqdiQ7Xew4= -github.com/prometheus/client_model v0.3.0/go.mod h1:LDGWKZIo7rky3hgvBe+caln+Dr3dPggB5dvjtD7w9+w= +github.com/prometheus/client_model v0.4.0 h1:5lQXD3cAg1OXBf4Wq03gTrXHeaV0TQvGfUooCfx1yqY= +github.com/prometheus/client_model v0.4.0/go.mod h1:oMQmHW1/JoDwqLtg57MGgP/Fb1CJEYF2imWWhWtMkYU= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.7.0/go.mod h1:DjGbpBbp5NYNiECxcL/VnbXCCaQpKd3tt26CguLLsqA= -github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= -github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= -github.com/prometheus/common v0.32.1/go.mod h1:vu+V0TpY+O6vW9J44gczi3Ap/oXXR10b+M/gUGO4Hls= -github.com/prometheus/common v0.37.0 h1:ccBbHCgIiT9uSoFY0vX8H3zsNR5eLt17/RQLUvn8pXE= -github.com/prometheus/common v0.37.0/go.mod h1:phzohg0JFMnBEFGxTDbfu3QyL5GI8gTQJFhYO5B3mfA= +github.com/prometheus/common v0.44.0 h1:+5BrQJwiBB9xsMygAB3TNvpQKOwlkc25LbISbrdOOfY= +github.com/prometheus/common v0.44.0/go.mod h1:ofAIvZbQ1e/nugmZGz4/qCb9Ap1VoSTIO7x0VV9VvuY= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= -github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= -github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= -github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= -github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo= -github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4= +github.com/prometheus/procfs v0.10.1 h1:kYK1Va/YMlutzCGazswoHKo//tZVlFpKYh+PymziUAg= +github.com/prometheus/procfs v0.10.1/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPHWJq+XBB/FM= github.com/prometheus/prom2json v1.3.0 h1:BlqrtbT9lLH3ZsOVhXPsHzFrApCTKRifB7gjJuypu6Y= github.com/prometheus/prom2json v1.3.0/go.mod h1:rMN7m0ApCowcoDlypBHlkNbp5eJQf/+1isKykIP5ZnM= github.com/r3labs/diff/v2 v2.14.1 h1:wRZ3jB44Ny50DSXsoIcFQ27l2x+n5P31K/Pk+b9B0Ic= @@ -690,17 +599,23 @@ github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzG github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= +github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sanity-io/litter v1.5.5 h1:iE+sBxPBzoK6uaEP5Lt3fHNgpKcHXc/A2HGETy0uJQo= +github.com/sanity-io/litter v1.5.5/go.mod h1:9gzJgR2i4ZpjZHsKvUXIRQVk7P+yM3e+jAF7bU2UI5U= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= -github.com/segmentio/kafka-go v0.4.34 h1:Dm6YlLMiVSiwwav20KY0AoY63s661FXevwJ3CVHUERo= -github.com/segmentio/kafka-go v0.4.34/go.mod h1:GAjxBQJdQMB5zfNA21AhpaqOB2Mu+w3De4ni3Gbm8y0= -github.com/sergi/go-diff v1.0.0 h1:Kpca3qRNrduNnOQeazBd0ysaKrUJiIuISHxogkT9RPQ= +github.com/segmentio/kafka-go v0.4.45 h1:prqrZp1mMId4kI6pyPolkLsH6sWOUmDxmmucbL4WS6E= +github.com/segmentio/kafka-go v0.4.45/go.mod h1:HjF6XbOKh0Pjlkr5GVZxt6CsjjwnmhVOfURM5KMd8qg= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= +github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8= +github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I= github.com/shirou/gopsutil/v3 v3.23.5 h1:5SgDCeQ0KW0S4N0znjeM/eFHXXOKyv2dVNgRq/c9P6Y= github.com/shirou/gopsutil/v3 v3.23.5/go.mod h1:Ng3Maa27Q2KARVJ0SPZF5NdrQSC3XHKP8IIWrHgMeLY= github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM= @@ -714,15 +629,16 @@ github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPx github.com/sirupsen/logrus v1.4.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= -github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/slack-go/slack v0.12.2 h1:x3OppyMyGIbbiyFhsBmpf9pwkUzMhthJMRNmNlA4LaQ= +github.com/slack-go/slack v0.12.2/go.mod h1:hlGi5oXA+Gt+yWTPP0plCdRKmjsDxecdHxYQdlMQKOw= github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng= github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= github.com/spf13/cobra v1.4.0/go.mod h1:Wo4iy3BUC+X2Fybo0PDqwJIv3dNRiZLHQymsfxlB84g= -github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= -github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= +github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= +github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= @@ -730,9 +646,10 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/testify v1.2.1/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v0.0.0-20161117074351-18a02ba4a312/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= @@ -744,53 +661,61 @@ github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1F github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/tetratelabs/wazero v1.2.1 h1:J4X2hrGzJvt+wqltuvcSjHQ7ujQxA9gb6PeMs4qlUWs= -github.com/tetratelabs/wazero v1.2.1/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ= -github.com/tidwall/gjson v1.12.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.13.0 h1:3TFY9yxOQShrvmjdM76K+jc66zJeT6D3/VFFYCGQf7M= -github.com/tidwall/gjson v1.13.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tetratelabs/wazero v1.8.0 h1:iEKu0d4c2Pd+QSRieYbnQC9yiFlMS9D+Jr0LsRmcF4g= +github.com/tetratelabs/wazero v1.8.0/go.mod h1:yAI0XTsMBhREkM/YDAK/zNou3GoiAce1P6+rp/wQhjs= +github.com/tidwall/gjson v1.17.0 h1:/Jocvlh98kcTfpN2+JzGQWQcqrPQwDrVEMApx/M5ZwM= +github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= -github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tklauser/go-sysconf v0.3.11 h1:89WgdJhk5SNwJfu+GKyYveZ4IaJ7xAkecBo+KdJV0CM= github.com/tklauser/go-sysconf v0.3.11/go.mod h1:GqXfhXY3kiPa0nAXPDIQIWzJbMCB7AmcWpGR8lSZfqI= github.com/tklauser/numcpus v0.6.0 h1:kebhY2Qt+3U6RNK7UqpYNA+tJ23IBEGKkB7JQBfDYms= github.com/tklauser/numcpus v0.6.0/go.mod h1:FEZLMke0lhOUG6w2JadTzp0a+Nl8PF/GFkQ5UVIcaL4= +github.com/toorop/go-dkim v0.0.0-20201103131630-e1cd1a0a5208 h1:PM5hJF7HVfNWmCjMdEfbuOBNXSVF2cMFGgQTPdKCbwM= +github.com/toorop/go-dkim v0.0.0-20201103131630-e1cd1a0a5208/go.mod h1:BzWtXXrXzZUvMacR0oF/fbDDgUPO8L36tDMmRAf14ns= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= -github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= -github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= -github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= -github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= +github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/umahmood/haversine v0.0.0-20151105152445-808ab04add26 h1:UFHFmFfixpmfRBcxuu+LA9l8MdURWVdVNUHxO5n1d2w= github.com/umahmood/haversine v0.0.0-20151105152445-808ab04add26/go.mod h1:IGhd0qMDsUa9acVjsbsT7bu3ktadtGOHI79+idTew/M= github.com/vektah/gqlparser v1.1.2/go.mod h1:1ycwN7Ij5njmMkPPAOaRFY4rET2Enx7IkVv3vaXspKw= github.com/vjeantet/grok v1.0.1 h1:2rhIR7J4gThTgcZ1m2JY4TrJZNgjn985U28kT2wQrJ4= +github.com/vjeantet/grok v1.0.1/go.mod h1:ax1aAchzC6/QMXMcyzHQGZWaW1l195+uMYIkCWPCNIo= github.com/vmihailenco/msgpack v4.0.4+incompatible h1:dSLoQfGFAo3F6OoNhwUmLwVgaUXK79GlxNBwueZn0xI= github.com/vmihailenco/msgpack v4.0.4+incompatible/go.mod h1:fy3FlTQTDXWkZ7Bh6AcGMlsjHatGryHQYUTf1ShIgkk= github.com/vmihailenco/msgpack/v4 v4.3.12/go.mod h1:gborTTJjAo/GWTqqRjrLCn9pgNN+NXzzngzBKDPIqw4= github.com/vmihailenco/tagparser v0.1.1/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI= -github.com/wasilibs/go-re2 v1.3.0 h1:LFhBNzoStM3wMie6rN2slD1cuYH2CGiHpvNL3UtcsMw= -github.com/wasilibs/go-re2 v1.3.0/go.mod h1:AafrCXVvGRJJOImMajgJ2M7rVmWyisVK7sFshbxnVrg= +github.com/wasilibs/go-re2 v1.7.0 h1:bYhl8gn+a9h01dxwotNycxkiFPTiSgwUrIz8KZJ90Lc= +github.com/wasilibs/go-re2 v1.7.0/go.mod h1:sUsZMLflgl+LNivDE229omtmvjICmOseT9xOy199VDU= github.com/wasilibs/nottinygc v0.4.0 h1:h1TJMihMC4neN6Zq+WKpLxgd9xCFMw7O9ETLwY2exJQ= +github.com/wasilibs/nottinygc v0.4.0/go.mod h1:oDcIotskuYNMpqMF23l7Z8uzD4TC0WXHK8jetlB3HIo= +github.com/wasilibs/wazero-helpers v0.0.0-20240620070341-3dff1577cd52 h1:OvLBa8SqJnZ6P+mjlzc2K7PM22rRUPE1x32G9DTPrC4= +github.com/wasilibs/wazero-helpers v0.0.0-20240620070341-3dff1577cd52/go.mod h1:jMeV4Vpbi8osrE/pKUxRZkVaA0EX7NZN0A9/oRzgpgY= +github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= github.com/xdg-go/scram v1.0.2/go.mod h1:1WAq6h33pAW+iRreB34OORO2Nf7qel3VV3fjBj+hCSs= +github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= +github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= github.com/xdg-go/stringprep v1.0.2/go.mod h1:8F9zXuvzgwmyT5DUm4GUfZGDdT3W+LCvS6+da4O5kxM= +github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= +github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c/go.mod h1:lB8K/P019DLNhemzwFU4jHLhdvlE6uDZjXFejJXr49I= -github.com/xdg/scram v1.0.5 h1:TuS0RFmt5Is5qm9Tm2SoD89OPqe4IRiFtyFY4iwWXsw= -github.com/xdg/scram v1.0.5/go.mod h1:lB8K/P019DLNhemzwFU4jHLhdvlE6uDZjXFejJXr49I= github.com/xdg/stringprep v0.0.0-20180714160509-73f8eece6fdc/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0HrGL1Y= -github.com/xdg/stringprep v1.0.3 h1:cmL5Enob4W83ti/ZHuZLuKD/xqJfus4fVPwE+/BDm+4= -github.com/xdg/stringprep v1.0.3/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0HrGL1Y= +github.com/xhit/go-simple-mail/v2 v2.16.0 h1:ouGy/Ww4kuaqu2E2UrDw7SvLaziWTB60ICLkIkNVccA= +github.com/xhit/go-simple-mail/v2 v2.16.0/go.mod h1:b7P5ygho6SYE+VIqpxA6QkYfv4teeyG4MKqB3utRu98= github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= -github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw= github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= github.com/zclconf/go-cty v1.8.0 h1:s4AvqaeQzJIu3ndv4gVIhplVD0krU+bgrcLSVUnaWuA= @@ -804,15 +729,13 @@ go.mongodb.org/mongo-driver v1.4.3/go.mod h1:WcMNYLx/IlOxLe6JRJiv2uXuCz6zBLndR4S go.mongodb.org/mongo-driver v1.4.4/go.mod h1:WcMNYLx/IlOxLe6JRJiv2uXuCz6zBLndR4SoGjYphSc= go.mongodb.org/mongo-driver v1.9.4 h1:qXWlnK2WCOWSxJ/Hm3XyYOGKv3ujA2btBsCyuIFvQjc= go.mongodb.org/mongo-driver v1.9.4/go.mod h1:0sQWfOeY63QTntERDJJ/0SuKK0T1uVSgKCuAROlKEPY= -go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= -go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= -go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= +go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= @@ -821,8 +744,8 @@ go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= -golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= -golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/arch v0.7.0 h1:pskyeJh/3AmoQ8CPE95vxHLqp1G1GfGNXTmcl9NEKTc= +golang.org/x/arch v0.7.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190320223903-b7391e95e576/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -830,265 +753,150 @@ golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaE golang.org/x/crypto v0.0.0-20190422162423-af44ce270edf/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190530122614-20be4c3c3ed5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190617133340-57b3e21c3d56/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200414173820-0848c9571904/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20201216223049-8b5274cf687f/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= -golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= -golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek= -golang.org/x/exp v0.0.0-20191030013958-a1ab85dbe136/go.mod h1:JXzH8nQsPlswgeRAPE3MuO9GYsAcnJvJ4vnMwN/5qkY= -golang.org/x/exp v0.0.0-20191129062945-2f5052295587/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= -golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= -golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= -golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= -golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= -golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc= -golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= -golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= -golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= +golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= +golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= +golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs= -golang.org/x/lint v0.0.0-20200130185559-910be7a94367/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= -golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= -golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE= -golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= -golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= -golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20181005035420-146acd28ed58/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190320064053-1272bf9dcd53/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20190724013045-ca1201d0de80/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200222125558-5a598a2470a0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200506145744-7e3656a0809f/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200602114024-627f9648deb9/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= -golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= -golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20211209124913-491a49abca63/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/net v0.0.0-20220706163947-c90051bbdb60/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20210514164344-f6687ab2804c/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= -golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b/go.mod h1:DAh4E804XQdzx2j+YRIaUnCqCV2RuMz24cGBJ5QYIrc= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= +golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190412183630-56d357773e84/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190321052220-f7bb7a8bee54/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190419153524-e8e3143a4f4a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190530182044-ad28b68e88f1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190531175056-4c3a928424d2/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190616124812-15dcb6c0061f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191220142924-d4481acd189f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191224085550-c709ea063b76/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200331124033-c3d80250170d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200501052902-10377860bb8e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210426230700-d19ff857e887/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210601080250-7ecdf8ef093b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= -golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= +golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.8.0 h1:n5xxQn2i3PC0yLAbjTpNT85q/Kgzcr2gIoX9OrJUols= +golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= -golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= +golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU= +golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= +golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= +golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= +golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190125232054-d66bd3c5d5a6/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190329151228-23e29df326fe/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190416151739-9c9e1878f421/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190420181800-aa740d480789/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190531172133-b3315ee88b7d/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= -golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190614205625-5aca471b1d59/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190617190820-da514acc4774/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= -golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= -golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191113191852-77e3bb0ad9e7/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191115202509-3a792d9c32b2/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191216173652-a0e659d51361/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20191227053925-7b8e75db28f4/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200117161641-43d50277825c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200122220014-bf1340f18c4a/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200204074204-1cc6d1ef6c74/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200207183749-b753a1ba74fa/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200212150539-ea181f53ac56/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200224181240-023911ca70b2/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200227222343-706bc42d1f0d/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200304193943-95d2e580d8eb/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw= -golang.org/x/tools v0.0.0-20200312045724-11d5b4c81c7d/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw= -golang.org/x/tools v0.0.0-20200331025713-a30bf2db82d4/go.mod h1:Sl4aGygMT6LrqrWclx+PTx3U+LnKx/seiNR+3G19Ar8= -golang.org/x/tools v0.0.0-20200501065659-ab2804fb9c9d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200512131952-2bc93b1c0c88/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200515010526-7d3b6ebf133d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200618134242-20370b0cb4b2/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= -golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= -golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.7.0 h1:W4OVu8VVOaIO0yzWMNdepAulS7YfoS3Zabrm8DOXXU4= -golang.org/x/tools v0.7.0/go.mod h1:4pg6aUX35JBAogB10C9AtvVL+qowtN4pT3CGSQex14s= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -1096,89 +904,18 @@ golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= -google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= -google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= -google.golang.org/api v0.9.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= -google.golang.org/api v0.13.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= -google.golang.org/api v0.14.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= -google.golang.org/api v0.15.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= -google.golang.org/api v0.17.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.18.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.19.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.20.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.22.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.24.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= -google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= -google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM= -google.golang.org/api v0.30.0/go.mod h1:QGmEvQ87FHZNiUVJkT14jQNYJ4ZJjdRF23ZXz5138Fc= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190502173448-54afdca5d873/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8= -google.golang.org/genproto v0.0.0-20191108220845-16a3f7862a1a/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20191115194625-c23dd37a84c9/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20191216164720-4f79533eabd1/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20191230161307-f3c370f40bfb/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20200115191322-ca5a22157cba/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20200122232147-0452cf42e150/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20200204135345-fa8e72b47b90/go.mod h1:GmwEX6Z4W5gMy59cAlVYjN9JhxgbQH6Gn+gFDQe2lzA= -google.golang.org/genproto v0.0.0-20200212174721-66ed5ce911ce/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200224152610-e50cd9704f63/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200228133532-8c2c7df3a383/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200305110556-506484158171/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200312145019-da6875a35672/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200331122359-1ee6d9798940/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200430143042-b979b6f78d84/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200511104702-f5ebc3bea380/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200515170657-fc4c6c6a6587/go.mod h1:YsZOwe1myG/8QRHRsmBRE1LrgQY60beZKjly0O1fX9U= -google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/genproto v0.0.0-20200618031413-b414f8b61790/go.mod h1:jDfRM7FcilCzHH/e9qn6dsT145K34l5v+OpcnNgKAAA= -google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 h1:KpwkzHKEF7B9Zxg18WzOa7djJ+Ha5DzthMyZYQfEn2A= -google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= -google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= -google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= -google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.27.1/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.28.0/go.mod h1:rpkK4SK4GF4Ach/+MFLZUBavHOvF2JJB5uozKKal+60= -google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= -google.golang.org/grpc v1.30.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= -google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= -google.golang.org/grpc v1.56.1 h1:z0dNfjIl0VpaZ9iSVjA6daGatAYwPGstTjt5vkRMFkQ= -google.golang.org/grpc v1.56.1/go.mod h1:I9bI3vqKfayGqPUAwGdOSu7kt6oIJLixfffKrpXqQ9s= -google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= -google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= -google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= -google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= -google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= -google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= -google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142 h1:e7S5W7MGGLaSu8j3YjdezkZ+m1/Nm0uRVRMEMGk26Xs= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= +google.golang.org/grpc v1.67.1 h1:zWnc1Vrcno+lHZCOofnIMvycFcc0QRGIzm9dhnDX68E= +google.golang.org/grpc v1.67.1/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= -google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= +google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -1199,7 +936,6 @@ gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637/go.mod h1:BHsqpu/nsuzkT5BpiH gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= @@ -1211,29 +947,24 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY= gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= -honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= -honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= -k8s.io/api v0.27.3 h1:yR6oQXXnUEBWEWcvPWS0jQL575KoAboQPfJAuKNrw5Y= -k8s.io/api v0.27.3/go.mod h1:C4BNvZnQOF7JA/0Xed2S+aUyJSfTGkGFxLXz9MnpIpg= -k8s.io/apimachinery v0.27.3 h1:Ubye8oBufD04l9QnNtW05idcOe9Z3GQN8+7PqmuVcUM= -k8s.io/apimachinery v0.27.3/go.mod h1:XNfZ6xklnMCOGGFNqXG7bUrQCoR04dh/E7FprV6pb+E= -k8s.io/apiserver v0.27.3 h1:AxLvq9JYtveYWK+D/Dz/uoPCfz8JC9asR5z7+I/bbQ4= -k8s.io/apiserver v0.27.3/go.mod h1:Y61+EaBMVWUBJtxD5//cZ48cHZbQD+yIyV/4iEBhhNA= -k8s.io/klog/v2 v2.90.1 h1:m4bYOKall2MmOiRaR1J+We67Do7vm9KiQVlT96lnHUw= -k8s.io/klog/v2 v2.90.1/go.mod h1:y1WjHnz7Dj687irZUWR/WLkLc5N1YHtjLdmgWjndZn0= -k8s.io/utils v0.0.0-20230209194617-a36077c30491 h1:r0BAOLElQnnFhE/ApUsg3iHdVYYPBjNSSOMowRZxxsY= -k8s.io/utils v0.0.0-20230209194617-a36077c30491/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= +k8s.io/api v0.28.4 h1:8ZBrLjwosLl/NYgv1P7EQLqoO8MGQApnbgH8tu3BMzY= +k8s.io/api v0.28.4/go.mod h1:axWTGrY88s/5YE+JSt4uUi6NMM+gur1en2REMR7IRj0= +k8s.io/apimachinery v0.28.4 h1:zOSJe1mc+GxuMnFzD4Z/U1wst50X28ZNsn5bhgIIao8= +k8s.io/apimachinery v0.28.4/go.mod h1:wI37ncBvfAoswfq626yPTe6Bz1c22L7uaJ8dho83mgg= +k8s.io/apiserver v0.28.4 h1:BJXlaQbAU/RXYX2lRz+E1oPe3G3TKlozMMCZWu5GMgg= +k8s.io/apiserver v0.28.4/go.mod h1:Idq71oXugKZoVGUUL2wgBCTHbUR+FYTWa4rq9j4n23w= +k8s.io/klog/v2 v2.100.1 h1:7WCHKK6K8fNhTqfBhISHQ97KrnJNFZMcQvKp7gP/tmg= +k8s.io/klog/v2 v2.100.1/go.mod h1:y1WjHnz7Dj687irZUWR/WLkLc5N1YHtjLdmgWjndZn0= +k8s.io/utils v0.0.0-20230406110748-d93618cff8a2 h1:qY1Ad8PODbnymg2pRbkyMT/ylpTrCM8P2RJ0yroCyIk= +k8s.io/utils v0.0.0-20230406110748-d93618cff8a2/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= +nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= +rsc.io/binaryregexp v0.2.0 h1:HfqmD5MEmC0zvwBuF187nq9mdnXjXsSivRiXN7SmRkE= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= -rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= -rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd h1:EDPBXCAspyGV4jQlpZSudPeMmr1bNJefnuqLsRAsHZo= sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd/go.mod h1:B8JuhiUyNFVKdsE8h686QcCxMaH6HrOAZj4vswFpcB0= sigs.k8s.io/structured-merge-diff/v4 v4.2.3 h1:PRbqxJClWWYMNV1dhaG4NsibJbArud9kFxnAMREiWFE= sigs.k8s.io/structured-merge-diff/v4 v4.2.3/go.mod h1:qjx8mGObPmV2aSZepjQjbmb2ihdVs8cGKBraizNC69E= sigs.k8s.io/yaml v1.3.0 h1:a2VclLzOGrwOHDiV8EfBGhvjHvP46CtW5j6POvhYGGo= +sigs.k8s.io/yaml v1.3.0/go.mod h1:GeOyir5tyXNByN85N/dRIT9es5UQNerPYEKK56eTBm8= diff --git a/make_chocolatey.ps1 b/make_chocolatey.ps1 index 67f85c33d89..cceed28402f 100644 --- a/make_chocolatey.ps1 +++ b/make_chocolatey.ps1 @@ -15,4 +15,6 @@ if ($version.Contains("-")) Set-Location .\windows\Chocolatey\crowdsec Copy-Item ..\..\..\crowdsec_$version.msi tools\crowdsec.msi -choco pack --version $version \ No newline at end of file +choco pack --version $version + +Copy-Item crowdsec.$version.nupkg ..\..\..\ \ No newline at end of file diff --git a/make_installer.ps1 b/make_installer.ps1 index a20ffaf55b5..c927452ff72 100644 --- a/make_installer.ps1 +++ b/make_installer.ps1 @@ -1,7 +1,7 @@ param ( $version ) -$env:Path += ";C:\Program Files (x86)\WiX Toolset v3.11\bin" +$env:Path += ";C:\Program Files (x86)\WiX Toolset v3.14\bin" if ($version.StartsWith("v")) { $version = $version.Substring(1) diff --git a/mk/check_go_version.ps1 b/mk/check_go_version.ps1 deleted file mode 100644 index 6060cb22751..00000000000 --- a/mk/check_go_version.ps1 +++ /dev/null @@ -1,19 +0,0 @@ -##This must be called with $(MINIMUM_SUPPORTED_GO_MAJOR_VERSION) $(MINIMUM_SUPPORTED_GO_MINOR_VERSION) in this order -$min_major=$args[0] -$min_minor=$args[1] -$goversion = (go env GOVERSION).replace("go","").split(".") -$goversion_major=$goversion[0] -$goversion_minor=$goversion[1] -$err_msg="Golang version $goversion_major.$goversion_minor is not supported, please use least $min_major.$min_minor" - -if ( $goversion_major -gt $min_major ) { - exit 0; -} -elseif ($goversion_major -lt $min_major) { - Write-Output $err_msg; - exit 1; -} -elseif ($goversion_minor -lt $min_minor) { - Write-Output $(GO_VERSION_VALIDATION_ERR_MSG); - exit 1; -} diff --git a/mk/goversion.mk b/mk/goversion.mk deleted file mode 100644 index dd99549283f..00000000000 --- a/mk/goversion.mk +++ /dev/null @@ -1,36 +0,0 @@ - -BUILD_GOVERSION = $(subst go,,$(shell go env GOVERSION)) - -go_major_minor = $(subst ., ,$(BUILD_GOVERSION)) -GO_MAJOR_VERSION = $(word 1, $(go_major_minor)) -GO_MINOR_VERSION = $(word 2, $(go_major_minor)) - -GO_VERSION_VALIDATION_ERR_MSG = Golang version ($(BUILD_GOVERSION)) is not supported, please use at least $(BUILD_REQUIRE_GO_MAJOR).$(BUILD_REQUIRE_GO_MINOR) - - -.PHONY: goversion -goversion: $(if $(findstring devel,$(shell go env GOVERSION)),goversion_devel,goversion_check) - - -.PHONY: goversion_devel -goversion_devel: - $(warning WARNING: You are using a development version of Golang ($(BUILD_GOVERSION)) which is not supported. For production environments, use a stable version (at least $(BUILD_REQUIRE_GO_MAJOR).$(BUILD_REQUIRE_GO_MINOR))) - $(info ) - - -.PHONY: goversion_check -goversion_check: -ifneq ($(OS), Windows_NT) - @if [ $(GO_MAJOR_VERSION) -gt $(BUILD_REQUIRE_GO_MAJOR) ]; then \ - exit 0; \ - elif [ $(GO_MAJOR_VERSION) -lt $(BUILD_REQUIRE_GO_MAJOR) ]; then \ - echo '$(GO_VERSION_VALIDATION_ERR_MSG)';\ - exit 1; \ - elif [ $(GO_MINOR_VERSION) -lt $(BUILD_REQUIRE_GO_MINOR) ] ; then \ - echo '$(GO_VERSION_VALIDATION_ERR_MSG)';\ - exit 1; \ - fi -else - # This needs Set-ExecutionPolicy -Scope CurrentUser Unrestricted - @$(CURDIR)/mk/check_go_version.ps1 $(BUILD_REQUIRE_GO_MAJOR) $(BUILD_REQUIRE_GO_MINOR) -endif diff --git a/mk/help.mk b/mk/help.mk new file mode 100644 index 00000000000..36392255efb --- /dev/null +++ b/mk/help.mk @@ -0,0 +1,5 @@ +.PHONY: help +help: + @grep -E '^[a-zA-Z0-9_-]+:.*?## .*$$' $(MAKEFILE_LIST) \ + | sed -n 's/^.*:\(.*\): \(.*\)##\(.*\)/\1:\3/p' \ + | awk 'BEGIN {FS = ":"; printf "\033[33m"} {printf "%-20s \033[32m %s\033[0m\n", $$1, $$2}' diff --git a/mk/platform.mk b/mk/platform.mk index 9e375de3e77..b639723b612 100644 --- a/mk/platform.mk +++ b/mk/platform.mk @@ -1,7 +1,7 @@ BUILD_CODENAME ?= alphaga GOARCH ?= $(shell go env GOARCH) -BUILD_TAG ?= $(shell git rev-parse HEAD) +BUILD_TAG ?= $(shell git rev-parse --short HEAD) ifeq ($(OS), Windows_NT) SHELL := pwsh.exe diff --git a/mk/platform/unix_common.mk b/mk/platform/unix_common.mk index 8f06c93284f..5e5b5de3a43 100644 --- a/mk/platform/unix_common.mk +++ b/mk/platform/unix_common.mk @@ -8,7 +8,11 @@ MKDIR=mkdir -p GOOS ?= $(shell go env GOOS) # Current versioning information from env -BUILD_VERSION?=$(shell git describe --tags) +# The $(or) is used to ignore an empty BUILD_VERSION when it's an envvar, +# like inside a docker build: docker build --build-arg BUILD_VERSION=1.2.3 +# as opposed to a make parameter: make BUILD_VERSION=1.2.3 +BUILD_VERSION:=$(or $(BUILD_VERSION),$(shell git describe --tags --dirty)) + BUILD_TIMESTAMP=$(shell date +%F"_"%T) DEFAULT_CONFIGDIR?=/etc/crowdsec DEFAULT_DATADIR?=/var/lib/crowdsec/data diff --git a/pkg/acquisition/acquisition.go b/pkg/acquisition/acquisition.go index 11d5320e5f1..4519ea7392b 100644 --- a/pkg/acquisition/acquisition.go +++ b/pkg/acquisition/acquisition.go @@ -1,37 +1,27 @@ package acquisition import ( + "context" "errors" "fmt" "io" "os" "strings" - "github.com/antonmedv/expr" - "github.com/antonmedv/expr/vm" + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" tomb "gopkg.in/tomb.v2" "gopkg.in/yaml.v2" - "github.com/crowdsecurity/go-cs-lib/pkg/trace" + "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" - cloudwatchacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/cloudwatch" - dockeracquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/docker" - fileacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/file" - httpacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/http" - journalctlacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/journalctl" - kafkaacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/kafka" - kinesisacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/kinesis" - k8sauditacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/kubernetesaudit" - s3acquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/s3" - syslogacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/syslog" - wineventlogacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/wineventlog" - "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/cwversion/component" + "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -50,83 +40,110 @@ func (e *DataSourceUnavailableError) Unwrap() error { // The interface each datasource must implement type DataSource interface { - GetMetrics() []prometheus.Collector // Returns pointers to metrics that are managed by the module - GetAggregMetrics() []prometheus.Collector // Returns pointers to metrics that are managed by the module (aggregated mode, limits cardinality) - UnmarshalConfig([]byte) error // Decode and pre-validate the YAML datasource - anything that can be checked before runtime - Configure([]byte, *log.Entry) error // Complete the YAML datasource configuration and perform runtime checks. - ConfigureByDSN(string, map[string]string, *log.Entry, string) error // Configure the datasource - GetMode() string // Get the mode (TAIL, CAT or SERVER) - GetName() string // Get the name of the module - OneShotAcquisition(chan types.Event, *tomb.Tomb) error // Start one shot acquisition(eg, cat a file) - StreamingAcquisition(chan types.Event, *tomb.Tomb) error // Start live acquisition (eg, tail a file) - CanRun() error // Whether the datasource can run or not (eg, journalctl on BSD is a non-sense) - GetUuid() string // Get the unique identifier of the datasource + GetMetrics() []prometheus.Collector // Returns pointers to metrics that are managed by the module + GetAggregMetrics() []prometheus.Collector // Returns pointers to metrics that are managed by the module (aggregated mode, limits cardinality) + UnmarshalConfig([]byte) error // Decode and pre-validate the YAML datasource - anything that can be checked before runtime + Configure([]byte, *log.Entry, int) error // Complete the YAML datasource configuration and perform runtime checks. + ConfigureByDSN(string, map[string]string, *log.Entry, string) error // Configure the datasource + GetMode() string // Get the mode (TAIL, CAT or SERVER) + GetName() string // Get the name of the module + OneShotAcquisition(chan types.Event, *tomb.Tomb) error // Start one shot acquisition(eg, cat a file) + StreamingAcquisition(context.Context, chan types.Event, *tomb.Tomb) error // Start live acquisition (eg, tail a file) + CanRun() error // Whether the datasource can run or not (eg, journalctl on BSD is a non-sense) + GetUuid() string // Get the unique identifier of the datasource Dump() interface{} } -var AcquisitionSources = map[string]func() DataSource{ - "file": func() DataSource { return &fileacquisition.FileSource{} }, - "journalctl": func() DataSource { return &journalctlacquisition.JournalCtlSource{} }, - "cloudwatch": func() DataSource { return &cloudwatchacquisition.CloudwatchSource{} }, - "syslog": func() DataSource { return &syslogacquisition.SyslogSource{} }, - "docker": func() DataSource { return &dockeracquisition.DockerSource{} }, - "kinesis": func() DataSource { return &kinesisacquisition.KinesisSource{} }, - "wineventlog": func() DataSource { return &wineventlogacquisition.WinEventLogSource{} }, - "kafka": func() DataSource { return &kafkaacquisition.KafkaSource{} }, - "k8s-audit": func() DataSource { return &k8sauditacquisition.KubernetesAuditSource{} }, - "s3": func() DataSource { return &s3acquisition.S3Source{} }, - "http": func() DataSource { return &httpacquisition.HTTPSource{} }, +var ( + // We declare everything here so we can tell if they are unsupported, or excluded from the build + AcquisitionSources = map[string]func() DataSource{} + transformRuntimes = map[string]*vm.Program{} +) + +func GetDataSourceIface(dataSourceType string) (DataSource, error) { + source, registered := AcquisitionSources[dataSourceType] + if registered { + return source(), nil + } + + built, known := component.Built["datasource_"+dataSourceType] + + if !known { + return nil, fmt.Errorf("unknown data source %s", dataSourceType) + } + + if built { + panic("datasource " + dataSourceType + " is built but not registered") + } + + return nil, fmt.Errorf("data source %s is not built in this version of crowdsec", dataSourceType) } -var transformRuntimes = map[string]*vm.Program{} +// registerDataSource registers a datasource in the AcquisitionSources map. +// It must be called in the init() function of the datasource package, and the datasource name +// must be declared with a nil value in the map, to allow for conditional compilation. +func registerDataSource(dataSourceType string, dsGetter func() DataSource) { + component.Register("datasource_" + dataSourceType) -func GetDataSourceIface(dataSourceType string) DataSource { - source := AcquisitionSources[dataSourceType] - if source == nil { - return nil + AcquisitionSources[dataSourceType] = dsGetter +} + +// setupLogger creates a logger for the datasource to use at runtime. +func setupLogger(source, name string, level *log.Level) (*log.Entry, error) { + clog := log.New() + if err := types.ConfigureLogger(clog); err != nil { + return nil, fmt.Errorf("while configuring datasource logger: %w", err) + } + + if level != nil { + clog.SetLevel(*level) + } + + fields := log.Fields{ + "type": source, + } + + if name != "" { + fields["name"] = name } - return source() + + subLogger := clog.WithFields(fields) + + return subLogger, nil } // DataSourceConfigure creates and returns a DataSource object from a configuration, // if the configuration is not valid it returns an error. // If the datasource can't be run (eg. journalctl not available), it still returns an error which // can be checked for the appropriate action. -func DataSourceConfigure(commonConfig configuration.DataSourceCommonCfg) (*DataSource, error) { +func DataSourceConfigure(commonConfig configuration.DataSourceCommonCfg, metricsLevel int) (*DataSource, error) { // we dump it back to []byte, because we want to decode the yaml blob twice: // once to DataSourceCommonCfg, and then later to the dedicated type of the datasource yamlConfig, err := yaml.Marshal(commonConfig) if err != nil { - return nil, fmt.Errorf("unable to marshal back interface: %w", err) + return nil, fmt.Errorf("unable to serialize back interface: %w", err) } - if dataSrc := GetDataSourceIface(commonConfig.Source); dataSrc != nil { - /* this logger will then be used by the datasource at runtime */ - clog := log.New() - if err := types.ConfigureLogger(clog); err != nil { - return nil, fmt.Errorf("while configuring datasource logger: %w", err) - } - if commonConfig.LogLevel != nil { - clog.SetLevel(*commonConfig.LogLevel) - } - customLog := log.Fields{ - "type": commonConfig.Source, - } - if commonConfig.Name != "" { - customLog["name"] = commonConfig.Name - } - subLogger := clog.WithFields(customLog) - /* check eventual dependencies are satisfied (ie. journald will check journalctl availability) */ - if err := dataSrc.CanRun(); err != nil { - return nil, &DataSourceUnavailableError{Name: commonConfig.Source, Err: err} - } - /* configure the actual datasource */ - if err := dataSrc.Configure(yamlConfig, subLogger); err != nil { - return nil, fmt.Errorf("failed to configure datasource %s: %w", commonConfig.Source, err) - } - return &dataSrc, nil + dataSrc, err := GetDataSourceIface(commonConfig.Source) + if err != nil { + return nil, err + } + + subLogger, err := setupLogger(commonConfig.Source, commonConfig.Name, commonConfig.LogLevel) + if err != nil { + return nil, err + } + + /* check eventual dependencies are satisfied (ie. journald will check journalctl availability) */ + if err := dataSrc.CanRun(); err != nil { + return nil, &DataSourceUnavailableError{Name: commonConfig.Source, Err: err} + } + /* configure the actual datasource */ + if err := dataSrc.Configure(yamlConfig, subLogger, metricsLevel); err != nil { + return nil, fmt.Errorf("failed to configure datasource %s: %w", commonConfig.Source, err) } - return nil, fmt.Errorf("cannot find source %s", commonConfig.Source) + + return &dataSrc, nil } // detectBackwardCompatAcquis: try to magically detect the type for backward compat (type was not mandatory then) @@ -134,12 +151,15 @@ func detectBackwardCompatAcquis(sub configuration.DataSourceCommonCfg) string { if _, ok := sub.Config["filename"]; ok { return "file" } + if _, ok := sub.Config["filenames"]; ok { return "file" } + if _, ok := sub.Config["journalctl_filter"]; ok { return "journalctl" } + return "" } @@ -150,109 +170,160 @@ func LoadAcquisitionFromDSN(dsn string, labels map[string]string, transformExpr if len(frags) == 1 { return nil, fmt.Errorf("%s isn't valid dsn (no protocol)", dsn) } - dataSrc := GetDataSourceIface(frags[0]) - if dataSrc == nil { - return nil, fmt.Errorf("no acquisition for protocol %s://", frags[0]) + + dataSrc, err := GetDataSourceIface(frags[0]) + if err != nil { + return nil, fmt.Errorf("no acquisition for protocol %s:// - %w", frags[0], err) } - /* this logger will then be used by the datasource at runtime */ - clog := log.New() - if err := types.ConfigureLogger(clog); err != nil { - return nil, fmt.Errorf("while configuring datasource logger: %w", err) + + subLogger, err := setupLogger(dsn, "", nil) + if err != nil { + return nil, err } - subLogger := clog.WithFields(log.Fields{ - "type": dsn, - }) + uniqueId := uuid.NewString() + if transformExpr != "" { vm, err := expr.Compile(transformExpr, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { return nil, fmt.Errorf("while compiling transform expression '%s': %w", transformExpr, err) } + transformRuntimes[uniqueId] = vm } - err := dataSrc.ConfigureByDSN(dsn, labels, subLogger, uniqueId) + + err = dataSrc.ConfigureByDSN(dsn, labels, subLogger, uniqueId) if err != nil { return nil, fmt.Errorf("while configuration datasource for %s: %w", dsn, err) } + sources = append(sources, dataSrc) + return sources, nil } +func GetMetricsLevelFromPromCfg(prom *csconfig.PrometheusCfg) int { + if prom == nil { + return configuration.METRICS_FULL + } + + if !prom.Enabled { + return configuration.METRICS_NONE + } + + if prom.Level == configuration.CFG_METRICS_AGGREGATE { + return configuration.METRICS_AGGREGATE + } + + if prom.Level == configuration.CFG_METRICS_FULL { + return configuration.METRICS_FULL + } + + return configuration.METRICS_FULL +} + // LoadAcquisitionFromFile unmarshals the configuration item and checks its availability -func LoadAcquisitionFromFile(config *csconfig.CrowdsecServiceCfg) ([]DataSource, error) { +func LoadAcquisitionFromFile(config *csconfig.CrowdsecServiceCfg, prom *csconfig.PrometheusCfg) ([]DataSource, error) { var sources []DataSource + metrics_level := GetMetricsLevelFromPromCfg(prom) + for _, acquisFile := range config.AcquisitionFiles { log.Infof("loading acquisition file : %s", acquisFile) + yamlFile, err := os.Open(acquisFile) if err != nil { return nil, err } + dec := yaml.NewDecoder(yamlFile) dec.SetStrict(true) + idx := -1 + for { var sub configuration.DataSourceCommonCfg - err = dec.Decode(&sub) + idx += 1 + + err = dec.Decode(&sub) if err != nil { if !errors.Is(err, io.EOF) { return nil, fmt.Errorf("failed to yaml decode %s: %w", acquisFile, err) } + log.Tracef("End of yaml file") + break } - //for backward compat ('type' was not mandatory, detect it) + // for backward compat ('type' was not mandatory, detect it) if guessType := detectBackwardCompatAcquis(sub); guessType != "" { sub.Source = guessType } - //it's an empty item, skip it + // it's an empty item, skip it if len(sub.Labels) == 0 { if sub.Source == "" { log.Debugf("skipping empty item in %s", acquisFile) continue } - return nil, fmt.Errorf("missing labels in %s (position: %d)", acquisFile, idx) + + if sub.Source != "docker" { + // docker is the only source that can be empty + return nil, fmt.Errorf("missing labels in %s (position: %d)", acquisFile, idx) + } } + if sub.Source == "" { return nil, fmt.Errorf("data source type is empty ('source') in %s (position: %d)", acquisFile, idx) } - if GetDataSourceIface(sub.Source) == nil { - return nil, fmt.Errorf("unknown data source %s in %s (position: %d)", sub.Source, acquisFile, idx) + + // pre-check that the source is valid + _, err := GetDataSourceIface(sub.Source) + if err != nil { + return nil, fmt.Errorf("in file %s (position: %d) - %w", acquisFile, idx, err) } + uniqueId := uuid.NewString() sub.UniqueId = uniqueId - src, err := DataSourceConfigure(sub) + + src, err := DataSourceConfigure(sub, metrics_level) if err != nil { var dserr *DataSourceUnavailableError if errors.As(err, &dserr) { log.Error(err) continue } + return nil, fmt.Errorf("while configuring datasource of type %s from %s (position: %d): %w", sub.Source, acquisFile, idx, err) } + if sub.TransformExpr != "" { vm, err := expr.Compile(sub.TransformExpr, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { return nil, fmt.Errorf("while compiling transform expression '%s' for datasource %s in %s (position: %d): %w", sub.TransformExpr, sub.Source, acquisFile, idx, err) } + transformRuntimes[uniqueId] = vm } + sources = append(sources, *src) } } + return sources, nil } func GetMetrics(sources []DataSource, aggregated bool) error { var metrics []prometheus.Collector - for i := 0; i < len(sources); i++ { + + for i := range sources { if aggregated { metrics = sources[i].GetMetrics() } else { metrics = sources[i].GetAggregMetrics() } + for _, metric := range metrics { if err := prometheus.Register(metric); err != nil { if _, ok := err.(prometheus.AlreadyRegisteredError); !ok { @@ -262,12 +333,14 @@ func GetMetrics(sources []DataSource, aggregated bool) error { } } } + return nil } func transform(transformChan chan types.Event, output chan types.Event, AcquisTomb *tomb.Tomb, transformRuntime *vm.Program, logger *log.Entry) { defer trace.CatchPanic("crowdsec/acquis") logger.Infof("transformer started") + for { select { case <-AcquisTomb.Dying(): @@ -275,15 +348,18 @@ func transform(transformChan chan types.Event, output chan types.Event, AcquisTo return case evt := <-transformChan: logger.Tracef("Received event %s", evt.Line.Raw) + out, err := expr.Run(transformRuntime, map[string]interface{}{"evt": &evt}) if err != nil { logger.Errorf("while running transform expression: %s, sending event as-is", err) output <- evt } + if out == nil { logger.Errorf("transform expression returned nil, sending event as-is") output <- evt } + switch v := out.(type) { case string: logger.Tracef("transform expression returned %s", v) @@ -291,18 +367,22 @@ func transform(transformChan chan types.Event, output chan types.Event, AcquisTo output <- evt case []interface{}: logger.Tracef("transform expression returned %v", v) //nolint:asasalint // We actually want to log the slice content + for _, line := range v { l, ok := line.(string) if !ok { logger.Errorf("transform expression returned []interface{}, but cannot assert an element to string") output <- evt + continue } + evt.Line.Raw = l output <- evt } case []string: logger.Tracef("transform expression returned %v", v) + for _, line := range v { evt.Line.Raw = line output <- evt @@ -315,49 +395,58 @@ func transform(transformChan chan types.Event, output chan types.Event, AcquisTo } } -func StartAcquisition(sources []DataSource, output chan types.Event, AcquisTomb *tomb.Tomb) error { +func StartAcquisition(ctx context.Context, sources []DataSource, output chan types.Event, AcquisTomb *tomb.Tomb) error { // Don't wait if we have no sources, as it will hang forever if len(sources) == 0 { return nil } - for i := 0; i < len(sources); i++ { - subsrc := sources[i] //ensure its a copy + for i := range sources { + subsrc := sources[i] // ensure its a copy log.Debugf("starting one source %d/%d ->> %T", i, len(sources), subsrc) AcquisTomb.Go(func() error { defer trace.CatchPanic("crowdsec/acquis") + var err error outChan := output + log.Debugf("datasource %s UUID: %s", subsrc.GetName(), subsrc.GetUuid()) + if transformRuntime, ok := transformRuntimes[subsrc.GetUuid()]; ok { log.Infof("transform expression found for datasource %s", subsrc.GetName()) + transformChan := make(chan types.Event) outChan = transformChan transformLogger := log.WithFields(log.Fields{ "component": "transform", "datasource": subsrc.GetName(), }) + AcquisTomb.Go(func() error { transform(outChan, output, AcquisTomb, transformRuntime, transformLogger) return nil }) } + if subsrc.GetMode() == configuration.TAIL_MODE { - err = subsrc.StreamingAcquisition(outChan, AcquisTomb) + err = subsrc.StreamingAcquisition(ctx, outChan, AcquisTomb) } else { err = subsrc.OneShotAcquisition(outChan, AcquisTomb) } + if err != nil { - //if one of the acqusition returns an error, we kill the others to properly shutdown + // if one of the acqusition returns an error, we kill the others to properly shutdown AcquisTomb.Kill(err) } + return nil }) } /*return only when acquisition is over (cat) or never (tail)*/ err := AcquisTomb.Wait() + return err } diff --git a/pkg/acquisition/acquisition_test.go b/pkg/acquisition/acquisition_test.go index 548ecc04bb5..e82b3df54c2 100644 --- a/pkg/acquisition/acquisition_test.go +++ b/pkg/acquisition/acquisition_test.go @@ -1,6 +1,8 @@ package acquisition import ( + "context" + "errors" "fmt" "strings" "testing" @@ -13,7 +15,7 @@ import ( tomb "gopkg.in/tomb.v2" "gopkg.in/yaml.v2" - "github.com/crowdsecurity/go-cs-lib/pkg/cstest" + "github.com/crowdsecurity/go-cs-lib/cstest" "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" "github.com/crowdsecurity/crowdsec/pkg/csconfig" @@ -35,32 +37,38 @@ func (f *MockSource) UnmarshalConfig(cfg []byte) error { return nil } -func (f *MockSource) Configure(cfg []byte, logger *log.Entry) error { +func (f *MockSource) Configure(cfg []byte, logger *log.Entry, metricsLevel int) error { f.logger = logger if err := f.UnmarshalConfig(cfg); err != nil { return err } + if f.Mode == "" { f.Mode = configuration.CAT_MODE } + if f.Mode != configuration.CAT_MODE && f.Mode != configuration.TAIL_MODE { return fmt.Errorf("mode %s is not supported", f.Mode) } + if f.Toto == "" { - return fmt.Errorf("expect non-empty toto") + return errors.New("expect non-empty toto") } + return nil } -func (f *MockSource) GetMode() string { return f.Mode } -func (f *MockSource) OneShotAcquisition(chan types.Event, *tomb.Tomb) error { return nil } -func (f *MockSource) StreamingAcquisition(chan types.Event, *tomb.Tomb) error { return nil } -func (f *MockSource) CanRun() error { return nil } -func (f *MockSource) GetMetrics() []prometheus.Collector { return nil } -func (f *MockSource) GetAggregMetrics() []prometheus.Collector { return nil } -func (f *MockSource) Dump() interface{} { return f } -func (f *MockSource) GetName() string { return "mock" } +func (f *MockSource) GetMode() string { return f.Mode } +func (f *MockSource) OneShotAcquisition(chan types.Event, *tomb.Tomb) error { return nil } +func (f *MockSource) StreamingAcquisition(context.Context, chan types.Event, *tomb.Tomb) error { + return nil +} +func (f *MockSource) CanRun() error { return nil } +func (f *MockSource) GetMetrics() []prometheus.Collector { return nil } +func (f *MockSource) GetAggregMetrics() []prometheus.Collector { return nil } +func (f *MockSource) Dump() interface{} { return f } +func (f *MockSource) GetName() string { return "mock" } func (f *MockSource) ConfigureByDSN(string, map[string]string, *log.Entry, string) error { - return fmt.Errorf("not supported") + return errors.New("not supported") } func (f *MockSource) GetUuid() string { return "" } @@ -69,21 +77,18 @@ type MockSourceCantRun struct { MockSource } -func (f *MockSourceCantRun) CanRun() error { return fmt.Errorf("can't run bro") } +func (f *MockSourceCantRun) CanRun() error { return errors.New("can't run bro") } func (f *MockSourceCantRun) GetName() string { return "mock_cant_run" } // appendMockSource is only used to add mock source for tests func appendMockSource() { - if GetDataSourceIface("mock") == nil { - AcquisitionSources["mock"] = func() DataSource { return &MockSource{} } - } - if GetDataSourceIface("mock_cant_run") == nil { - AcquisitionSources["mock_cant_run"] = func() DataSource { return &MockSourceCantRun{} } - } + AcquisitionSources["mock"] = func() DataSource { return &MockSource{} } + AcquisitionSources["mock_cant_run"] = func() DataSource { return &MockSourceCantRun{} } } func TestDataSourceConfigure(t *testing.T) { appendMockSource() + tests := []struct { TestName string String string @@ -143,7 +148,7 @@ labels: log_level: debug source: tutu `, - ExpectedError: "cannot find source tutu", + ExpectedError: "unknown data source tutu", }, { TestName: "mismatch_config", @@ -172,12 +177,12 @@ wowo: ajsajasjas } for _, tc := range tests { - tc := tc t.Run(tc.TestName, func(t *testing.T) { common := configuration.DataSourceCommonCfg{} yaml.Unmarshal([]byte(tc.String), &common) - ds, err := DataSourceConfigure(common) + ds, err := DataSourceConfigure(common, configuration.METRICS_NONE) cstest.RequireErrorContains(t, err, tc.ExpectedError) + if tc.ExpectedError != "" { return } @@ -185,22 +190,22 @@ wowo: ajsajasjas switch tc.TestName { case "basic_valid_config": mock := (*ds).Dump().(*MockSource) - assert.Equal(t, mock.Toto, "test_value1") - assert.Equal(t, mock.Mode, "cat") - assert.Equal(t, mock.logger.Logger.Level, log.InfoLevel) - assert.Equal(t, mock.Labels, map[string]string{"test": "foobar"}) + assert.Equal(t, "test_value1", mock.Toto) + assert.Equal(t, "cat", mock.Mode) + assert.Equal(t, log.InfoLevel, mock.logger.Logger.Level) + assert.Equal(t, map[string]string{"test": "foobar"}, mock.Labels) case "basic_debug_config": mock := (*ds).Dump().(*MockSource) - assert.Equal(t, mock.Toto, "test_value1") - assert.Equal(t, mock.Mode, "cat") - assert.Equal(t, mock.logger.Logger.Level, log.DebugLevel) - assert.Equal(t, mock.Labels, map[string]string{"test": "foobar"}) + assert.Equal(t, "test_value1", mock.Toto) + assert.Equal(t, "cat", mock.Mode) + assert.Equal(t, log.DebugLevel, mock.logger.Logger.Level) + assert.Equal(t, map[string]string{"test": "foobar"}, mock.Labels) case "basic_tailmode_config": mock := (*ds).Dump().(*MockSource) - assert.Equal(t, mock.Toto, "test_value1") - assert.Equal(t, mock.Mode, "tail") - assert.Equal(t, mock.logger.Logger.Level, log.DebugLevel) - assert.Equal(t, mock.Labels, map[string]string{"test": "foobar"}) + assert.Equal(t, "test_value1", mock.Toto) + assert.Equal(t, "tail", mock.Mode) + assert.Equal(t, log.DebugLevel, mock.logger.Logger.Level) + assert.Equal(t, map[string]string{"test": "foobar"}, mock.Labels) } }) } @@ -208,6 +213,7 @@ wowo: ajsajasjas func TestLoadAcquisitionFromFile(t *testing.T) { appendMockSource() + tests := []struct { TestName string Config csconfig.CrowdsecServiceCfg @@ -263,7 +269,7 @@ func TestLoadAcquisitionFromFile(t *testing.T) { Config: csconfig.CrowdsecServiceCfg{ AcquisitionFiles: []string{"test_files/bad_source.yaml"}, }, - ExpectedError: "unknown data source does_not_exist in test_files/bad_source.yaml", + ExpectedError: "in file test_files/bad_source.yaml (position: 0) - unknown data source does_not_exist", }, { TestName: "invalid_filetype_config", @@ -274,17 +280,16 @@ func TestLoadAcquisitionFromFile(t *testing.T) { }, } for _, tc := range tests { - tc := tc t.Run(tc.TestName, func(t *testing.T) { - dss, err := LoadAcquisitionFromFile(&tc.Config) + dss, err := LoadAcquisitionFromFile(&tc.Config, nil) cstest.RequireErrorContains(t, err, tc.ExpectedError) + if tc.ExpectedError != "" { return } assert.Len(t, dss, tc.ExpectedLen) }) - } } @@ -299,14 +304,16 @@ type MockCat struct { logger *log.Entry } -func (f *MockCat) Configure(cfg []byte, logger *log.Entry) error { +func (f *MockCat) Configure(cfg []byte, logger *log.Entry, metricsLevel int) error { f.logger = logger if f.Mode == "" { f.Mode = configuration.CAT_MODE } + if f.Mode != configuration.CAT_MODE { return fmt.Errorf("mode %s is not supported", f.Mode) } + return nil } @@ -314,22 +321,24 @@ func (f *MockCat) UnmarshalConfig(cfg []byte) error { return nil } func (f *MockCat) GetName() string { return "mock_cat" } func (f *MockCat) GetMode() string { return "cat" } func (f *MockCat) OneShotAcquisition(out chan types.Event, tomb *tomb.Tomb) error { - for i := 0; i < 10; i++ { + for range 10 { evt := types.Event{} evt.Line.Src = "test" out <- evt } + return nil } -func (f *MockCat) StreamingAcquisition(chan types.Event, *tomb.Tomb) error { - return fmt.Errorf("can't run in tail") + +func (f *MockCat) StreamingAcquisition(context.Context, chan types.Event, *tomb.Tomb) error { + return errors.New("can't run in tail") } func (f *MockCat) CanRun() error { return nil } func (f *MockCat) GetMetrics() []prometheus.Collector { return nil } func (f *MockCat) GetAggregMetrics() []prometheus.Collector { return nil } func (f *MockCat) Dump() interface{} { return f } func (f *MockCat) ConfigureByDSN(string, map[string]string, *log.Entry, string) error { - return fmt.Errorf("not supported") + return errors.New("not supported") } func (f *MockCat) GetUuid() string { return "" } @@ -340,14 +349,16 @@ type MockTail struct { logger *log.Entry } -func (f *MockTail) Configure(cfg []byte, logger *log.Entry) error { +func (f *MockTail) Configure(cfg []byte, logger *log.Entry, metricsLevel int) error { f.logger = logger if f.Mode == "" { f.Mode = configuration.TAIL_MODE } + if f.Mode != configuration.TAIL_MODE { return fmt.Errorf("mode %s is not supported", f.Mode) } + return nil } @@ -355,15 +366,18 @@ func (f *MockTail) UnmarshalConfig(cfg []byte) error { return nil } func (f *MockTail) GetName() string { return "mock_tail" } func (f *MockTail) GetMode() string { return "tail" } func (f *MockTail) OneShotAcquisition(out chan types.Event, tomb *tomb.Tomb) error { - return fmt.Errorf("can't run in cat mode") + return errors.New("can't run in cat mode") } -func (f *MockTail) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { - for i := 0; i < 10; i++ { + +func (f *MockTail) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { + for range 10 { evt := types.Event{} evt.Line.Src = "test" out <- evt } + <-t.Dying() + return nil } func (f *MockTail) CanRun() error { return nil } @@ -371,13 +385,14 @@ func (f *MockTail) GetMetrics() []prometheus.Collector { return nil } func (f *MockTail) GetAggregMetrics() []prometheus.Collector { return nil } func (f *MockTail) Dump() interface{} { return f } func (f *MockTail) ConfigureByDSN(string, map[string]string, *log.Entry, string) error { - return fmt.Errorf("not supported") + return errors.New("not supported") } func (f *MockTail) GetUuid() string { return "" } -//func StartAcquisition(sources []DataSource, output chan types.Event, AcquisTomb *tomb.Tomb) error { +// func StartAcquisition(sources []DataSource, output chan types.Event, AcquisTomb *tomb.Tomb) error { func TestStartAcquisitionCat(t *testing.T) { + ctx := context.Background() sources := []DataSource{ &MockCat{}, } @@ -385,7 +400,7 @@ func TestStartAcquisitionCat(t *testing.T) { acquisTomb := tomb.Tomb{} go func() { - if err := StartAcquisition(sources, out, &acquisTomb); err != nil { + if err := StartAcquisition(ctx, sources, out, &acquisTomb); err != nil { t.Errorf("unexpected error") } }() @@ -405,6 +420,7 @@ READLOOP: } func TestStartAcquisitionTail(t *testing.T) { + ctx := context.Background() sources := []DataSource{ &MockTail{}, } @@ -412,7 +428,7 @@ func TestStartAcquisitionTail(t *testing.T) { acquisTomb := tomb.Tomb{} go func() { - if err := StartAcquisition(sources, out, &acquisTomb); err != nil { + if err := StartAcquisition(ctx, sources, out, &acquisTomb); err != nil { t.Errorf("unexpected error") } }() @@ -439,17 +455,20 @@ type MockTailError struct { MockTail } -func (f *MockTailError) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { - for i := 0; i < 10; i++ { +func (f *MockTailError) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { + for range 10 { evt := types.Event{} evt.Line.Src = "test" out <- evt } - t.Kill(fmt.Errorf("got error (tomb)")) - return fmt.Errorf("got error") + + t.Kill(errors.New("got error (tomb)")) + + return errors.New("got error") } func TestStartAcquisitionTailError(t *testing.T) { + ctx := context.Background() sources := []DataSource{ &MockTailError{}, } @@ -457,7 +476,7 @@ func TestStartAcquisitionTailError(t *testing.T) { acquisTomb := tomb.Tomb{} go func() { - if err := StartAcquisition(sources, out, &acquisTomb); err != nil && err.Error() != "got error (tomb)" { + if err := StartAcquisition(ctx, sources, out, &acquisTomb); err != nil && err.Error() != "got error (tomb)" { t.Errorf("expected error, got '%s'", err) } }() @@ -473,7 +492,7 @@ READLOOP: } } assert.Equal(t, 10, count) - //acquisTomb.Kill(nil) + // acquisTomb.Kill(nil) time.Sleep(1 * time.Second) cstest.RequireErrorContains(t, acquisTomb.Err(), "got error (tomb)") } @@ -484,21 +503,26 @@ type MockSourceByDSN struct { logger *log.Entry //nolint: unused } -func (f *MockSourceByDSN) UnmarshalConfig(cfg []byte) error { return nil } -func (f *MockSourceByDSN) Configure(cfg []byte, logger *log.Entry) error { return nil } -func (f *MockSourceByDSN) GetMode() string { return f.Mode } -func (f *MockSourceByDSN) OneShotAcquisition(chan types.Event, *tomb.Tomb) error { return nil } -func (f *MockSourceByDSN) StreamingAcquisition(chan types.Event, *tomb.Tomb) error { return nil } -func (f *MockSourceByDSN) CanRun() error { return nil } -func (f *MockSourceByDSN) GetMetrics() []prometheus.Collector { return nil } -func (f *MockSourceByDSN) GetAggregMetrics() []prometheus.Collector { return nil } -func (f *MockSourceByDSN) Dump() interface{} { return f } -func (f *MockSourceByDSN) GetName() string { return "mockdsn" } +func (f *MockSourceByDSN) UnmarshalConfig(cfg []byte) error { return nil } +func (f *MockSourceByDSN) Configure(cfg []byte, logger *log.Entry, metricsLevel int) error { + return nil +} +func (f *MockSourceByDSN) GetMode() string { return f.Mode } +func (f *MockSourceByDSN) OneShotAcquisition(chan types.Event, *tomb.Tomb) error { return nil } +func (f *MockSourceByDSN) StreamingAcquisition(context.Context, chan types.Event, *tomb.Tomb) error { + return nil +} +func (f *MockSourceByDSN) CanRun() error { return nil } +func (f *MockSourceByDSN) GetMetrics() []prometheus.Collector { return nil } +func (f *MockSourceByDSN) GetAggregMetrics() []prometheus.Collector { return nil } +func (f *MockSourceByDSN) Dump() interface{} { return f } +func (f *MockSourceByDSN) GetName() string { return "mockdsn" } func (f *MockSourceByDSN) ConfigureByDSN(dsn string, labels map[string]string, logger *log.Entry, uuid string) error { dsn = strings.TrimPrefix(dsn, "mockdsn://") if dsn != "test_expect" { - return fmt.Errorf("unexpected value") + return errors.New("unexpected value") } + return nil } func (f *MockSourceByDSN) GetUuid() string { return "" } @@ -527,12 +551,9 @@ func TestConfigureByDSN(t *testing.T) { }, } - if GetDataSourceIface("mockdsn") == nil { - AcquisitionSources["mockdsn"] = func() DataSource { return &MockSourceByDSN{} } - } + AcquisitionSources["mockdsn"] = func() DataSource { return &MockSourceByDSN{} } for _, tc := range tests { - tc := tc t.Run(tc.dsn, func(t *testing.T) { srcs, err := LoadAcquisitionFromDSN(tc.dsn, map[string]string{"type": "test_label"}, "") cstest.RequireErrorContains(t, err, tc.ExpectedError) diff --git a/pkg/acquisition/appsec.go b/pkg/acquisition/appsec.go new file mode 100644 index 00000000000..81616d3d2b8 --- /dev/null +++ b/pkg/acquisition/appsec.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_appsec + +package acquisition + +import ( + appsecacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/appsec" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("appsec", func() DataSource { return &appsecacquisition.AppsecSource{} }) +} diff --git a/pkg/acquisition/cloudwatch.go b/pkg/acquisition/cloudwatch.go new file mode 100644 index 00000000000..e6b3d3e3e53 --- /dev/null +++ b/pkg/acquisition/cloudwatch.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_cloudwatch + +package acquisition + +import ( + cloudwatchacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/cloudwatch" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("cloudwatch", func() DataSource { return &cloudwatchacquisition.CloudwatchSource{} }) +} diff --git a/pkg/acquisition/configuration/configuration.go b/pkg/acquisition/configuration/configuration.go index 5ec1a4ac4c3..3e27da1b9e6 100644 --- a/pkg/acquisition/configuration/configuration.go +++ b/pkg/acquisition/configuration/configuration.go @@ -19,3 +19,14 @@ type DataSourceCommonCfg struct { var TAIL_MODE = "tail" var CAT_MODE = "cat" var SERVER_MODE = "server" // No difference with tail, just a bit more verbose + +const ( + METRICS_NONE = iota + METRICS_AGGREGATE + METRICS_FULL +) + +const ( + CFG_METRICS_AGGREGATE = "aggregated" + CFG_METRICS_FULL = "full" +) diff --git a/pkg/acquisition/docker.go b/pkg/acquisition/docker.go new file mode 100644 index 00000000000..3bf792a039a --- /dev/null +++ b/pkg/acquisition/docker.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_docker + +package acquisition + +import ( + dockeracquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/docker" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("docker", func() DataSource { return &dockeracquisition.DockerSource{} }) +} diff --git a/pkg/acquisition/file.go b/pkg/acquisition/file.go new file mode 100644 index 00000000000..1ff2e4a3c0e --- /dev/null +++ b/pkg/acquisition/file.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_file + +package acquisition + +import ( + fileacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/file" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("file", func() DataSource { return &fileacquisition.FileSource{} }) +} diff --git a/pkg/acquisition/journalctl.go b/pkg/acquisition/journalctl.go new file mode 100644 index 00000000000..691f961ae77 --- /dev/null +++ b/pkg/acquisition/journalctl.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_journalctl + +package acquisition + +import ( + journalctlacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/journalctl" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("journalctl", func() DataSource { return &journalctlacquisition.JournalCtlSource{} }) +} diff --git a/pkg/acquisition/k8s.go b/pkg/acquisition/k8s.go new file mode 100644 index 00000000000..cb9446be285 --- /dev/null +++ b/pkg/acquisition/k8s.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_k8saudit + +package acquisition + +import ( + k8sauditacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/kubernetesaudit" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("k8s-audit", func() DataSource { return &k8sauditacquisition.KubernetesAuditSource{} }) +} diff --git a/pkg/acquisition/kafka.go b/pkg/acquisition/kafka.go new file mode 100644 index 00000000000..7d315d87feb --- /dev/null +++ b/pkg/acquisition/kafka.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_kafka + +package acquisition + +import ( + kafkaacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/kafka" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("kafka", func() DataSource { return &kafkaacquisition.KafkaSource{} }) +} diff --git a/pkg/acquisition/kinesis.go b/pkg/acquisition/kinesis.go new file mode 100644 index 00000000000..b41372e7fb9 --- /dev/null +++ b/pkg/acquisition/kinesis.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_kinesis + +package acquisition + +import ( + kinesisacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/kinesis" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("kinesis", func() DataSource { return &kinesisacquisition.KinesisSource{} }) +} diff --git a/pkg/acquisition/loki.go b/pkg/acquisition/loki.go new file mode 100644 index 00000000000..1eed6686591 --- /dev/null +++ b/pkg/acquisition/loki.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_loki + +package acquisition + +import ( + "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/loki" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("loki", func() DataSource { return &loki.LokiSource{} }) +} diff --git a/pkg/acquisition/modules/appsec/appsec.go b/pkg/acquisition/modules/appsec/appsec.go new file mode 100644 index 00000000000..5161b631c33 --- /dev/null +++ b/pkg/acquisition/modules/appsec/appsec.go @@ -0,0 +1,401 @@ +package appsecacquisition + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net" + "net/http" + "os" + "sync" + "time" + + "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" + log "github.com/sirupsen/logrus" + "gopkg.in/tomb.v2" + "gopkg.in/yaml.v2" + + "github.com/crowdsecurity/go-cs-lib/trace" + + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" + "github.com/crowdsecurity/crowdsec/pkg/appsec" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +const ( + InBand = "inband" + OutOfBand = "outofband" +) + +var DefaultAuthCacheDuration = (1 * time.Minute) + +// configuration structure of the acquis for the application security engine +type AppsecSourceConfig struct { + ListenAddr string `yaml:"listen_addr"` + ListenSocket string `yaml:"listen_socket"` + CertFilePath string `yaml:"cert_file"` + KeyFilePath string `yaml:"key_file"` + Path string `yaml:"path"` + Routines int `yaml:"routines"` + AppsecConfig string `yaml:"appsec_config"` + AppsecConfigPath string `yaml:"appsec_config_path"` + AuthCacheDuration *time.Duration `yaml:"auth_cache_duration"` + configuration.DataSourceCommonCfg `yaml:",inline"` +} + +// runtime structure of AppsecSourceConfig +type AppsecSource struct { + metricsLevel int + config AppsecSourceConfig + logger *log.Entry + mux *http.ServeMux + server *http.Server + outChan chan types.Event + InChan chan appsec.ParsedRequest + AppsecRuntime *appsec.AppsecRuntimeConfig + AppsecConfigs map[string]appsec.AppsecConfig + lapiURL string + AuthCache AuthCache + AppsecRunners []AppsecRunner // one for each go-routine +} + +// Struct to handle cache of authentication +type AuthCache struct { + APIKeys map[string]time.Time + mu sync.RWMutex +} + +func NewAuthCache() AuthCache { + return AuthCache{ + APIKeys: make(map[string]time.Time, 0), + mu: sync.RWMutex{}, + } +} + +func (ac *AuthCache) Set(apiKey string, expiration time.Time) { + ac.mu.Lock() + ac.APIKeys[apiKey] = expiration + ac.mu.Unlock() +} + +func (ac *AuthCache) Get(apiKey string) (time.Time, bool) { + ac.mu.RLock() + expiration, exists := ac.APIKeys[apiKey] + ac.mu.RUnlock() + return expiration, exists +} + +// @tko + @sbl : we might want to get rid of that or improve it +type BodyResponse struct { + Action string `json:"action"` +} + +func (w *AppsecSource) UnmarshalConfig(yamlConfig []byte) error { + err := yaml.UnmarshalStrict(yamlConfig, &w.config) + if err != nil { + return fmt.Errorf("cannot parse appsec configuration: %w", err) + } + + if w.config.ListenAddr == "" && w.config.ListenSocket == "" { + w.config.ListenAddr = "127.0.0.1:7422" + } + + if w.config.Path == "" { + w.config.Path = "/" + } + + if w.config.Path[0] != '/' { + w.config.Path = "/" + w.config.Path + } + + if w.config.Mode == "" { + w.config.Mode = configuration.TAIL_MODE + } + + // always have at least one appsec routine + if w.config.Routines == 0 { + w.config.Routines = 1 + } + + if w.config.AppsecConfig == "" && w.config.AppsecConfigPath == "" { + return errors.New("appsec_config or appsec_config_path must be set") + } + + if w.config.Name == "" { + if w.config.ListenSocket != "" && w.config.ListenAddr == "" { + w.config.Name = w.config.ListenSocket + } + if w.config.ListenSocket == "" { + w.config.Name = fmt.Sprintf("%s%s", w.config.ListenAddr, w.config.Path) + } + } + + csConfig := csconfig.GetConfig() + w.lapiURL = fmt.Sprintf("%sv1/decisions/stream", csConfig.API.Client.Credentials.URL) + w.AuthCache = NewAuthCache() + + return nil +} + +func (w *AppsecSource) GetMetrics() []prometheus.Collector { + return []prometheus.Collector{AppsecReqCounter, AppsecBlockCounter, AppsecRuleHits, AppsecOutbandParsingHistogram, AppsecInbandParsingHistogram, AppsecGlobalParsingHistogram} +} + +func (w *AppsecSource) GetAggregMetrics() []prometheus.Collector { + return []prometheus.Collector{AppsecReqCounter, AppsecBlockCounter, AppsecRuleHits, AppsecOutbandParsingHistogram, AppsecInbandParsingHistogram, AppsecGlobalParsingHistogram} +} + +func (w *AppsecSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLevel int) error { + err := w.UnmarshalConfig(yamlConfig) + if err != nil { + return fmt.Errorf("unable to parse appsec configuration: %w", err) + } + w.logger = logger + w.metricsLevel = MetricsLevel + w.logger.Tracef("Appsec configuration: %+v", w.config) + + if w.config.AuthCacheDuration == nil { + w.config.AuthCacheDuration = &DefaultAuthCacheDuration + w.logger.Infof("Cache duration for auth not set, using default: %v", *w.config.AuthCacheDuration) + } + + w.mux = http.NewServeMux() + + w.server = &http.Server{ + Addr: w.config.ListenAddr, + Handler: w.mux, + } + + w.InChan = make(chan appsec.ParsedRequest) + appsecCfg := appsec.AppsecConfig{Logger: w.logger.WithField("component", "appsec_config")} + + // let's load the associated appsec_config: + if w.config.AppsecConfigPath != "" { + err := appsecCfg.LoadByPath(w.config.AppsecConfigPath) + if err != nil { + return fmt.Errorf("unable to load appsec_config: %w", err) + } + } else if w.config.AppsecConfig != "" { + err := appsecCfg.Load(w.config.AppsecConfig) + if err != nil { + return fmt.Errorf("unable to load appsec_config: %w", err) + } + } else { + return errors.New("no appsec_config provided") + } + + w.AppsecRuntime, err = appsecCfg.Build() + if err != nil { + return fmt.Errorf("unable to build appsec_config: %w", err) + } + + err = w.AppsecRuntime.ProcessOnLoadRules() + if err != nil { + return fmt.Errorf("unable to process on load rules: %w", err) + } + + w.AppsecRunners = make([]AppsecRunner, w.config.Routines) + + for nbRoutine := range w.config.Routines { + appsecRunnerUUID := uuid.New().String() + // we copy AppsecRutime for each runner + wrt := *w.AppsecRuntime + wrt.Logger = w.logger.Dup().WithField("runner_uuid", appsecRunnerUUID) + runner := AppsecRunner{ + inChan: w.InChan, + UUID: appsecRunnerUUID, + logger: w.logger.WithField("runner_uuid", appsecRunnerUUID), + AppsecRuntime: &wrt, + Labels: w.config.Labels, + } + err := runner.Init(appsecCfg.GetDataDir()) + if err != nil { + return fmt.Errorf("unable to initialize runner: %w", err) + } + w.AppsecRunners[nbRoutine] = runner + } + + w.logger.Infof("Created %d appsec runners", len(w.AppsecRunners)) + + // We don´t use the wrapper provided by coraza because we want to fully control what happens when a rule match to send the information in crowdsec + w.mux.HandleFunc(w.config.Path, w.appsecHandler) + return nil +} + +func (w *AppsecSource) ConfigureByDSN(dsn string, labels map[string]string, logger *log.Entry, uuid string) error { + return errors.New("AppSec datasource does not support command line acquisition") +} + +func (w *AppsecSource) GetMode() string { + return w.config.Mode +} + +func (w *AppsecSource) GetName() string { + return "appsec" +} + +func (w *AppsecSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { + return errors.New("AppSec datasource does not support command line acquisition") +} + +func (w *AppsecSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { + w.outChan = out + t.Go(func() error { + defer trace.CatchPanic("crowdsec/acquis/appsec/live") + + w.logger.Infof("%d appsec runner to start", len(w.AppsecRunners)) + for _, runner := range w.AppsecRunners { + runner.outChan = out + t.Go(func() error { + defer trace.CatchPanic("crowdsec/acquis/appsec/live/runner") + return runner.Run(t) + }) + } + t.Go(func() error { + if w.config.ListenSocket != "" { + w.logger.Infof("creating unix socket %s", w.config.ListenSocket) + _ = os.RemoveAll(w.config.ListenSocket) + listener, err := net.Listen("unix", w.config.ListenSocket) + if err != nil { + return fmt.Errorf("appsec server failed: %w", err) + } + defer listener.Close() + if w.config.CertFilePath != "" && w.config.KeyFilePath != "" { + err = w.server.ServeTLS(listener, w.config.CertFilePath, w.config.KeyFilePath) + } else { + err = w.server.Serve(listener) + } + if err != nil && err != http.ErrServerClosed { + return fmt.Errorf("appsec server failed: %w", err) + } + } + return nil + }) + t.Go(func() error { + var err error + if w.config.ListenAddr != "" { + w.logger.Infof("creating TCP server on %s", w.config.ListenAddr) + if w.config.CertFilePath != "" && w.config.KeyFilePath != "" { + err = w.server.ListenAndServeTLS(w.config.CertFilePath, w.config.KeyFilePath) + } else { + err = w.server.ListenAndServe() + } + + if err != nil && err != http.ErrServerClosed { + return fmt.Errorf("appsec server failed: %w", err) + } + } + return nil + }) + <-t.Dying() + w.logger.Info("Shutting down Appsec server") + // xx let's clean up the appsec runners :) + appsec.AppsecRulesDetails = make(map[int]appsec.RulesDetails) + w.server.Shutdown(context.TODO()) + return nil + }) + return nil +} + +func (w *AppsecSource) CanRun() error { + return nil +} + +func (w *AppsecSource) GetUuid() string { + return w.config.UniqueId +} + +func (w *AppsecSource) Dump() interface{} { + return w +} + +func (w *AppsecSource) IsAuth(apiKey string) bool { + client := &http.Client{ + Timeout: 200 * time.Millisecond, + } + + req, err := http.NewRequest(http.MethodHead, w.lapiURL, nil) + if err != nil { + log.Errorf("Error creating request: %s", err) + return false + } + + req.Header.Add("X-Api-Key", apiKey) + resp, err := client.Do(req) + if err != nil { + log.Errorf("Error performing request: %s", err) + return false + } + defer resp.Body.Close() + + return resp.StatusCode == http.StatusOK +} + +// should this be in the runner ? +func (w *AppsecSource) appsecHandler(rw http.ResponseWriter, r *http.Request) { + w.logger.Debugf("Received request from '%s' on %s", r.RemoteAddr, r.URL.Path) + + apiKey := r.Header.Get(appsec.APIKeyHeaderName) + clientIP := r.Header.Get(appsec.IPHeaderName) + remoteIP := r.RemoteAddr + if apiKey == "" { + w.logger.Errorf("Unauthorized request from '%s' (real IP = %s)", remoteIP, clientIP) + rw.WriteHeader(http.StatusUnauthorized) + return + } + expiration, exists := w.AuthCache.Get(apiKey) + // if the apiKey is not in cache or has expired, just recheck the auth + if !exists || time.Now().After(expiration) { + if !w.IsAuth(apiKey) { + rw.WriteHeader(http.StatusUnauthorized) + w.logger.Errorf("Unauthorized request from '%s' (real IP = %s)", remoteIP, clientIP) + return + } + + // apiKey is valid, store it in cache + w.AuthCache.Set(apiKey, time.Now().Add(*w.config.AuthCacheDuration)) + } + + // parse the request only once + parsedRequest, err := appsec.NewParsedRequestFromRequest(r, w.logger) + if err != nil { + w.logger.Errorf("%s", err) + rw.WriteHeader(http.StatusInternalServerError) + return + } + parsedRequest.AppsecEngine = w.config.Name + + logger := w.logger.WithFields(log.Fields{ + "request_uuid": parsedRequest.UUID, + "client_ip": parsedRequest.ClientIP, + }) + + AppsecReqCounter.With(prometheus.Labels{"source": parsedRequest.RemoteAddrNormalized, "appsec_engine": parsedRequest.AppsecEngine}).Inc() + + w.InChan <- parsedRequest + + /* + response is a copy of w.AppSecRuntime.Response that is safe to use. + As OutOfBand might still be running, the original one can be modified + */ + response := <-parsedRequest.ResponseChannel + + if response.InBandInterrupt { + AppsecBlockCounter.With(prometheus.Labels{"source": parsedRequest.RemoteAddrNormalized, "appsec_engine": parsedRequest.AppsecEngine}).Inc() + } + + statusCode, appsecResponse := w.AppsecRuntime.GenerateResponse(response, logger) + logger.Debugf("Response: %+v", appsecResponse) + + rw.WriteHeader(statusCode) + body, err := json.Marshal(appsecResponse) + if err != nil { + logger.Errorf("unable to serialize response: %s", err) + rw.WriteHeader(http.StatusInternalServerError) + } else { + rw.Write(body) + } +} diff --git a/pkg/acquisition/modules/appsec/appsec_hooks_test.go b/pkg/acquisition/modules/appsec/appsec_hooks_test.go new file mode 100644 index 00000000000..c549d2ef1d1 --- /dev/null +++ b/pkg/acquisition/modules/appsec/appsec_hooks_test.go @@ -0,0 +1,894 @@ +package appsecacquisition + +import ( + "net/http" + "net/url" + "testing" + + "github.com/davecgh/go-spew/spew" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + + "github.com/crowdsecurity/crowdsec/pkg/appsec" + "github.com/crowdsecurity/crowdsec/pkg/appsec/appsec_rule" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +func TestAppsecOnMatchHooks(t *testing.T) { + tests := []appsecRuleTest{ + { + name: "no rule : check return code", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Len(t, responses, 1) + require.Equal(t, 403, responses[0].BouncerHTTPResponseCode) + require.Equal(t, 403, responses[0].UserHTTPResponseCode) + require.Equal(t, appsec.BanRemediation, responses[0].Action) + }, + }, + { + name: "on_match: change return code", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetReturnCode(413)"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Len(t, responses, 1) + require.Equal(t, 403, responses[0].BouncerHTTPResponseCode) + require.Equal(t, 413, responses[0].UserHTTPResponseCode) + require.Equal(t, appsec.BanRemediation, responses[0].Action) + }, + }, + { + name: "on_match: change action to a non standard one (log)", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('log')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Len(t, responses, 1) + require.Equal(t, "log", responses[0].Action) + require.Equal(t, 403, responses[0].BouncerHTTPResponseCode) + require.Equal(t, 403, responses[0].UserHTTPResponseCode) + }, + }, + { + name: "on_match: change action to another standard one (allow)", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('allow')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Len(t, responses, 1) + require.Equal(t, appsec.AllowRemediation, responses[0].Action) + }, + }, + { + name: "on_match: change action to another standard one (ban)", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('ban')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, responses, 1) + //note: SetAction normalizes deny, ban and block to ban + require.Equal(t, appsec.BanRemediation, responses[0].Action) + }, + }, + { + name: "on_match: change action to another standard one (captcha)", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('captcha')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, responses, 1) + //note: SetAction normalizes deny, ban and block to ban + require.Equal(t, appsec.CaptchaRemediation, responses[0].Action) + }, + }, + { + name: "on_match: change action to a non standard one", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('foobar')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Len(t, responses, 1) + require.Equal(t, "foobar", responses[0].Action) + }, + }, + { + name: "on_match: cancel alert", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true && LogInfo('XX -> %s', evt.Appsec.MatchedRules.GetName())", Apply: []string{"CancelAlert()"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 1) + require.Equal(t, types.LOG, events[0].Type) + require.Len(t, responses, 1) + require.Equal(t, appsec.BanRemediation, responses[0].Action) + }, + }, + { + name: "on_match: cancel event", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"CancelEvent()"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 1) + require.Equal(t, types.APPSEC, events[0].Type) + require.Len(t, responses, 1) + require.Equal(t, appsec.BanRemediation, responses[0].Action) + }, + }, + { + name: "on_match: on_success break", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"CancelEvent()"}, OnSuccess: "break"}, + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('captcha')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 1) + require.Equal(t, types.APPSEC, events[0].Type) + require.Len(t, responses, 1) + require.Equal(t, appsec.BanRemediation, responses[0].Action) + }, + }, + { + name: "on_match: on_success continue", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"CancelEvent()"}, OnSuccess: "continue"}, + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('captcha')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 1) + require.Equal(t, types.APPSEC, events[0].Type) + require.Len(t, responses, 1) + require.Equal(t, appsec.CaptchaRemediation, responses[0].Action) + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + loadAppSecEngine(test, t) + }) + } +} + +func TestAppsecPreEvalHooks(t *testing.T) { + + tests := []appsecRuleTest{ + { + name: "Basic pre_eval hook to disable inband rule", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + pre_eval: []appsec.Hook{ + {Filter: "1 == 1", Apply: []string{"RemoveInBandRuleByName('rule1')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Empty(t, events) + require.Len(t, responses, 1) + require.False(t, responses[0].InBandInterrupt) + require.False(t, responses[0].OutOfBandInterrupt) + }, + }, + { + name: "Basic pre_eval fails to disable rule", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + pre_eval: []appsec.Hook{ + {Filter: "1 ==2", Apply: []string{"RemoveInBandRuleByName('rule1')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + + require.Equal(t, types.LOG, events[1].Type) + require.True(t, events[1].Appsec.HasInBandMatches) + require.Len(t, events[1].Appsec.MatchedRules, 1) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + + require.Len(t, responses, 1) + require.True(t, responses[0].InBandInterrupt) + + }, + }, + { + name: "pre_eval : disable inband by tag", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rulez", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + pre_eval: []appsec.Hook{ + {Apply: []string{"RemoveInBandRuleByTag('crowdsec-rulez')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Empty(t, events) + require.Len(t, responses, 1) + require.False(t, responses[0].InBandInterrupt) + require.False(t, responses[0].OutOfBandInterrupt) + }, + }, + { + name: "pre_eval : disable inband by ID", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rulez", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + pre_eval: []appsec.Hook{ + {Apply: []string{"RemoveInBandRuleByID(1516470898)"}}, //rule ID is generated at runtime. If you change rule, it will break the test (: + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Empty(t, events) + require.Len(t, responses, 1) + require.False(t, responses[0].InBandInterrupt) + require.False(t, responses[0].OutOfBandInterrupt) + }, + }, + { + name: "pre_eval : disable inband by name", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rulez", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + pre_eval: []appsec.Hook{ + {Apply: []string{"RemoveInBandRuleByName('rulez')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Empty(t, events) + require.Len(t, responses, 1) + require.False(t, responses[0].InBandInterrupt) + require.False(t, responses[0].OutOfBandInterrupt) + }, + }, + { + name: "pre_eval : outofband default behavior", + expected_load_ok: true, + outofband_rules: []appsec_rule.CustomRule{ + { + Name: "rulez", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 1) + require.Equal(t, types.LOG, events[0].Type) + require.True(t, events[0].Appsec.HasOutBandMatches) + require.False(t, events[0].Appsec.HasInBandMatches) + require.Len(t, events[0].Appsec.MatchedRules, 1) + require.Equal(t, "rulez", events[0].Appsec.MatchedRules[0]["msg"]) + //maybe surprising, but response won't mention OOB event, as it's sent as soon as the inband phase is over. + require.Len(t, responses, 1) + require.False(t, responses[0].InBandInterrupt) + require.False(t, responses[0].OutOfBandInterrupt) + }, + }, + { + name: "pre_eval : set remediation by tag", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rulez", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + pre_eval: []appsec.Hook{ + {Apply: []string{"SetRemediationByTag('crowdsec-rulez', 'foobar')"}}, //rule ID is generated at runtime. If you change rule, it will break the test (: + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Len(t, responses, 1) + require.Equal(t, "foobar", responses[0].Action) + }, + }, + { + name: "pre_eval : set remediation by name", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rulez", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + pre_eval: []appsec.Hook{ + {Apply: []string{"SetRemediationByName('rulez', 'foobar')"}}, //rule ID is generated at runtime. If you change rule, it will break the test (: + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Len(t, responses, 1) + require.Equal(t, "foobar", responses[0].Action) + }, + }, + { + name: "pre_eval : set remediation by ID", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rulez", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + pre_eval: []appsec.Hook{ + {Apply: []string{"SetRemediationByID(1516470898, 'foobar')"}}, //rule ID is generated at runtime. If you change rule, it will break the test (: + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Len(t, responses, 1) + require.Equal(t, "foobar", responses[0].Action) + require.Equal(t, "foobar", appsecResponse.Action) + require.Equal(t, http.StatusForbidden, appsecResponse.HTTPStatus) + }, + }, + { + name: "pre_eval : on_success continue", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rulez", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + pre_eval: []appsec.Hook{ + {Filter: "1==1", Apply: []string{"SetRemediationByName('rulez', 'foobar')"}, OnSuccess: "continue"}, + {Filter: "1==1", Apply: []string{"SetRemediationByName('rulez', 'foobar2')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Len(t, responses, 1) + require.Equal(t, "foobar2", responses[0].Action) + }, + }, + { + name: "pre_eval : on_success break", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rulez", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + pre_eval: []appsec.Hook{ + {Filter: "1==1", Apply: []string{"SetRemediationByName('rulez', 'foobar')"}, OnSuccess: "break"}, + {Filter: "1==1", Apply: []string{"SetRemediationByName('rulez', 'foobar2')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Len(t, responses, 1) + require.Equal(t, "foobar", responses[0].Action) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + loadAppSecEngine(test, t) + }) + } +} + +func TestAppsecRemediationConfigHooks(t *testing.T) { + + tests := []appsecRuleTest{ + { + name: "Basic matching rule", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.BanRemediation, responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, appsec.BanRemediation, appsecResponse.Action) + require.Equal(t, http.StatusForbidden, appsecResponse.HTTPStatus) + }, + }, + { + name: "SetRemediation", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + on_match: []appsec.Hook{{Apply: []string{"SetRemediation('captcha')"}}}, //rule ID is generated at runtime. If you change rule, it will break the test (: + + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.CaptchaRemediation, responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, appsec.CaptchaRemediation, appsecResponse.Action) + require.Equal(t, http.StatusForbidden, appsecResponse.HTTPStatus) + }, + }, + { + name: "SetRemediation", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + on_match: []appsec.Hook{{Apply: []string{"SetReturnCode(418)"}}}, //rule ID is generated at runtime. If you change rule, it will break the test (: + + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.BanRemediation, responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, appsec.BanRemediation, appsecResponse.Action) + require.Equal(t, http.StatusTeapot, appsecResponse.HTTPStatus) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + loadAppSecEngine(test, t) + }) + } +} +func TestOnMatchRemediationHooks(t *testing.T) { + tests := []appsecRuleTest{ + { + name: "set remediation to allow with on_match hook", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('allow')"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + require.Equal(t, http.StatusOK, appsecResponse.HTTPStatus) + }, + }, + { + name: "set remediation to captcha + custom user code with on_match hook", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + DefaultRemediation: appsec.AllowRemediation, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('captcha')", "SetReturnCode(418)"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + spew.Dump(responses) + spew.Dump(appsecResponse) + + log.Errorf("http status : %d", statusCode) + require.Equal(t, appsec.CaptchaRemediation, appsecResponse.Action) + require.Equal(t, http.StatusTeapot, appsecResponse.HTTPStatus) + require.Equal(t, http.StatusForbidden, statusCode) + }, + }, + { + name: "on_match: on_success break", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + DefaultRemediation: appsec.AllowRemediation, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('captcha')", "SetReturnCode(418)"}, OnSuccess: "break"}, + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('ban')"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + spew.Dump(responses) + spew.Dump(appsecResponse) + + log.Errorf("http status : %d", statusCode) + require.Equal(t, appsec.CaptchaRemediation, appsecResponse.Action) + require.Equal(t, http.StatusTeapot, appsecResponse.HTTPStatus) + require.Equal(t, http.StatusForbidden, statusCode) + }, + }, + { + name: "on_match: on_success continue", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + DefaultRemediation: appsec.AllowRemediation, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('captcha')", "SetReturnCode(418)"}, OnSuccess: "continue"}, + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('ban')"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + spew.Dump(responses) + spew.Dump(appsecResponse) + + log.Errorf("http status : %d", statusCode) + require.Equal(t, appsec.BanRemediation, appsecResponse.Action) + require.Equal(t, http.StatusTeapot, appsecResponse.HTTPStatus) + require.Equal(t, http.StatusForbidden, statusCode) + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + loadAppSecEngine(test, t) + }) + } +} diff --git a/pkg/acquisition/modules/appsec/appsec_lnx_test.go b/pkg/acquisition/modules/appsec/appsec_lnx_test.go new file mode 100644 index 00000000000..61dfc536f5e --- /dev/null +++ b/pkg/acquisition/modules/appsec/appsec_lnx_test.go @@ -0,0 +1,74 @@ +//go:build !windows + +package appsecacquisition + +import ( + "testing" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + + "github.com/crowdsecurity/crowdsec/pkg/appsec" + "github.com/crowdsecurity/crowdsec/pkg/appsec/appsec_rule" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +func TestAppsecRuleTransformsOthers(t *testing.T) { + log.SetLevel(log.TraceLevel) + + tests := []appsecRuleTest{ + { + name: "normalizepath", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "equals", Value: "b/c"}, + Transform: []string{"normalizepath"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/?foo=a/../b/c", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "normalizepath #2", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "equals", Value: "b/c/"}, + Transform: []string{"normalizepath"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/?foo=a/../b/c/////././././", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + loadAppSecEngine(test, t) + }) + } +} diff --git a/pkg/acquisition/modules/appsec/appsec_remediation_test.go b/pkg/acquisition/modules/appsec/appsec_remediation_test.go new file mode 100644 index 00000000000..06016b6251f --- /dev/null +++ b/pkg/acquisition/modules/appsec/appsec_remediation_test.go @@ -0,0 +1,319 @@ +package appsecacquisition + +import ( + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/crowdsecurity/crowdsec/pkg/appsec" + "github.com/crowdsecurity/crowdsec/pkg/appsec/appsec_rule" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +func TestAppsecDefaultPassRemediation(t *testing.T) { + tests := []appsecRuleTest{ + { + name: "Basic non-matching rule", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + Args: url.Values{"foo": []string{"tutu"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.AllowRemediation, responses[0].Action) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + require.Equal(t, http.StatusOK, appsecResponse.HTTPStatus) + }, + }, + { + name: "DefaultPassAction: pass", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + Args: url.Values{"foo": []string{"tutu"}}, + }, + DefaultPassAction: "allow", + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.AllowRemediation, responses[0].Action) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + require.Equal(t, http.StatusOK, appsecResponse.HTTPStatus) + }, + }, + { + name: "DefaultPassAction: captcha", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + Args: url.Values{"foo": []string{"tutu"}}, + }, + DefaultPassAction: "captcha", + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.CaptchaRemediation, responses[0].Action) + require.Equal(t, http.StatusOK, statusCode) //@tko: body is captcha, but as it's 200, captcha won't be showed to user + require.Equal(t, appsec.CaptchaRemediation, appsecResponse.Action) + require.Equal(t, http.StatusOK, appsecResponse.HTTPStatus) + }, + }, + { + name: "DefaultPassHTTPCode: 200", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + Args: url.Values{"foo": []string{"tutu"}}, + }, + UserPassedHTTPCode: 200, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.AllowRemediation, responses[0].Action) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + require.Equal(t, http.StatusOK, appsecResponse.HTTPStatus) + }, + }, + { + name: "DefaultPassHTTPCode: 200", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + Args: url.Values{"foo": []string{"tutu"}}, + }, + UserPassedHTTPCode: 418, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.AllowRemediation, responses[0].Action) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + require.Equal(t, http.StatusTeapot, appsecResponse.HTTPStatus) + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + loadAppSecEngine(test, t) + }) + } +} + +func TestAppsecDefaultRemediation(t *testing.T) { + tests := []appsecRuleTest{ + { + name: "Basic matching rule", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.BanRemediation, responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, appsec.BanRemediation, appsecResponse.Action) + require.Equal(t, http.StatusForbidden, appsecResponse.HTTPStatus) + }, + }, + { + name: "default remediation to ban (default)", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + DefaultRemediation: "ban", + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.BanRemediation, responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, appsec.BanRemediation, appsecResponse.Action) + require.Equal(t, http.StatusForbidden, appsecResponse.HTTPStatus) + }, + }, + { + name: "default remediation to allow", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + DefaultRemediation: "allow", + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.AllowRemediation, responses[0].Action) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + require.Equal(t, http.StatusOK, appsecResponse.HTTPStatus) + }, + }, + { + name: "default remediation to captcha", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + DefaultRemediation: "captcha", + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.CaptchaRemediation, responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, appsec.CaptchaRemediation, appsecResponse.Action) + require.Equal(t, http.StatusForbidden, appsecResponse.HTTPStatus) + }, + }, + { + name: "custom user HTTP code", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + UserBlockedHTTPCode: 418, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.BanRemediation, responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, appsec.BanRemediation, appsecResponse.Action) + require.Equal(t, http.StatusTeapot, appsecResponse.HTTPStatus) + }, + }, + { + name: "custom remediation + HTTP code", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + UserBlockedHTTPCode: 418, + DefaultRemediation: "foobar", + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, "foobar", responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, "foobar", appsecResponse.Action) + require.Equal(t, http.StatusTeapot, appsecResponse.HTTPStatus) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + loadAppSecEngine(test, t) + }) + } +} diff --git a/pkg/acquisition/modules/appsec/appsec_rules_test.go b/pkg/acquisition/modules/appsec/appsec_rules_test.go new file mode 100644 index 00000000000..909f16357ed --- /dev/null +++ b/pkg/acquisition/modules/appsec/appsec_rules_test.go @@ -0,0 +1,859 @@ +package appsecacquisition + +import ( + "net/http" + "net/url" + "testing" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + + "github.com/crowdsecurity/crowdsec/pkg/appsec" + "github.com/crowdsecurity/crowdsec/pkg/appsec/appsec_rule" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +func TestAppsecRuleMatches(t *testing.T) { + tests := []appsecRuleTest{ + { + name: "Basic matching rule", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + + require.Equal(t, types.LOG, events[1].Type) + require.True(t, events[1].Appsec.HasInBandMatches) + require.Len(t, events[1].Appsec.MatchedRules, 1) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + + require.Len(t, responses, 1) + require.True(t, responses[0].InBandInterrupt) + }, + }, + { + name: "Basic non-matching rule", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"tutu"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Empty(t, events) + require.Len(t, responses, 1) + require.False(t, responses[0].InBandInterrupt) + require.False(t, responses[0].OutOfBandInterrupt) + }, + }, + { + name: "default remediation to allow", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + DefaultRemediation: "allow", + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.AllowRemediation, responses[0].Action) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + require.Equal(t, http.StatusOK, appsecResponse.HTTPStatus) + }, + }, + { + name: "default remediation to captcha", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + DefaultRemediation: "captcha", + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.CaptchaRemediation, responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, appsec.CaptchaRemediation, appsecResponse.Action) + require.Equal(t, http.StatusForbidden, appsecResponse.HTTPStatus) + }, + }, + { + name: "no default remediation / custom user HTTP code", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + UserBlockedHTTPCode: 418, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.BanRemediation, responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, appsec.BanRemediation, appsecResponse.Action) + require.Equal(t, http.StatusTeapot, appsecResponse.HTTPStatus) + }, + }, + { + name: "no match but try to set remediation to captcha with on_match hook", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('captcha')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"bla"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Empty(t, events) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + }, + }, + { + name: "no match but try to set user HTTP code with on_match hook", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetReturnCode(418)"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"bla"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Empty(t, events) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + }, + }, + { + name: "no match but try to set remediation with pre_eval hook", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + pre_eval: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetRemediationByName('rule42', 'captcha')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"bla"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Empty(t, events) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + }, + }, + { + name: "Basic matching in cookies", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"COOKIES"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Headers: http.Header{"Cookie": []string{"foo=toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + + require.Equal(t, types.LOG, events[1].Type) + require.True(t, events[1].Appsec.HasInBandMatches) + require.Len(t, events[1].Appsec.MatchedRules, 1) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + + require.Len(t, responses, 1) + require.True(t, responses[0].InBandInterrupt) + }, + }, + { + name: "Basic matching in all cookies", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"COOKIES"}, + Match: appsec_rule.Match{Type: "regex", Value: "^tutu"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Headers: http.Header{"Cookie": []string{"foo=toto; bar=tutu"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + + require.Equal(t, types.LOG, events[1].Type) + require.True(t, events[1].Appsec.HasInBandMatches) + require.Len(t, events[1].Appsec.MatchedRules, 1) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + + require.Len(t, responses, 1) + require.True(t, responses[0].InBandInterrupt) + }, + }, + { + name: "Basic matching in cookie name", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"COOKIES_NAMES"}, + Match: appsec_rule.Match{Type: "regex", Value: "^tutu"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Headers: http.Header{"Cookie": []string{"bar=tutu; tututata=toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + + require.Equal(t, types.LOG, events[1].Type) + require.True(t, events[1].Appsec.HasInBandMatches) + require.Len(t, events[1].Appsec.MatchedRules, 1) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + + require.Len(t, responses, 1) + require.True(t, responses[0].InBandInterrupt) + }, + }, + { + name: "Basic matching in multipart file name", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"FILES"}, + Match: appsec_rule.Match{Type: "regex", Value: "\\.php$"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Headers: http.Header{"Content-Type": []string{"multipart/form-data; boundary=boundary"}}, + Body: []byte(` +--boundary +Content-Disposition: form-data; name="foo"; filename="bar.php" +Content-Type: application/octet-stream + +toto +--boundary--`), + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + + require.Equal(t, types.LOG, events[1].Type) + require.True(t, events[1].Appsec.HasInBandMatches) + require.Len(t, events[1].Appsec.MatchedRules, 1) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + + require.Len(t, responses, 1) + require.True(t, responses[0].InBandInterrupt) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + loadAppSecEngine(test, t) + }) + } +} + +func TestAppsecRuleTransforms(t *testing.T) { + log.SetLevel(log.TraceLevel) + tests := []appsecRuleTest{ + { + name: "Basic matching rule", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"URI"}, + Match: appsec_rule.Match{Type: "equals", Value: "/toto"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/toto", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "lowercase", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"URI"}, + Match: appsec_rule.Match{Type: "equals", Value: "/toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/TOTO", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "uppercase", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"URI"}, + Match: appsec_rule.Match{Type: "equals", Value: "/TOTO"}, + Transform: []string{"uppercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/toto", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "b64decode", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "equals", Value: "toto"}, + Transform: []string{"b64decode"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/?foo=dG90bw", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "b64decode with extra padding", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "equals", Value: "toto"}, + Transform: []string{"b64decode"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/?foo=dG90bw===", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "length", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "gte", Value: "3"}, + Transform: []string{"length"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/?foo=toto", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "urldecode", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "equals", Value: "BB/A"}, + Transform: []string{"urldecode"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/?foo=%42%42%2F%41", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "trim", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "equals", Value: "BB/A"}, + Transform: []string{"urldecode", "trim"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/?foo=%20%20%42%42%2F%41%20%20", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + loadAppSecEngine(test, t) + }) + } +} + +func TestAppsecRuleZones(t *testing.T) { + log.SetLevel(log.TraceLevel) + tests := []appsecRuleTest{ + { + name: "rule: ARGS", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Match: appsec_rule.Match{Type: "equals", Value: "toto"}, + }, + { + Name: "rule2", + Zones: []string{"ARGS"}, + Match: appsec_rule.Match{Type: "equals", Value: "foobar"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/foobar?something=toto&foobar=smth", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "rule: ARGS_NAMES", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS_NAMES"}, + Match: appsec_rule.Match{Type: "equals", Value: "toto"}, + }, + { + Name: "rule2", + Zones: []string{"ARGS_NAMES"}, + Match: appsec_rule.Match{Type: "equals", Value: "foobar"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/foobar?something=toto&foobar=smth", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule2", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "rule: BODY_ARGS", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"BODY_ARGS"}, + Match: appsec_rule.Match{Type: "equals", Value: "toto"}, + }, + { + Name: "rule2", + Zones: []string{"BODY_ARGS"}, + Match: appsec_rule.Match{Type: "equals", Value: "foobar"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + Body: []byte("smth=toto&foobar=other"), + Headers: http.Header{"Content-Type": []string{"application/x-www-form-urlencoded"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "rule: BODY_ARGS_NAMES", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"BODY_ARGS_NAMES"}, + Match: appsec_rule.Match{Type: "equals", Value: "toto"}, + }, + { + Name: "rule2", + Zones: []string{"BODY_ARGS_NAMES"}, + Match: appsec_rule.Match{Type: "equals", Value: "foobar"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + Body: []byte("smth=toto&foobar=other"), + Headers: http.Header{"Content-Type": []string{"application/x-www-form-urlencoded"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule2", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "rule: HEADERS", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"HEADERS"}, + Match: appsec_rule.Match{Type: "equals", Value: "toto"}, + }, + { + Name: "rule2", + Zones: []string{"HEADERS"}, + Match: appsec_rule.Match{Type: "equals", Value: "foobar"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + Headers: http.Header{"foobar": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "rule: HEADERS_NAMES", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"HEADERS_NAMES"}, + Match: appsec_rule.Match{Type: "equals", Value: "toto"}, + }, + { + Name: "rule2", + Zones: []string{"HEADERS_NAMES"}, + Match: appsec_rule.Match{Type: "equals", Value: "foobar"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + Headers: http.Header{"foobar": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule2", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "rule: METHOD", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"METHOD"}, + Match: appsec_rule.Match{Type: "equals", Value: "GET"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "rule: PROTOCOL", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"PROTOCOL"}, + Match: appsec_rule.Match{Type: "contains", Value: "3.1"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + Proto: "HTTP/3.1", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "rule: URI", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"URI"}, + Match: appsec_rule.Match{Type: "equals", Value: "/foobar"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/foobar", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "rule: URI_FULL", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"URI_FULL"}, + Match: appsec_rule.Match{Type: "equals", Value: "/foobar?a=b"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/foobar?a=b", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "rule: RAW_BODY", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"RAW_BODY"}, + Match: appsec_rule.Match{Type: "equals", Value: "foobar=42421"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + Body: []byte("foobar=42421"), + Headers: http.Header{"Content-Type": []string{"application/x-www-form-urlencoded"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + loadAppSecEngine(test, t) + }) + } +} diff --git a/pkg/acquisition/modules/appsec/appsec_runner.go b/pkg/acquisition/modules/appsec/appsec_runner.go new file mode 100644 index 00000000000..de34b62d704 --- /dev/null +++ b/pkg/acquisition/modules/appsec/appsec_runner.go @@ -0,0 +1,380 @@ +package appsecacquisition + +import ( + "fmt" + "os" + "slices" + "time" + + "github.com/prometheus/client_golang/prometheus" + log "github.com/sirupsen/logrus" + "gopkg.in/tomb.v2" + + "github.com/crowdsecurity/coraza/v3" + corazatypes "github.com/crowdsecurity/coraza/v3/types" + + // load body processors via init() + _ "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/appsec/bodyprocessors" + "github.com/crowdsecurity/crowdsec/pkg/appsec" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +// that's the runtime structure of the Application security engine as seen from the acquis +type AppsecRunner struct { + outChan chan types.Event + inChan chan appsec.ParsedRequest + UUID string + AppsecRuntime *appsec.AppsecRuntimeConfig //this holds the actual appsec runtime config, rules, remediations, hooks etc. + AppsecInbandEngine coraza.WAF + AppsecOutbandEngine coraza.WAF + Labels map[string]string + logger *log.Entry +} + +func (r *AppsecRunner) Init(datadir string) error { + var err error + fs := os.DirFS(datadir) + + inBandRules := "" + outOfBandRules := "" + + for _, collection := range r.AppsecRuntime.InBandRules { + inBandRules += collection.String() + } + + for _, collection := range r.AppsecRuntime.OutOfBandRules { + outOfBandRules += collection.String() + } + inBandLogger := r.logger.Dup().WithField("band", "inband") + outBandLogger := r.logger.Dup().WithField("band", "outband") + + //setting up inband engine + inbandCfg := coraza.NewWAFConfig().WithDirectives(inBandRules).WithRootFS(fs).WithDebugLogger(appsec.NewCrzLogger(inBandLogger)) + if !r.AppsecRuntime.Config.InbandOptions.DisableBodyInspection { + inbandCfg = inbandCfg.WithRequestBodyAccess() + } else { + log.Warningf("Disabling body inspection, Inband rules will not be able to match on body's content.") + } + if r.AppsecRuntime.Config.InbandOptions.RequestBodyInMemoryLimit != nil { + inbandCfg = inbandCfg.WithRequestBodyInMemoryLimit(*r.AppsecRuntime.Config.InbandOptions.RequestBodyInMemoryLimit) + } + r.AppsecInbandEngine, err = coraza.NewWAF(inbandCfg) + if err != nil { + return fmt.Errorf("unable to initialize inband engine : %w", err) + } + + //setting up outband engine + outbandCfg := coraza.NewWAFConfig().WithDirectives(outOfBandRules).WithRootFS(fs).WithDebugLogger(appsec.NewCrzLogger(outBandLogger)) + if !r.AppsecRuntime.Config.OutOfBandOptions.DisableBodyInspection { + outbandCfg = outbandCfg.WithRequestBodyAccess() + } else { + log.Warningf("Disabling body inspection, Out of band rules will not be able to match on body's content.") + } + if r.AppsecRuntime.Config.OutOfBandOptions.RequestBodyInMemoryLimit != nil { + outbandCfg = outbandCfg.WithRequestBodyInMemoryLimit(*r.AppsecRuntime.Config.OutOfBandOptions.RequestBodyInMemoryLimit) + } + r.AppsecOutbandEngine, err = coraza.NewWAF(outbandCfg) + + if r.AppsecRuntime.DisabledInBandRulesTags != nil { + for _, tag := range r.AppsecRuntime.DisabledInBandRulesTags { + r.AppsecInbandEngine.GetRuleGroup().DeleteByTag(tag) + } + } + + if r.AppsecRuntime.DisabledOutOfBandRulesTags != nil { + for _, tag := range r.AppsecRuntime.DisabledOutOfBandRulesTags { + r.AppsecOutbandEngine.GetRuleGroup().DeleteByTag(tag) + } + } + + if r.AppsecRuntime.DisabledInBandRuleIds != nil { + for _, id := range r.AppsecRuntime.DisabledInBandRuleIds { + r.AppsecInbandEngine.GetRuleGroup().DeleteByID(id) + } + } + + if r.AppsecRuntime.DisabledOutOfBandRuleIds != nil { + for _, id := range r.AppsecRuntime.DisabledOutOfBandRuleIds { + r.AppsecOutbandEngine.GetRuleGroup().DeleteByID(id) + } + } + + r.logger.Tracef("Loaded inband rules: %+v", r.AppsecInbandEngine.GetRuleGroup().GetRules()) + r.logger.Tracef("Loaded outband rules: %+v", r.AppsecOutbandEngine.GetRuleGroup().GetRules()) + + if err != nil { + return fmt.Errorf("unable to initialize outband engine : %w", err) + } + + return nil +} + +func (r *AppsecRunner) processRequest(tx appsec.ExtendedTransaction, request *appsec.ParsedRequest) error { + var in *corazatypes.Interruption + var err error + + if request.Tx.IsRuleEngineOff() { + r.logger.Debugf("rule engine is off, skipping") + return nil + } + + defer func() { + request.Tx.ProcessLogging() + //We don't close the transaction here, as it will reset coraza internal state and break variable tracking + + err := r.AppsecRuntime.ProcessPostEvalRules(request) + if err != nil { + r.logger.Errorf("unable to process PostEval rules: %s", err) + } + }() + + //pre eval (expr) rules + err = r.AppsecRuntime.ProcessPreEvalRules(request) + if err != nil { + r.logger.Errorf("unable to process PreEval rules: %s", err) + //FIXME: should we abort here ? + } + + request.Tx.ProcessConnection(request.RemoteAddr, 0, "", 0) + + for k, v := range request.Args { + for _, vv := range v { + request.Tx.AddGetRequestArgument(k, vv) + } + } + + request.Tx.ProcessURI(request.URI, request.Method, request.Proto) + + for k, vr := range request.Headers { + for _, v := range vr { + request.Tx.AddRequestHeader(k, v) + } + } + + if request.ClientHost != "" { + request.Tx.AddRequestHeader("Host", request.ClientHost) + request.Tx.SetServerName(request.ClientHost) + } + + if request.TransferEncoding != nil { + request.Tx.AddRequestHeader("Transfer-Encoding", request.TransferEncoding[0]) + } + + in = request.Tx.ProcessRequestHeaders() + + if in != nil { + r.logger.Infof("inband rules matched for headers : %s", in.Action) + return nil + } + + if len(request.Body) > 0 { + in, _, err = request.Tx.WriteRequestBody(request.Body) + if err != nil { + r.logger.Errorf("unable to write request body : %s", err) + return err + } + if in != nil { + return nil + } + } + + in, err = request.Tx.ProcessRequestBody() + if err != nil { + r.logger.Errorf("unable to process request body : %s", err) + return err + } + + if in != nil { + r.logger.Debugf("rules matched for body : %d", in.RuleID) + } + + return nil +} + +func (r *AppsecRunner) ProcessInBandRules(request *appsec.ParsedRequest) error { + tx := appsec.NewExtendedTransaction(r.AppsecInbandEngine, request.UUID) + r.AppsecRuntime.InBandTx = tx + request.Tx = tx + if len(r.AppsecRuntime.InBandRules) == 0 { + return nil + } + err := r.processRequest(tx, request) + return err +} + +func (r *AppsecRunner) ProcessOutOfBandRules(request *appsec.ParsedRequest) error { + tx := appsec.NewExtendedTransaction(r.AppsecOutbandEngine, request.UUID) + r.AppsecRuntime.OutOfBandTx = tx + request.Tx = tx + if len(r.AppsecRuntime.OutOfBandRules) == 0 { + return nil + } + err := r.processRequest(tx, request) + return err +} + +func (r *AppsecRunner) handleInBandInterrupt(request *appsec.ParsedRequest) { + //create the associated event for crowdsec itself + evt, err := EventFromRequest(request, r.Labels) + if err != nil { + //let's not interrupt the pipeline for this + r.logger.Errorf("unable to create event from request : %s", err) + } + err = r.AccumulateTxToEvent(&evt, request) + if err != nil { + r.logger.Errorf("unable to accumulate tx to event : %s", err) + } + if in := request.Tx.Interruption(); in != nil { + r.logger.Debugf("inband rules matched : %d", in.RuleID) + r.AppsecRuntime.Response.InBandInterrupt = true + r.AppsecRuntime.Response.BouncerHTTPResponseCode = r.AppsecRuntime.Config.BouncerBlockedHTTPCode + r.AppsecRuntime.Response.UserHTTPResponseCode = r.AppsecRuntime.Config.UserBlockedHTTPCode + r.AppsecRuntime.Response.Action = r.AppsecRuntime.DefaultRemediation + + if _, ok := r.AppsecRuntime.RemediationById[in.RuleID]; ok { + r.AppsecRuntime.Response.Action = r.AppsecRuntime.RemediationById[in.RuleID] + } + + for tag, remediation := range r.AppsecRuntime.RemediationByTag { + if slices.Contains[[]string, string](in.Tags, tag) { + r.AppsecRuntime.Response.Action = remediation + } + } + + err = r.AppsecRuntime.ProcessOnMatchRules(request, evt) + if err != nil { + r.logger.Errorf("unable to process OnMatch rules: %s", err) + return + } + + // Should the in band match trigger an overflow ? + if r.AppsecRuntime.Response.SendAlert { + appsecOvlfw, err := AppsecEventGeneration(evt) + if err != nil { + r.logger.Errorf("unable to generate appsec event : %s", err) + return + } + if appsecOvlfw != nil { + r.outChan <- *appsecOvlfw + } + } + + // Should the in band match trigger an event ? + if r.AppsecRuntime.Response.SendEvent { + r.outChan <- evt + } + + } +} + +func (r *AppsecRunner) handleOutBandInterrupt(request *appsec.ParsedRequest) { + evt, err := EventFromRequest(request, r.Labels) + if err != nil { + //let's not interrupt the pipeline for this + r.logger.Errorf("unable to create event from request : %s", err) + } + err = r.AccumulateTxToEvent(&evt, request) + if err != nil { + r.logger.Errorf("unable to accumulate tx to event : %s", err) + } + if in := request.Tx.Interruption(); in != nil { + r.logger.Debugf("outband rules matched : %d", in.RuleID) + r.AppsecRuntime.Response.OutOfBandInterrupt = true + + err = r.AppsecRuntime.ProcessOnMatchRules(request, evt) + if err != nil { + r.logger.Errorf("unable to process OnMatch rules: %s", err) + return + } + // Should the match trigger an event ? + if r.AppsecRuntime.Response.SendEvent { + r.outChan <- evt + } + + // Should the match trigger an overflow ? + if r.AppsecRuntime.Response.SendAlert { + appsecOvlfw, err := AppsecEventGeneration(evt) + if err != nil { + r.logger.Errorf("unable to generate appsec event : %s", err) + return + } + r.outChan <- *appsecOvlfw + } + } +} + +func (r *AppsecRunner) handleRequest(request *appsec.ParsedRequest) { + r.AppsecRuntime.Logger = r.AppsecRuntime.Logger.WithField("request_uuid", request.UUID) + logger := r.logger.WithField("request_uuid", request.UUID) + logger.Debug("Request received in runner") + r.AppsecRuntime.ClearResponse() + + request.IsInBand = true + request.IsOutBand = false + + //to measure the time spent in the Application Security Engine for InBand rules + startInBandParsing := time.Now() + startGlobalParsing := time.Now() + + //inband appsec rules + err := r.ProcessInBandRules(request) + if err != nil { + logger.Errorf("unable to process InBand rules: %s", err) + return + } + + // time spent to process in band rules + inBandParsingElapsed := time.Since(startInBandParsing) + AppsecInbandParsingHistogram.With(prometheus.Labels{"source": request.RemoteAddrNormalized, "appsec_engine": request.AppsecEngine}).Observe(inBandParsingElapsed.Seconds()) + + if request.Tx.IsInterrupted() { + r.handleInBandInterrupt(request) + } + + // send back the result to the HTTP handler for the InBand part + request.ResponseChannel <- r.AppsecRuntime.Response + + //Now let's process the out of band rules + + request.IsInBand = false + request.IsOutBand = true + r.AppsecRuntime.Response.SendAlert = false + r.AppsecRuntime.Response.SendEvent = true + + //FIXME: This is a bit of a hack to avoid confusion with the transaction if we do not have any inband rules. + //We should probably have different transaction (or even different request object) for inband and out of band rules + if len(r.AppsecRuntime.OutOfBandRules) > 0 { + //to measure the time spent in the Application Security Engine for OutOfBand rules + startOutOfBandParsing := time.Now() + + err = r.ProcessOutOfBandRules(request) + if err != nil { + logger.Errorf("unable to process OutOfBand rules: %s", err) + return + } + + // time spent to process out of band rules + outOfBandParsingElapsed := time.Since(startOutOfBandParsing) + AppsecOutbandParsingHistogram.With(prometheus.Labels{"source": request.RemoteAddrNormalized, "appsec_engine": request.AppsecEngine}).Observe(outOfBandParsingElapsed.Seconds()) + if request.Tx.IsInterrupted() { + r.handleOutBandInterrupt(request) + } + } + // time spent to process inband AND out of band rules + globalParsingElapsed := time.Since(startGlobalParsing) + AppsecGlobalParsingHistogram.With(prometheus.Labels{"source": request.RemoteAddrNormalized, "appsec_engine": request.AppsecEngine}).Observe(globalParsingElapsed.Seconds()) + +} + +func (r *AppsecRunner) Run(t *tomb.Tomb) error { + r.logger.Infof("Appsec Runner ready to process event") + for { + select { + case <-t.Dying(): + r.logger.Infof("Appsec Runner is dying") + return nil + case request := <-r.inChan: + r.handleRequest(&request) + } + } +} diff --git a/pkg/acquisition/modules/appsec/appsec_test.go b/pkg/acquisition/modules/appsec/appsec_test.go new file mode 100644 index 00000000000..d2079b43726 --- /dev/null +++ b/pkg/acquisition/modules/appsec/appsec_test.go @@ -0,0 +1,124 @@ +package appsecacquisition + +import ( + "testing" + "time" + + "github.com/davecgh/go-spew/spew" + "github.com/google/uuid" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/pkg/appsec" + "github.com/crowdsecurity/crowdsec/pkg/appsec/appsec_rule" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +type appsecRuleTest struct { + name string + expected_load_ok bool + inband_rules []appsec_rule.CustomRule + outofband_rules []appsec_rule.CustomRule + on_load []appsec.Hook + pre_eval []appsec.Hook + post_eval []appsec.Hook + on_match []appsec.Hook + BouncerBlockedHTTPCode int + UserBlockedHTTPCode int + UserPassedHTTPCode int + DefaultRemediation string + DefaultPassAction string + input_request appsec.ParsedRequest + output_asserts func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) +} + +func loadAppSecEngine(test appsecRuleTest, t *testing.T) { + if testing.Verbose() { + log.SetLevel(log.TraceLevel) + } else { + log.SetLevel(log.WarnLevel) + } + inbandRules := []string{} + outofbandRules := []string{} + InChan := make(chan appsec.ParsedRequest) + OutChan := make(chan types.Event) + + logger := log.WithField("test", test.name) + + //build rules + for ridx, rule := range test.inband_rules { + strRule, _, err := rule.Convert(appsec_rule.ModsecurityRuleType, rule.Name) + if err != nil { + t.Fatalf("failed compilation of rule %d/%d of %s : %s", ridx, len(test.inband_rules), test.name, err) + } + inbandRules = append(inbandRules, strRule) + + } + for ridx, rule := range test.outofband_rules { + strRule, _, err := rule.Convert(appsec_rule.ModsecurityRuleType, rule.Name) + if err != nil { + t.Fatalf("failed compilation of rule %d/%d of %s : %s", ridx, len(test.outofband_rules), test.name, err) + } + outofbandRules = append(outofbandRules, strRule) + } + + appsecCfg := appsec.AppsecConfig{Logger: logger, + OnLoad: test.on_load, + PreEval: test.pre_eval, + PostEval: test.post_eval, + OnMatch: test.on_match, + BouncerBlockedHTTPCode: test.BouncerBlockedHTTPCode, + UserBlockedHTTPCode: test.UserBlockedHTTPCode, + UserPassedHTTPCode: test.UserPassedHTTPCode, + DefaultRemediation: test.DefaultRemediation, + DefaultPassAction: test.DefaultPassAction} + AppsecRuntime, err := appsecCfg.Build() + if err != nil { + t.Fatalf("unable to build appsec runtime : %s", err) + } + AppsecRuntime.InBandRules = []appsec.AppsecCollection{{Rules: inbandRules}} + AppsecRuntime.OutOfBandRules = []appsec.AppsecCollection{{Rules: outofbandRules}} + appsecRunnerUUID := uuid.New().String() + //we copy AppsecRutime for each runner + wrt := *AppsecRuntime + wrt.Logger = logger + runner := AppsecRunner{ + inChan: InChan, + UUID: appsecRunnerUUID, + logger: logger, + AppsecRuntime: &wrt, + Labels: map[string]string{"foo": "bar"}, + outChan: OutChan, + } + err = runner.Init("/tmp/") + if err != nil { + t.Fatalf("unable to initialize runner : %s", err) + } + + input := test.input_request + input.ResponseChannel = make(chan appsec.AppsecTempResponse) + OutputEvents := make([]types.Event, 0) + OutputResponses := make([]appsec.AppsecTempResponse, 0) + go func() { + for { + //log.Printf("reading from %p", input.ResponseChannel) + out := <-input.ResponseChannel + OutputResponses = append(OutputResponses, out) + //log.Errorf("response -> %s", spew.Sdump(out)) + } + }() + go func() { + for { + out := <-OutChan + OutputEvents = append(OutputEvents, out) + //log.Errorf("outchan -> %s", spew.Sdump(out)) + } + }() + + runner.handleRequest(&input) + time.Sleep(50 * time.Millisecond) + + http_status, appsecResponse := AppsecRuntime.GenerateResponse(OutputResponses[0], logger) + log.Infof("events : %s", spew.Sdump(OutputEvents)) + log.Infof("responses : %s", spew.Sdump(OutputResponses)) + test.output_asserts(OutputEvents, OutputResponses, appsecResponse, http_status) +} diff --git a/pkg/acquisition/modules/appsec/appsec_win_test.go b/pkg/acquisition/modules/appsec/appsec_win_test.go new file mode 100644 index 00000000000..a6b8f3a0340 --- /dev/null +++ b/pkg/acquisition/modules/appsec/appsec_win_test.go @@ -0,0 +1,45 @@ +//go:build windows + +package appsecacquisition + +import ( + "testing" + + log "github.com/sirupsen/logrus" +) + +func TestAppsecRuleTransformsWindows(t *testing.T) { + + log.SetLevel(log.TraceLevel) + tests := []appsecRuleTest{ + // { + // name: "normalizepath", + // expected_load_ok: true, + // inband_rules: []appsec_rule.CustomRule{ + // { + // Name: "rule1", + // Zones: []string{"ARGS"}, + // Variables: []string{"foo"}, + // Match: appsec_rule.Match{Type: "equals", Value: "b/c"}, + // Transform: []string{"normalizepath"}, + // }, + // }, + // input_request: appsec.ParsedRequest{ + // RemoteAddr: "1.2.3.4", + // Method: "GET", + // URI: "/?foo=a/../b/c", + // }, + // output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + // require.Len(t, events, 2) + // require.Equal(t, types.APPSEC, events[0].Type) + // require.Equal(t, types.LOG, events[1].Type) + // require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + // }, + // }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + loadAppSecEngine(test, t) + }) + } +} diff --git a/pkg/acquisition/modules/appsec/bodyprocessors/raw.go b/pkg/acquisition/modules/appsec/bodyprocessors/raw.go new file mode 100644 index 00000000000..e2e23eb57ae --- /dev/null +++ b/pkg/acquisition/modules/appsec/bodyprocessors/raw.go @@ -0,0 +1,45 @@ +package bodyprocessors + +import ( + "io" + "strconv" + "strings" + + "github.com/crowdsecurity/coraza/v3/experimental/plugins" + "github.com/crowdsecurity/coraza/v3/experimental/plugins/plugintypes" +) + +type rawBodyProcessor struct { +} + +type setterInterface interface { + Set(string) +} + +func (*rawBodyProcessor) ProcessRequest(reader io.Reader, v plugintypes.TransactionVariables, options plugintypes.BodyProcessorOptions) error { + buf := new(strings.Builder) + if _, err := io.Copy(buf, reader); err != nil { + return err + } + + b := buf.String() + + v.RequestBody().(setterInterface).Set(b) + v.RequestBodyLength().(setterInterface).Set(strconv.Itoa(len(b))) + return nil +} + +func (*rawBodyProcessor) ProcessResponse(reader io.Reader, v plugintypes.TransactionVariables, options plugintypes.BodyProcessorOptions) error { + return nil +} + +var ( + _ plugintypes.BodyProcessor = &rawBodyProcessor{} +) + +//nolint:gochecknoinits //Coraza recommends to use init() for registering plugins +func init() { + plugins.RegisterBodyProcessor("raw", func() plugintypes.BodyProcessor { + return &rawBodyProcessor{} + }) +} diff --git a/pkg/acquisition/modules/appsec/metrics.go b/pkg/acquisition/modules/appsec/metrics.go new file mode 100644 index 00000000000..13275933836 --- /dev/null +++ b/pkg/acquisition/modules/appsec/metrics.go @@ -0,0 +1,54 @@ +package appsecacquisition + +import "github.com/prometheus/client_golang/prometheus" + +var AppsecGlobalParsingHistogram = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Help: "Time spent processing a request by the Application Security Engine.", + Name: "cs_appsec_parsing_time_seconds", + Buckets: []float64{0.0001, 0.00025, 0.0005, 0.001, 0.0025, 0.0050, 0.01, 0.025, 0.05, 0.1, 0.25}, + }, + []string{"source", "appsec_engine"}, +) + +var AppsecInbandParsingHistogram = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Help: "Time spent processing a request by the inband Application Security Engine.", + Name: "cs_appsec_inband_parsing_time_seconds", + Buckets: []float64{0.0001, 0.00025, 0.0005, 0.001, 0.0025, 0.0050, 0.01, 0.025, 0.05, 0.1, 0.25}, + }, + []string{"source", "appsec_engine"}, +) + +var AppsecOutbandParsingHistogram = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Help: "Time spent processing a request by the Application Security Engine.", + Name: "cs_appsec_outband_parsing_time_seconds", + Buckets: []float64{0.0001, 0.00025, 0.0005, 0.001, 0.0025, 0.0050, 0.01, 0.025, 0.05, 0.1, 0.25}, + }, + []string{"source", "appsec_engine"}, +) + +var AppsecReqCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "cs_appsec_reqs_total", + Help: "Total events processed by the Application Security Engine.", + }, + []string{"source", "appsec_engine"}, +) + +var AppsecBlockCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "cs_appsec_block_total", + Help: "Total events blocked by the Application Security Engine.", + }, + []string{"source", "appsec_engine"}, +) + +var AppsecRuleHits = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "cs_appsec_rule_hits", + Help: "Count of triggered rule, by rule_name, type (inband/outofband), appsec_engine and source", + }, + []string{"rule_name", "type", "appsec_engine", "source"}, +) diff --git a/pkg/acquisition/modules/appsec/rx_operator.go b/pkg/acquisition/modules/appsec/rx_operator.go new file mode 100644 index 00000000000..4b16296fd40 --- /dev/null +++ b/pkg/acquisition/modules/appsec/rx_operator.go @@ -0,0 +1,95 @@ +package appsecacquisition + +import ( + "fmt" + "strconv" + "unicode/utf8" + + "github.com/wasilibs/go-re2" + "github.com/wasilibs/go-re2/experimental" + + "github.com/crowdsecurity/coraza/v3/experimental/plugins" + "github.com/crowdsecurity/coraza/v3/experimental/plugins/plugintypes" +) + +type rx struct { + re *re2.Regexp +} + +var _ plugintypes.Operator = (*rx)(nil) + +func newRX(options plugintypes.OperatorOptions) (plugintypes.Operator, error) { + // (?sm) enables multiline mode which makes 942522-7 work, see + // - https://stackoverflow.com/a/27680233 + // - https://groups.google.com/g/golang-nuts/c/jiVdamGFU9E + data := fmt.Sprintf("(?sm)%s", options.Arguments) + + var re *re2.Regexp + var err error + + if matchesArbitraryBytes(data) { + re, err = experimental.CompileLatin1(data) + } else { + re, err = re2.Compile(data) + } + if err != nil { + return nil, err + } + return &rx{re: re}, nil +} + +func (o *rx) Evaluate(tx plugintypes.TransactionState, value string) bool { + if tx.Capturing() { + match := o.re.FindStringSubmatch(value) + if len(match) == 0 { + return false + } + for i, c := range match { + if i == 9 { + return true + } + tx.CaptureField(i, c) + } + return true + } + + return o.re.MatchString(value) +} + +// RegisterRX registers the rx operator using a WASI implementation instead of Go. +func RegisterRX() { + plugins.RegisterOperator("rx", newRX) +} + +// matchesArbitraryBytes checks for control sequences for byte matches in the expression. +// If the sequences are not valid utf8, it returns true. +func matchesArbitraryBytes(expr string) bool { + decoded := make([]byte, 0, len(expr)) + for i := 0; i < len(expr); i++ { + c := expr[i] + if c != '\\' { + decoded = append(decoded, c) + continue + } + if i+3 >= len(expr) { + decoded = append(decoded, expr[i:]...) + break + } + if expr[i+1] != 'x' { + decoded = append(decoded, expr[i]) + continue + } + + v, mb, _, err := strconv.UnquoteChar(expr[i:], 0) + if err != nil || mb { + // Wasn't a byte escape sequence, shouldn't happen in practice. + decoded = append(decoded, expr[i]) + continue + } + + decoded = append(decoded, byte(v)) + i += 3 + } + + return !utf8.Valid(decoded) +} diff --git a/pkg/acquisition/modules/appsec/utils.go b/pkg/acquisition/modules/appsec/utils.go new file mode 100644 index 00000000000..4fb1a979d14 --- /dev/null +++ b/pkg/acquisition/modules/appsec/utils.go @@ -0,0 +1,386 @@ +package appsecacquisition + +import ( + "fmt" + "net" + "slices" + "strconv" + "time" + + "github.com/oschwald/geoip2-golang" + "github.com/prometheus/client_golang/prometheus" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/coraza/v3/collection" + "github.com/crowdsecurity/coraza/v3/types/variables" + "github.com/crowdsecurity/go-cs-lib/ptr" + + "github.com/crowdsecurity/crowdsec/pkg/alertcontext" + "github.com/crowdsecurity/crowdsec/pkg/appsec" + "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" + "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +var appsecMetaKeys = []string{ + "id", + "name", + "method", + "uri", + "matched_zones", + "msg", +} + +func appendMeta(meta models.Meta, key string, value string) models.Meta { + if value == "" { + return meta + } + + meta = append(meta, &models.MetaItems0{ + Key: key, + Value: value, + }) + + return meta +} + +func AppsecEventGeneration(inEvt types.Event) (*types.Event, error) { + // if the request didnd't trigger inband rules, we don't want to generate an event to LAPI/CAPI + if !inEvt.Appsec.HasInBandMatches { + return nil, nil + } + + evt := types.Event{} + evt.Type = types.APPSEC + evt.Process = true + sourceIP := inEvt.Parsed["source_ip"] + source := models.Source{ + Value: &sourceIP, + IP: sourceIP, + Scope: ptr.Of(types.Ip), + } + + asndata, err := exprhelpers.GeoIPASNEnrich(sourceIP) + + if err != nil { + log.Errorf("Unable to enrich ip '%s' for ASN: %s", sourceIP, err) + } else if asndata != nil { + record := asndata.(*geoip2.ASN) + source.AsName = record.AutonomousSystemOrganization + source.AsNumber = fmt.Sprintf("%d", record.AutonomousSystemNumber) + } + + cityData, err := exprhelpers.GeoIPEnrich(sourceIP) + if err != nil { + log.Errorf("Unable to enrich ip '%s' for geo data: %s", sourceIP, err) + } else if cityData != nil { + record := cityData.(*geoip2.City) + source.Cn = record.Country.IsoCode + source.Latitude = float32(record.Location.Latitude) + source.Longitude = float32(record.Location.Longitude) + } + + rangeData, err := exprhelpers.GeoIPRangeEnrich(sourceIP) + if err != nil { + log.Errorf("Unable to enrich ip '%s' for range: %s", sourceIP, err) + } else if rangeData != nil { + record := rangeData.(*net.IPNet) + source.Range = record.String() + } + + evt.Overflow.Sources = make(map[string]models.Source) + evt.Overflow.Sources[sourceIP] = source + + alert := models.Alert{} + alert.Capacity = ptr.Of(int32(1)) + alert.Events = make([]*models.Event, len(evt.Appsec.GetRuleIDs())) + + now := ptr.Of(time.Now().UTC().Format(time.RFC3339)) + + tmpAppsecContext := make(map[string][]string) + + for _, matched_rule := range inEvt.Appsec.MatchedRules { + evtRule := models.Event{} + + evtRule.Timestamp = now + + evtRule.Meta = make(models.Meta, 0) + + for _, key := range appsecMetaKeys { + if tmpAppsecContext[key] == nil { + tmpAppsecContext[key] = make([]string, 0) + } + + switch value := matched_rule[key].(type) { + case string: + evtRule.Meta = appendMeta(evtRule.Meta, key, value) + + if value != "" && !slices.Contains(tmpAppsecContext[key], value) { + tmpAppsecContext[key] = append(tmpAppsecContext[key], value) + } + case int: + val := strconv.Itoa(value) + evtRule.Meta = appendMeta(evtRule.Meta, key, val) + + if val != "" && !slices.Contains(tmpAppsecContext[key], val) { + tmpAppsecContext[key] = append(tmpAppsecContext[key], val) + } + case []string: + for _, v := range value { + evtRule.Meta = appendMeta(evtRule.Meta, key, v) + + if v != "" && !slices.Contains(tmpAppsecContext[key], v) { + tmpAppsecContext[key] = append(tmpAppsecContext[key], v) + } + } + case []int: + for _, v := range value { + val := strconv.Itoa(v) + evtRule.Meta = appendMeta(evtRule.Meta, key, val) + + if val != "" && !slices.Contains(tmpAppsecContext[key], val) { + tmpAppsecContext[key] = append(tmpAppsecContext[key], val) + } + } + default: + val := fmt.Sprintf("%v", value) + evtRule.Meta = appendMeta(evtRule.Meta, key, val) + + if val != "" && !slices.Contains(tmpAppsecContext[key], val) { + tmpAppsecContext[key] = append(tmpAppsecContext[key], val) + } + } + } + + alert.Events = append(alert.Events, &evtRule) + } + + metas := make([]*models.MetaItems0, 0) + + for key, values := range tmpAppsecContext { + if len(values) == 0 { + continue + } + + valueStr, err := alertcontext.TruncateContext(values, alertcontext.MaxContextValueLen) + if err != nil { + log.Warning(err.Error()) + } + + meta := models.MetaItems0{ + Key: key, + Value: valueStr, + } + metas = append(metas, &meta) + } + + alert.Meta = metas + + alert.EventsCount = ptr.Of(int32(len(alert.Events))) + alert.Leakspeed = ptr.Of("") + alert.Scenario = ptr.Of(inEvt.Appsec.MatchedRules.GetName()) + alert.ScenarioHash = ptr.Of(inEvt.Appsec.MatchedRules.GetHash()) + alert.ScenarioVersion = ptr.Of(inEvt.Appsec.MatchedRules.GetVersion()) + alert.Simulated = ptr.Of(false) + alert.Source = &source + msg := fmt.Sprintf("AppSec block: %s from %s (%s)", inEvt.Appsec.MatchedRules.GetName(), + alert.Source.IP, inEvt.Parsed["remediation_cmpt_ip"]) + alert.Message = &msg + alert.StartAt = ptr.Of(time.Now().UTC().Format(time.RFC3339)) + alert.StopAt = ptr.Of(time.Now().UTC().Format(time.RFC3339)) + evt.Overflow.APIAlerts = []models.Alert{alert} + evt.Overflow.Alert = &alert + + return &evt, nil +} + +func EventFromRequest(r *appsec.ParsedRequest, labels map[string]string) (types.Event, error) { + evt := types.Event{} + // we might want to change this based on in-band vs out-of-band ? + evt.Type = types.LOG + evt.ExpectMode = types.LIVE + // def needs fixing + evt.Stage = "s00-raw" + evt.Parsed = map[string]string{ + "source_ip": r.ClientIP, + "target_host": r.Host, + "target_uri": r.URI, + "method": r.Method, + "req_uuid": r.Tx.ID(), + "source": "crowdsec-appsec", + "remediation_cmpt_ip": r.RemoteAddrNormalized, + // TBD: + // http_status + // user_agent + + } + evt.Line = types.Line{ + Time: time.Now(), + // should we add some info like listen addr/port/path ? + Labels: labels, + Process: true, + Module: "appsec", + Src: "appsec", + Raw: "dummy-appsec-data", // we discard empty Line.Raw items :) + } + evt.Appsec = types.AppsecEvent{} + + return evt, nil +} + +func LogAppsecEvent(evt *types.Event, logger *log.Entry) { + req := evt.Parsed["target_uri"] + if len(req) > 12 { + req = req[:10] + ".." + } + + if evt.Meta["appsec_interrupted"] == "true" { + logger.WithFields(log.Fields{ + "module": "appsec", + "source": evt.Parsed["source_ip"], + "target_uri": req, + }).Infof("%s blocked on %s (%d rules) [%v]", evt.Parsed["source_ip"], req, len(evt.Appsec.MatchedRules), evt.Appsec.GetRuleIDs()) + } else if evt.Parsed["outofband_interrupted"] == "true" { + logger.WithFields(log.Fields{ + "module": "appsec", + "source": evt.Parsed["source_ip"], + "target_uri": req, + }).Infof("%s out-of-band blocking rules on %s (%d rules) [%v]", evt.Parsed["source_ip"], req, len(evt.Appsec.MatchedRules), evt.Appsec.GetRuleIDs()) + } else { + logger.WithFields(log.Fields{ + "module": "appsec", + "source": evt.Parsed["source_ip"], + "target_uri": req, + }).Debugf("%s triggered non-blocking rules on %s (%d rules) [%v]", evt.Parsed["source_ip"], req, len(evt.Appsec.MatchedRules), evt.Appsec.GetRuleIDs()) + } +} + +func (r *AppsecRunner) AccumulateTxToEvent(evt *types.Event, req *appsec.ParsedRequest) error { + if evt == nil { + // an error was already emitted, let's not spam the logs + return nil + } + + if !req.Tx.IsInterrupted() { + // if the phase didn't generate an interruption, we don't have anything to add to the event + return nil + } + // if one interruption was generated, event is good for processing :) + evt.Process = true + + if evt.Meta == nil { + evt.Meta = map[string]string{} + } + + if evt.Parsed == nil { + evt.Parsed = map[string]string{} + } + + if req.IsInBand { + evt.Meta["appsec_interrupted"] = "true" + evt.Meta["appsec_action"] = req.Tx.Interruption().Action + evt.Parsed["inband_interrupted"] = "true" + evt.Parsed["inband_action"] = req.Tx.Interruption().Action + } else { + evt.Parsed["outofband_interrupted"] = "true" + evt.Parsed["outofband_action"] = req.Tx.Interruption().Action + } + + if evt.Appsec.Vars == nil { + evt.Appsec.Vars = map[string]string{} + } + + req.Tx.Variables().All(func(v variables.RuleVariable, col collection.Collection) bool { + for _, variable := range col.FindAll() { + key := variable.Variable().Name() + if variable.Key() != "" { + key += "." + variable.Key() + } + + if variable.Value() == "" { + continue + } + + for _, collectionToKeep := range r.AppsecRuntime.CompiledVariablesTracking { + match := collectionToKeep.MatchString(key) + if match { + evt.Appsec.Vars[key] = variable.Value() + r.logger.Debugf("%s.%s = %s", variable.Variable().Name(), variable.Key(), variable.Value()) + } else { + r.logger.Debugf("%s.%s != %s (%s) (not kept)", variable.Variable().Name(), variable.Key(), collectionToKeep, variable.Value()) + } + } + } + + return true + }) + + for _, rule := range req.Tx.MatchedRules() { + if rule.Message() == "" { + r.logger.Tracef("discarding rule %d (action: %s)", rule.Rule().ID(), rule.DisruptiveAction()) + continue + } + kind := "outofband" + if req.IsInBand { + kind = "inband" + evt.Appsec.HasInBandMatches = true + } else { + evt.Appsec.HasOutBandMatches = true + } + + var name string + version := "" + hash := "" + ruleNameProm := fmt.Sprintf("%d", rule.Rule().ID()) + + if details, ok := appsec.AppsecRulesDetails[rule.Rule().ID()]; ok { + // Only set them for custom rules, not for rules written in seclang + name = details.Name + version = details.Version + hash = details.Hash + ruleNameProm = details.Name + + r.logger.Debugf("custom rule for event, setting name: %s, version: %s, hash: %s", name, version, hash) + } else { + name = fmt.Sprintf("native_rule:%d", rule.Rule().ID()) + } + + AppsecRuleHits.With(prometheus.Labels{"rule_name": ruleNameProm, "type": kind, "source": req.RemoteAddrNormalized, "appsec_engine": req.AppsecEngine}).Inc() + + matchedZones := make([]string, 0) + + for _, matchData := range rule.MatchedDatas() { + zone := matchData.Variable().Name() + + varName := matchData.Key() + if varName != "" { + zone += "." + varName + } + + matchedZones = append(matchedZones, zone) + } + + corazaRule := map[string]interface{}{ + "id": rule.Rule().ID(), + "uri": evt.Parsed["target_uri"], + "rule_type": kind, + "method": evt.Parsed["method"], + "disruptive": rule.Disruptive(), + "tags": rule.Rule().Tags(), + "file": rule.Rule().File(), + "file_line": rule.Rule().Line(), + "revision": rule.Rule().Revision(), + "secmark": rule.Rule().SecMark(), + "accuracy": rule.Rule().Accuracy(), + "msg": rule.Message(), + "severity": rule.Rule().Severity().String(), + "name": name, + "hash": hash, + "version": version, + "matched_zones": matchedZones, + } + evt.Appsec.MatchedRules = append(evt.Appsec.MatchedRules, corazaRule) + } + + return nil +} diff --git a/pkg/acquisition/modules/cloudwatch/cloudwatch.go b/pkg/acquisition/modules/cloudwatch/cloudwatch.go index 48bbe421753..e4b6c95d77f 100644 --- a/pkg/acquisition/modules/cloudwatch/cloudwatch.go +++ b/pkg/acquisition/modules/cloudwatch/cloudwatch.go @@ -2,6 +2,7 @@ package cloudwatchacquisition import ( "context" + "errors" "fmt" "net/url" "os" @@ -43,7 +44,8 @@ var linesRead = prometheus.NewCounterVec( // CloudwatchSource is the runtime instance keeping track of N streams within 1 cloudwatch group type CloudwatchSource struct { - Config CloudwatchSourceConfiguration + metricsLevel int + Config CloudwatchSourceConfiguration /*runtime stuff*/ logger *log.Entry t *tomb.Tomb @@ -55,16 +57,16 @@ type CloudwatchSource struct { // CloudwatchSourceConfiguration allows user to define one or more streams to monitor within a cloudwatch log group type CloudwatchSourceConfiguration struct { configuration.DataSourceCommonCfg `yaml:",inline"` - GroupName string `yaml:"group_name"` //the group name to be monitored - StreamRegexp *string `yaml:"stream_regexp,omitempty"` //allow to filter specific streams + GroupName string `yaml:"group_name"` // the group name to be monitored + StreamRegexp *string `yaml:"stream_regexp,omitempty"` // allow to filter specific streams StreamName *string `yaml:"stream_name,omitempty"` StartTime, EndTime *time.Time `yaml:"-"` - DescribeLogStreamsLimit *int64 `yaml:"describelogstreams_limit,omitempty"` //batch size for DescribeLogStreamsPagesWithContext + DescribeLogStreamsLimit *int64 `yaml:"describelogstreams_limit,omitempty"` // batch size for DescribeLogStreamsPagesWithContext GetLogEventsPagesLimit *int64 `yaml:"getlogeventspages_limit,omitempty"` - PollNewStreamInterval *time.Duration `yaml:"poll_new_stream_interval,omitempty"` //frequency at which we poll for new streams within the log group - MaxStreamAge *time.Duration `yaml:"max_stream_age,omitempty"` //monitor only streams that have been updated within $duration - PollStreamInterval *time.Duration `yaml:"poll_stream_interval,omitempty"` //frequency at which we poll each stream - StreamReadTimeout *time.Duration `yaml:"stream_read_timeout,omitempty"` //stop monitoring streams that haven't been updated within $duration, might be reopened later tho + PollNewStreamInterval *time.Duration `yaml:"poll_new_stream_interval,omitempty"` // frequency at which we poll for new streams within the log group + MaxStreamAge *time.Duration `yaml:"max_stream_age,omitempty"` // monitor only streams that have been updated within $duration + PollStreamInterval *time.Duration `yaml:"poll_stream_interval,omitempty"` // frequency at which we poll each stream + StreamReadTimeout *time.Duration `yaml:"stream_read_timeout,omitempty"` // stop monitoring streams that haven't been updated within $duration, might be reopened later tho AwsApiCallTimeout *time.Duration `yaml:"aws_api_timeout,omitempty"` AwsProfile *string `yaml:"aws_profile,omitempty"` PrependCloudwatchTimestamp *bool `yaml:"prepend_cloudwatch_timestamp,omitempty"` @@ -84,7 +86,7 @@ type LogStreamTailConfig struct { logger *log.Entry ExpectMode int t tomb.Tomb - StartTime, EndTime time.Time //only used for CatMode + StartTime, EndTime time.Time // only used for CatMode } var ( @@ -109,8 +111,8 @@ func (cw *CloudwatchSource) UnmarshalConfig(yamlConfig []byte) error { return fmt.Errorf("cannot parse CloudwatchSource configuration: %w", err) } - if len(cw.Config.GroupName) == 0 { - return fmt.Errorf("group_name is mandatory for CloudwatchSource") + if cw.Config.GroupName == "" { + return errors.New("group_name is mandatory for CloudwatchSource") } if cw.Config.Mode == "" { @@ -152,12 +154,14 @@ func (cw *CloudwatchSource) UnmarshalConfig(yamlConfig []byte) error { return nil } -func (cw *CloudwatchSource) Configure(yamlConfig []byte, logger *log.Entry) error { +func (cw *CloudwatchSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLevel int) error { err := cw.UnmarshalConfig(yamlConfig) if err != nil { return err } + cw.metricsLevel = MetricsLevel + cw.logger = logger.WithField("group", cw.Config.GroupName) cw.logger.Debugf("Starting configuration for Cloudwatch group %s", cw.Config.GroupName) @@ -172,42 +176,49 @@ func (cw *CloudwatchSource) Configure(yamlConfig []byte, logger *log.Entry) erro if *cw.Config.MaxStreamAge > *cw.Config.StreamReadTimeout { cw.logger.Warningf("max_stream_age > stream_read_timeout, stream might keep being opened/closed") } + cw.logger.Tracef("aws_config_dir set to %s", *cw.Config.AwsConfigDir) if *cw.Config.AwsConfigDir != "" { _, err := os.Stat(*cw.Config.AwsConfigDir) if err != nil { cw.logger.Errorf("can't read aws_config_dir '%s' got err %s", *cw.Config.AwsConfigDir, err) - return fmt.Errorf("can't read aws_config_dir %s got err %s ", *cw.Config.AwsConfigDir, err) + return fmt.Errorf("can't read aws_config_dir %s got err %w ", *cw.Config.AwsConfigDir, err) } + os.Setenv("AWS_SDK_LOAD_CONFIG", "1") - //as aws sdk relies on $HOME, let's allow the user to override it :) + // as aws sdk relies on $HOME, let's allow the user to override it :) os.Setenv("AWS_CONFIG_FILE", fmt.Sprintf("%s/config", *cw.Config.AwsConfigDir)) os.Setenv("AWS_SHARED_CREDENTIALS_FILE", fmt.Sprintf("%s/credentials", *cw.Config.AwsConfigDir)) } else { if cw.Config.AwsRegion == nil { cw.logger.Errorf("aws_region is not specified, specify it or aws_config_dir") - return fmt.Errorf("aws_region is not specified, specify it or aws_config_dir") + return errors.New("aws_region is not specified, specify it or aws_config_dir") } + os.Setenv("AWS_REGION", *cw.Config.AwsRegion) } if err := cw.newClient(); err != nil { return err } + cw.streamIndexes = make(map[string]string) targetStream := "*" + if cw.Config.StreamRegexp != nil { if _, err := regexp.Compile(*cw.Config.StreamRegexp); err != nil { return fmt.Errorf("while compiling regexp '%s': %w", *cw.Config.StreamRegexp, err) } + targetStream = *cw.Config.StreamRegexp } else if cw.Config.StreamName != nil { targetStream = *cw.Config.StreamName } cw.logger.Infof("Adding cloudwatch group '%s' (stream:%s) to datasources", cw.Config.GroupName, targetStream) + return nil } @@ -226,26 +237,31 @@ func (cw *CloudwatchSource) newClient() error { } if sess == nil { - return fmt.Errorf("failed to create aws session") + return errors.New("failed to create aws session") } + if v := os.Getenv("AWS_ENDPOINT_FORCE"); v != "" { cw.logger.Debugf("[testing] overloading endpoint with %s", v) cw.cwClient = cloudwatchlogs.New(sess, aws.NewConfig().WithEndpoint(v)) } else { cw.cwClient = cloudwatchlogs.New(sess) } + if cw.cwClient == nil { - return fmt.Errorf("failed to create cloudwatch client") + return errors.New("failed to create cloudwatch client") } + return nil } -func (cw *CloudwatchSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (cw *CloudwatchSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { cw.t = t monitChan := make(chan LogStreamTailConfig) + t.Go(func() error { return cw.LogStreamManager(monitChan, out) }) + return cw.WatchLogGroupForStreams(monitChan) } @@ -276,6 +292,7 @@ func (cw *CloudwatchSource) Dump() interface{} { func (cw *CloudwatchSource) WatchLogGroupForStreams(out chan LogStreamTailConfig) error { cw.logger.Debugf("Starting to watch group (interval:%s)", cw.Config.PollNewStreamInterval) ticker := time.NewTicker(*cw.Config.PollNewStreamInterval) + var startFrom *string for { @@ -286,11 +303,12 @@ func (cw *CloudwatchSource) WatchLogGroupForStreams(out chan LogStreamTailConfig case <-ticker.C: hasMoreStreams := true startFrom = nil + for hasMoreStreams { cw.logger.Tracef("doing the call to DescribeLogStreamsPagesWithContext") ctx := context.Background() - //there can be a lot of streams in a group, and we're only interested in those recently written to, so we sort by LastEventTime + // there can be a lot of streams in a group, and we're only interested in those recently written to, so we sort by LastEventTime err := cw.cwClient.DescribeLogStreamsPagesWithContext( ctx, &cloudwatchlogs.DescribeLogStreamsInput{ @@ -302,13 +320,14 @@ func (cw *CloudwatchSource) WatchLogGroupForStreams(out chan LogStreamTailConfig }, func(page *cloudwatchlogs.DescribeLogStreamsOutput, lastPage bool) bool { cw.logger.Tracef("in helper of DescribeLogStreamsPagesWithContext") + for _, event := range page.LogStreams { startFrom = page.NextToken - //we check if the stream has been written to recently enough to be monitored + // we check if the stream has been written to recently enough to be monitored if event.LastIngestionTime != nil { - //aws uses millisecond since the epoch + // aws uses millisecond since the epoch oldest := time.Now().UTC().Add(-*cw.Config.MaxStreamAge) - //TBD : verify that this is correct : Unix 2nd arg expects Nanoseconds, and have a code that is more explicit. + // TBD : verify that this is correct : Unix 2nd arg expects Nanoseconds, and have a code that is more explicit. LastIngestionTime := time.Unix(0, *event.LastIngestionTime*int64(time.Millisecond)) if LastIngestionTime.Before(oldest) { cw.logger.Tracef("stop iteration, %s reached oldest age, stop (%s < %s)", *event.LogStreamName, LastIngestionTime, time.Now().UTC().Add(-*cw.Config.MaxStreamAge)) @@ -316,7 +335,7 @@ func (cw *CloudwatchSource) WatchLogGroupForStreams(out chan LogStreamTailConfig return false } cw.logger.Tracef("stream %s is elligible for monitoring", *event.LogStreamName) - //the stream has been updated recently, check if we should monitor it + // the stream has been updated recently, check if we should monitor it var expectMode int if !cw.Config.UseTimeMachine { expectMode = types.LIVE @@ -354,7 +373,6 @@ func (cw *CloudwatchSource) WatchLogGroupForStreams(out chan LogStreamTailConfig // LogStreamManager receives the potential streams to monitor, and starts a go routine when needed func (cw *CloudwatchSource) LogStreamManager(in chan LogStreamTailConfig, outChan chan types.Event) error { - cw.logger.Debugf("starting to monitor streams for %s", cw.Config.GroupName) pollDeadStreamInterval := time.NewTicker(def_PollDeadStreamInterval) @@ -370,7 +388,7 @@ func (cw *CloudwatchSource) LogStreamManager(in chan LogStreamTailConfig, outCha } if cw.Config.StreamRegexp != nil { - match, err := regexp.Match(*cw.Config.StreamRegexp, []byte(newStream.StreamName)) + match, err := regexp.MatchString(*cw.Config.StreamRegexp, newStream.StreamName) if err != nil { cw.logger.Warningf("invalid regexp : %s", err) } else if !match { @@ -381,11 +399,13 @@ func (cw *CloudwatchSource) LogStreamManager(in chan LogStreamTailConfig, outCha for idx, stream := range cw.monitoredStreams { if newStream.GroupName == stream.GroupName && newStream.StreamName == stream.StreamName { - //stream exists, but is dead, remove it from list + // stream exists, but is dead, remove it from list if !stream.t.Alive() { cw.logger.Debugf("stream %s already exists, but is dead", newStream.StreamName) cw.monitoredStreams = append(cw.monitoredStreams[:idx], cw.monitoredStreams[idx+1:]...) - openedStreams.With(prometheus.Labels{"group": newStream.GroupName}).Dec() + if cw.metricsLevel != configuration.METRICS_NONE { + openedStreams.With(prometheus.Labels{"group": newStream.GroupName}).Dec() + } break } shouldCreate = false @@ -393,11 +413,13 @@ func (cw *CloudwatchSource) LogStreamManager(in chan LogStreamTailConfig, outCha } } - //let's start watching this stream + // let's start watching this stream if shouldCreate { - openedStreams.With(prometheus.Labels{"group": newStream.GroupName}).Inc() + if cw.metricsLevel != configuration.METRICS_NONE { + openedStreams.With(prometheus.Labels{"group": newStream.GroupName}).Inc() + } newStream.t = tomb.Tomb{} - newStream.logger = cw.logger.WithFields(log.Fields{"stream": newStream.StreamName}) + newStream.logger = cw.logger.WithField("stream", newStream.StreamName) cw.logger.Debugf("starting tail of stream %s", newStream.StreamName) newStream.t.Go(func() error { return cw.TailLogStream(&newStream, outChan) @@ -409,7 +431,9 @@ func (cw *CloudwatchSource) LogStreamManager(in chan LogStreamTailConfig, outCha for idx, stream := range cw.monitoredStreams { if !cw.monitoredStreams[idx].t.Alive() { cw.logger.Debugf("remove dead stream %s", stream.StreamName) - openedStreams.With(prometheus.Labels{"group": cw.monitoredStreams[idx].GroupName}).Dec() + if cw.metricsLevel != configuration.METRICS_NONE { + openedStreams.With(prometheus.Labels{"group": cw.monitoredStreams[idx].GroupName}).Dec() + } } else { newMonitoredStreams = append(newMonitoredStreams, stream) } @@ -437,7 +461,7 @@ func (cw *CloudwatchSource) TailLogStream(cfg *LogStreamTailConfig, outChan chan var startFrom *string lastReadMessage := time.Now().UTC() ticker := time.NewTicker(cfg.PollStreamInterval) - //resume at existing index if we already had + // resume at existing index if we already had streamIndexMutex.Lock() v := cw.streamIndexes[cfg.GroupName+"+"+cfg.StreamName] streamIndexMutex.Unlock() @@ -485,7 +509,9 @@ func (cw *CloudwatchSource) TailLogStream(cfg *LogStreamTailConfig, outChan chan cfg.logger.Warningf("cwLogToEvent error, discarded event : %s", err) } else { cfg.logger.Debugf("pushing message : %s", evt.Line.Raw) - linesRead.With(prometheus.Labels{"group": cfg.GroupName, "stream": cfg.StreamName}).Inc() + if cw.metricsLevel != configuration.METRICS_NONE { + linesRead.With(prometheus.Labels{"group": cfg.GroupName, "stream": cfg.StreamName}).Inc() + } outChan <- evt } } @@ -506,7 +532,7 @@ func (cw *CloudwatchSource) TailLogStream(cfg *LogStreamTailConfig, outChan chan } case <-cfg.t.Dying(): cfg.logger.Infof("logstream tail stopping") - return fmt.Errorf("killed") + return errors.New("killed") } } } @@ -517,11 +543,11 @@ func (cw *CloudwatchSource) ConfigureByDSN(dsn string, labels map[string]string, dsn = strings.TrimPrefix(dsn, cw.GetName()+"://") args := strings.Split(dsn, "?") if len(args) != 2 { - return fmt.Errorf("query is mandatory (at least start_date and end_date or backlog)") + return errors.New("query is mandatory (at least start_date and end_date or backlog)") } frags := strings.Split(args[0], ":") if len(frags) != 2 { - return fmt.Errorf("cloudwatch path must contain group and stream : /my/group/name:stream/name") + return errors.New("cloudwatch path must contain group and stream : /my/group/name:stream/name") } cw.Config.GroupName = frags[0] cw.Config.StreamName = &frags[1] @@ -537,7 +563,7 @@ func (cw *CloudwatchSource) ConfigureByDSN(dsn string, labels map[string]string, switch k { case "log_level": if len(v) != 1 { - return fmt.Errorf("expected zero or one value for 'log_level'") + return errors.New("expected zero or one value for 'log_level'") } lvl, err := log.ParseLevel(v[0]) if err != nil { @@ -547,32 +573,32 @@ func (cw *CloudwatchSource) ConfigureByDSN(dsn string, labels map[string]string, case "profile": if len(v) != 1 { - return fmt.Errorf("expected zero or one value for 'profile'") + return errors.New("expected zero or one value for 'profile'") } awsprof := v[0] cw.Config.AwsProfile = &awsprof cw.logger.Debugf("profile set to '%s'", *cw.Config.AwsProfile) case "start_date": if len(v) != 1 { - return fmt.Errorf("expected zero or one argument for 'start_date'") + return errors.New("expected zero or one argument for 'start_date'") } - //let's reuse our parser helper so that a ton of date formats are supported + // let's reuse our parser helper so that a ton of date formats are supported strdate, startDate := parser.GenDateParse(v[0]) cw.logger.Debugf("parsed '%s' as '%s'", v[0], strdate) cw.Config.StartTime = &startDate case "end_date": if len(v) != 1 { - return fmt.Errorf("expected zero or one argument for 'end_date'") + return errors.New("expected zero or one argument for 'end_date'") } - //let's reuse our parser helper so that a ton of date formats are supported + // let's reuse our parser helper so that a ton of date formats are supported strdate, endDate := parser.GenDateParse(v[0]) cw.logger.Debugf("parsed '%s' as '%s'", v[0], strdate) cw.Config.EndTime = &endDate case "backlog": if len(v) != 1 { - return fmt.Errorf("expected zero or one argument for 'backlog'") + return errors.New("expected zero or one argument for 'backlog'") } - //let's reuse our parser helper so that a ton of date formats are supported + // let's reuse our parser helper so that a ton of date formats are supported duration, err := time.ParseDuration(v[0]) if err != nil { return fmt.Errorf("unable to parse '%s' as duration: %w", v[0], err) @@ -595,10 +621,10 @@ func (cw *CloudwatchSource) ConfigureByDSN(dsn string, labels map[string]string, } if cw.Config.StreamName == nil || cw.Config.GroupName == "" { - return fmt.Errorf("missing stream or group name") + return errors.New("missing stream or group name") } if cw.Config.StartTime == nil || cw.Config.EndTime == nil { - return fmt.Errorf("start_date and end_date or backlog are mandatory in one-shot mode") + return errors.New("start_date and end_date or backlog are mandatory in one-shot mode") } cw.Config.Mode = configuration.CAT_MODE @@ -608,7 +634,7 @@ func (cw *CloudwatchSource) ConfigureByDSN(dsn string, labels map[string]string, } func (cw *CloudwatchSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { - //StreamName string, Start time.Time, End time.Time + // StreamName string, Start time.Time, End time.Time config := LogStreamTailConfig{ GroupName: cw.Config.GroupName, StreamName: *cw.Config.StreamName, @@ -627,7 +653,7 @@ func (cw *CloudwatchSource) OneShotAcquisition(out chan types.Event, t *tomb.Tom func (cw *CloudwatchSource) CatLogStream(cfg *LogStreamTailConfig, outChan chan types.Event) error { var startFrom *string - var head = true + head := true /*convert the times*/ startTime := cfg.StartTime.UTC().Unix() * 1000 endTime := cfg.EndTime.UTC().Unix() * 1000 @@ -689,7 +715,7 @@ func cwLogToEvent(log *cloudwatchlogs.OutputLogEvent, cfg *LogStreamTailConfig) l := types.Line{} evt := types.Event{} if log.Message == nil { - return evt, fmt.Errorf("nil message") + return evt, errors.New("nil message") } msg := *log.Message if cfg.PrependCloudwatchTimestamp != nil && *cfg.PrependCloudwatchTimestamp { diff --git a/pkg/acquisition/modules/cloudwatch/cloudwatch_test.go b/pkg/acquisition/modules/cloudwatch/cloudwatch_test.go index 7cdfefca6be..d62c3f6e3dd 100644 --- a/pkg/acquisition/modules/cloudwatch/cloudwatch_test.go +++ b/pkg/acquisition/modules/cloudwatch/cloudwatch_test.go @@ -1,6 +1,8 @@ package cloudwatchacquisition import ( + "context" + "errors" "fmt" "net" "os" @@ -9,14 +11,16 @@ import ( "testing" "time" - "github.com/crowdsecurity/go-cs-lib/pkg/cstest" - "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/cloudwatchlogs" - "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" "gopkg.in/tomb.v2" + + "github.com/crowdsecurity/go-cs-lib/cstest" + + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" + "github.com/crowdsecurity/crowdsec/pkg/types" ) /* @@ -31,6 +35,7 @@ func deleteAllLogGroups(t *testing.T, cw *CloudwatchSource) { input := &cloudwatchlogs.DescribeLogGroupsInput{} result, err := cw.cwClient.DescribeLogGroups(input) require.NoError(t, err) + for _, group := range result.LogGroups { _, err := cw.cwClient.DeleteLogGroup(&cloudwatchlogs.DeleteLogGroupInput{ LogGroupName: group.LogGroupName, @@ -42,14 +47,14 @@ func deleteAllLogGroups(t *testing.T, cw *CloudwatchSource) { func checkForLocalStackAvailability() error { v := os.Getenv("AWS_ENDPOINT_FORCE") if v == "" { - return fmt.Errorf("missing aws endpoint for tests : AWS_ENDPOINT_FORCE") + return errors.New("missing aws endpoint for tests : AWS_ENDPOINT_FORCE") } v = strings.TrimPrefix(v, "http://") _, err := net.Dial("tcp", v) if err != nil { - return fmt.Errorf("while dialing %s : %s : aws endpoint isn't available", v, err) + return fmt.Errorf("while dialing %s: %w: aws endpoint isn't available", v, err) } return nil @@ -59,18 +64,22 @@ func TestMain(m *testing.M) { if runtime.GOOS == "windows" { os.Exit(0) } + if err := checkForLocalStackAvailability(); err != nil { log.Fatalf("local stack error : %s", err) } + def_PollNewStreamInterval = 1 * time.Second def_PollStreamInterval = 1 * time.Second def_StreamReadTimeout = 10 * time.Second def_MaxStreamAge = 5 * time.Second def_PollDeadStreamInterval = 5 * time.Second + os.Exit(m.Run()) } func TestWatchLogGroupForStreams(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -199,7 +208,7 @@ stream_regexp: test_bad[0-9]+`), }, expectedResLen: 0, }, - // require a group name that does exist and contains a stream in which we gonna put events + // require a group name that does exist and contains a stream in which we are going to put events { name: "group_exists_stream_exists_has_events", config: []byte(` @@ -421,13 +430,12 @@ stream_name: test_stream`), } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { dbgLogger := log.New().WithField("test", tc.name) dbgLogger.Logger.SetLevel(log.DebugLevel) dbgLogger.Infof("starting test") cw := CloudwatchSource{} - err := cw.Configure(tc.config, dbgLogger) + err := cw.Configure(tc.config, dbgLogger, configuration.METRICS_NONE) cstest.RequireErrorContains(t, err, tc.expectedCfgErr) if tc.expectedCfgErr != "" { @@ -445,7 +453,7 @@ stream_name: test_stream`), dbgLogger.Infof("running StreamingAcquisition") actmb := tomb.Tomb{} actmb.Go(func() error { - err := cw.StreamingAcquisition(out, &actmb) + err := cw.StreamingAcquisition(ctx, out, &actmb) dbgLogger.Infof("acquis done") cstest.RequireErrorContains(t, err, tc.expectedStartErr) return nil @@ -501,7 +509,6 @@ stream_name: test_stream`), if len(res) != 0 { t.Fatalf("leftover unmatched results : %v", res) } - } if tc.teardown != nil { tc.teardown(t, &cw) @@ -511,6 +518,7 @@ stream_name: test_stream`), } func TestConfiguration(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -554,12 +562,11 @@ stream_name: test_stream`), } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { dbgLogger := log.New().WithField("test", tc.name) dbgLogger.Logger.SetLevel(log.DebugLevel) cw := CloudwatchSource{} - err := cw.Configure(tc.config, dbgLogger) + err := cw.Configure(tc.config, dbgLogger, configuration.METRICS_NONE) cstest.RequireErrorContains(t, err, tc.expectedCfgErr) if tc.expectedCfgErr != "" { return @@ -570,7 +577,7 @@ stream_name: test_stream`), switch cw.GetMode() { case "tail": - err = cw.StreamingAcquisition(out, &tmb) + err = cw.StreamingAcquisition(ctx, out, &tmb) case "cat": err = cw.OneShotAcquisition(out, &tmb) } @@ -619,7 +626,6 @@ func TestConfigureByDSN(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { dbgLogger := log.New().WithField("test", tc.name) dbgLogger.Logger.SetLevel(log.DebugLevel) @@ -741,7 +747,6 @@ func TestOneShotAcquisition(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { dbgLogger := log.New().WithField("test", tc.name) dbgLogger.Logger.SetLevel(log.DebugLevel) @@ -799,7 +804,6 @@ func TestOneShotAcquisition(t *testing.T) { if len(res) != 0 { t.Fatalf("leftover unmatched results : %v", res) } - } if tc.teardown != nil { tc.teardown(t, &cw) diff --git a/pkg/acquisition/modules/docker/docker.go b/pkg/acquisition/modules/docker/docker.go index 92962697437..874b1556fd5 100644 --- a/pkg/acquisition/modules/docker/docker.go +++ b/pkg/acquisition/modules/docker/docker.go @@ -3,6 +3,7 @@ package dockeracquisition import ( "bufio" "context" + "errors" "fmt" "net/url" "regexp" @@ -41,11 +42,12 @@ type DockerConfiguration struct { ContainerID []string `yaml:"container_id"` ContainerNameRegexp []string `yaml:"container_name_regexp"` ContainerIDRegexp []string `yaml:"container_id_regexp"` - ForceInotify bool `yaml:"force_inotify"` + UseContainerLabels bool `yaml:"use_container_labels"` configuration.DataSourceCommonCfg `yaml:",inline"` } type DockerSource struct { + metricsLevel int Config DockerConfiguration runningContainerState map[string]*ContainerConfig compiledContainerName []*regexp.Regexp @@ -86,8 +88,12 @@ func (d *DockerSource) UnmarshalConfig(yamlConfig []byte) error { d.logger.Tracef("DockerAcquisition configuration: %+v", d.Config) } - if len(d.Config.ContainerName) == 0 && len(d.Config.ContainerID) == 0 && len(d.Config.ContainerIDRegexp) == 0 && len(d.Config.ContainerNameRegexp) == 0 { - return fmt.Errorf("no containers names or containers ID configuration provided") + if len(d.Config.ContainerName) == 0 && len(d.Config.ContainerID) == 0 && len(d.Config.ContainerIDRegexp) == 0 && len(d.Config.ContainerNameRegexp) == 0 && !d.Config.UseContainerLabels { + return errors.New("no containers names or containers ID configuration provided") + } + + if d.Config.UseContainerLabels && (len(d.Config.ContainerName) > 0 || len(d.Config.ContainerID) > 0 || len(d.Config.ContainerIDRegexp) > 0 || len(d.Config.ContainerNameRegexp) > 0) { + return errors.New("use_container_labels and container_name, container_id, container_id_regexp, container_name_regexp are mutually exclusive") } d.CheckIntervalDuration, err = time.ParseDuration(d.Config.CheckInterval) @@ -128,9 +134,9 @@ func (d *DockerSource) UnmarshalConfig(yamlConfig []byte) error { return nil } -func (d *DockerSource) Configure(yamlConfig []byte, logger *log.Entry) error { +func (d *DockerSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLevel int) error { d.logger = logger - + d.metricsLevel = MetricsLevel err := d.UnmarshalConfig(yamlConfig) if err != nil { return err @@ -220,7 +226,7 @@ func (d *DockerSource) ConfigureByDSN(dsn string, labels map[string]string, logg switch k { case "log_level": if len(v) != 1 { - return fmt.Errorf("only one 'log_level' parameters is required, not many") + return errors.New("only one 'log_level' parameters is required, not many") } lvl, err := log.ParseLevel(v[0]) if err != nil { @@ -229,17 +235,17 @@ func (d *DockerSource) ConfigureByDSN(dsn string, labels map[string]string, logg d.logger.Logger.SetLevel(lvl) case "until": if len(v) != 1 { - return fmt.Errorf("only one 'until' parameters is required, not many") + return errors.New("only one 'until' parameters is required, not many") } d.containerLogsOptions.Until = v[0] case "since": if len(v) != 1 { - return fmt.Errorf("only one 'since' parameters is required, not many") + return errors.New("only one 'since' parameters is required, not many") } d.containerLogsOptions.Since = v[0] case "follow_stdout": if len(v) != 1 { - return fmt.Errorf("only one 'follow_stdout' parameters is required, not many") + return errors.New("only one 'follow_stdout' parameters is required, not many") } followStdout, err := strconv.ParseBool(v[0]) if err != nil { @@ -249,7 +255,7 @@ func (d *DockerSource) ConfigureByDSN(dsn string, labels map[string]string, logg d.containerLogsOptions.ShowStdout = followStdout case "follow_stderr": if len(v) != 1 { - return fmt.Errorf("only one 'follow_stderr' parameters is required, not many") + return errors.New("only one 'follow_stderr' parameters is required, not many") } followStdErr, err := strconv.ParseBool(v[0]) if err != nil { @@ -259,7 +265,7 @@ func (d *DockerSource) ConfigureByDSN(dsn string, labels map[string]string, logg d.containerLogsOptions.ShowStderr = followStdErr case "docker_host": if len(v) != 1 { - return fmt.Errorf("only one 'docker_host' parameters is required, not many") + return errors.New("only one 'docker_host' parameters is required, not many") } if err := client.WithHost(v[0])(dockerClient); err != nil { return err @@ -292,7 +298,7 @@ func (d *DockerSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) er d.logger.Debugf("container with id %s is already being read from", container.ID) continue } - if containerConfig, ok := d.EvalContainer(container); ok { + if containerConfig := d.EvalContainer(container); containerConfig != nil { d.logger.Infof("reading logs from container %s", containerConfig.Name) d.logger.Debugf("logs options: %+v", *d.containerLogsOptions) dockerReader, err := d.Client.ContainerLogs(context.Background(), containerConfig.ID, *d.containerLogsOptions) @@ -325,7 +331,9 @@ func (d *DockerSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) er l.Src = containerConfig.Name l.Process = true l.Module = d.GetName() - linesRead.With(prometheus.Labels{"source": containerConfig.Name}).Inc() + if d.metricsLevel != configuration.METRICS_NONE { + linesRead.With(prometheus.Labels{"source": containerConfig.Name}).Inc() + } evt := types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE} out <- evt d.logger.Debugf("Sent line to parsing: %+v", evt.Line.Raw) @@ -372,41 +380,88 @@ func (d *DockerSource) getContainerTTY(containerId string) bool { return containerDetails.Config.Tty } -func (d *DockerSource) EvalContainer(container dockerTypes.Container) (*ContainerConfig, bool) { +func (d *DockerSource) getContainerLabels(containerId string) map[string]interface{} { + containerDetails, err := d.Client.ContainerInspect(context.Background(), containerId) + if err != nil { + return map[string]interface{}{} + } + return parseLabels(containerDetails.Config.Labels) +} + +func (d *DockerSource) EvalContainer(container dockerTypes.Container) *ContainerConfig { for _, containerID := range d.Config.ContainerID { if containerID == container.ID { - return &ContainerConfig{ID: container.ID, Name: container.Names[0], Labels: d.Config.Labels, Tty: d.getContainerTTY(container.ID)}, true + return &ContainerConfig{ID: container.ID, Name: container.Names[0], Labels: d.Config.Labels, Tty: d.getContainerTTY(container.ID)} } } for _, containerName := range d.Config.ContainerName { for _, name := range container.Names { - if strings.HasPrefix(name, "/") && len(name) > 0 { + if strings.HasPrefix(name, "/") && name != "" { name = name[1:] } if name == containerName { - return &ContainerConfig{ID: container.ID, Name: name, Labels: d.Config.Labels, Tty: d.getContainerTTY(container.ID)}, true + return &ContainerConfig{ID: container.ID, Name: name, Labels: d.Config.Labels, Tty: d.getContainerTTY(container.ID)} } } - } for _, cont := range d.compiledContainerID { - if matched := cont.Match([]byte(container.ID)); matched { - return &ContainerConfig{ID: container.ID, Name: container.Names[0], Labels: d.Config.Labels, Tty: d.getContainerTTY(container.ID)}, true + if matched := cont.MatchString(container.ID); matched { + return &ContainerConfig{ID: container.ID, Name: container.Names[0], Labels: d.Config.Labels, Tty: d.getContainerTTY(container.ID)} } } for _, cont := range d.compiledContainerName { for _, name := range container.Names { - if matched := cont.Match([]byte(name)); matched { - return &ContainerConfig{ID: container.ID, Name: name, Labels: d.Config.Labels, Tty: d.getContainerTTY(container.ID)}, true + if matched := cont.MatchString(name); matched { + return &ContainerConfig{ID: container.ID, Name: name, Labels: d.Config.Labels, Tty: d.getContainerTTY(container.ID)} } } + } + if d.Config.UseContainerLabels { + parsedLabels := d.getContainerLabels(container.ID) + if len(parsedLabels) == 0 { + d.logger.Tracef("container has no 'crowdsec' labels set, ignoring container: %s", container.ID) + return nil + } + if _, ok := parsedLabels["enable"]; !ok { + d.logger.Errorf("container has 'crowdsec' labels set but no 'crowdsec.enable' key found") + return nil + } + enable, ok := parsedLabels["enable"].(string) + if !ok { + d.logger.Error("container has 'crowdsec.enable' label set but it's not a string") + return nil + } + if strings.ToLower(enable) != "true" { + d.logger.Debugf("container has 'crowdsec.enable' label not set to true ignoring container: %s", container.ID) + return nil + } + if _, ok = parsedLabels["labels"]; !ok { + d.logger.Error("container has 'crowdsec.enable' label set to true but no 'labels' keys found") + return nil + } + labelsTypeCast, ok := parsedLabels["labels"].(map[string]interface{}) + if !ok { + d.logger.Error("container has 'crowdsec.enable' label set to true but 'labels' is not a map") + return nil + } + d.logger.Debugf("container labels %+v", labelsTypeCast) + labels := make(map[string]string) + for k, v := range labelsTypeCast { + if v, ok := v.(string); ok { + log.Debugf("label %s is a string with value %s", k, v) + labels[k] = v + continue + } + d.logger.Errorf("label %s is not a string", k) + } + return &ContainerConfig{ID: container.ID, Name: container.Names[0], Labels: labels, Tty: d.getContainerTTY(container.ID)} } - return &ContainerConfig{}, false + return nil } func (d *DockerSource) WatchContainer(monitChan chan *ContainerConfig, deleteChan chan *ContainerConfig) error { @@ -446,7 +501,7 @@ func (d *DockerSource) WatchContainer(monitChan chan *ContainerConfig, deleteCha if _, ok := d.runningContainerState[container.ID]; ok { continue } - if containerConfig, ok := d.EvalContainer(container); ok { + if containerConfig := d.EvalContainer(container); containerConfig != nil { monitChan <- containerConfig } } @@ -463,7 +518,7 @@ func (d *DockerSource) WatchContainer(monitChan chan *ContainerConfig, deleteCha } } -func (d *DockerSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (d *DockerSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { d.t = t monitChan := make(chan *ContainerConfig) deleteChan := make(chan *ContainerConfig) @@ -519,7 +574,7 @@ func (d *DockerSource) TailDocker(container *ContainerConfig, outChan chan types } l := types.Line{} l.Raw = line - l.Labels = d.Config.Labels + l.Labels = container.Labels l.Time = time.Now().UTC() l.Src = container.Name l.Process = true @@ -534,11 +589,11 @@ func (d *DockerSource) TailDocker(container *ContainerConfig, outChan chan types outChan <- evt d.logger.Debugf("Sent line to parsing: %+v", evt.Line.Raw) case <-readerTomb.Dying(): - //This case is to handle temporarily losing the connection to the docker socket - //The only known case currently is when using docker-socket-proxy (and maybe a docker daemon restart) + // This case is to handle temporarily losing the connection to the docker socket + // The only known case currently is when using docker-socket-proxy (and maybe a docker daemon restart) d.logger.Debugf("readerTomb dying for container %s, removing it from runningContainerState", container.Name) deleteChan <- container - //Also reset the Since to avoid re-reading logs + // Also reset the Since to avoid re-reading logs d.Config.Since = time.Now().UTC().Format(time.RFC3339) d.containerLogsOptions.Since = d.Config.Since return nil @@ -553,7 +608,7 @@ func (d *DockerSource) DockerManager(in chan *ContainerConfig, deleteChan chan * case newContainer := <-in: if _, ok := d.runningContainerState[newContainer.ID]; !ok { newContainer.t = &tomb.Tomb{} - newContainer.logger = d.logger.WithFields(log.Fields{"container_name": newContainer.Name}) + newContainer.logger = d.logger.WithField("container_name", newContainer.Name) newContainer.t.Go(func() error { return d.TailDocker(newContainer, outChan, deleteChan) }) diff --git a/pkg/acquisition/modules/docker/docker_test.go b/pkg/acquisition/modules/docker/docker_test.go index 65c9927263a..e394c9cbe79 100644 --- a/pkg/acquisition/modules/docker/docker_test.go +++ b/pkg/acquisition/modules/docker/docker_test.go @@ -11,16 +11,17 @@ import ( "testing" "time" - "github.com/crowdsecurity/go-cs-lib/pkg/cstest" - - "github.com/crowdsecurity/crowdsec/pkg/types" dockerTypes "github.com/docker/docker/api/types" dockerContainer "github.com/docker/docker/api/types/container" "github.com/docker/docker/client" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "gopkg.in/tomb.v2" - "github.com/stretchr/testify/assert" + "github.com/crowdsecurity/go-cs-lib/cstest" + + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" + "github.com/crowdsecurity/crowdsec/pkg/types" ) const testContainerName = "docker_test" @@ -54,24 +55,26 @@ container_name: }, } - subLogger := log.WithFields(log.Fields{ - "type": "docker", - }) + subLogger := log.WithField("type", "docker") + for _, test := range tests { f := DockerSource{} - err := f.Configure([]byte(test.config), subLogger) + err := f.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) cstest.AssertErrorContains(t, err, test.expectedErr) } } func TestConfigureDSN(t *testing.T) { log.Infof("Test 'TestConfigureDSN'") + var dockerHost string + if runtime.GOOS == "windows" { dockerHost = "npipe:////./pipe/docker_engine" } else { dockerHost = "unix:///var/run/podman/podman.sock" } + tests := []struct { name string dsn string @@ -103,9 +106,8 @@ func TestConfigureDSN(t *testing.T) { expectedErr: "", }, } - subLogger := log.WithFields(log.Fields{ - "type": "docker", - }) + subLogger := log.WithField("type", "docker") + for _, test := range tests { f := DockerSource{} err := f.ConfigureByDSN(test.dsn, map[string]string{"type": "testtype"}, subLogger, "") @@ -118,6 +120,7 @@ type mockDockerCli struct { } func TestStreamingAcquisition(t *testing.T) { + ctx := context.Background() log.SetOutput(os.Stdout) log.SetLevel(log.InfoLevel) log.Info("Test 'TestStreamingAcquisition'") @@ -156,33 +159,34 @@ container_name_regexp: } for _, ts := range tests { - var logger *log.Logger - var subLogger *log.Entry + var ( + logger *log.Logger + subLogger *log.Entry + ) + if ts.expectedOutput != "" { logger.SetLevel(ts.logLevel) - subLogger = logger.WithFields(log.Fields{ - "type": "docker", - }) + subLogger = logger.WithField("type", "docker") } else { - subLogger = log.WithFields(log.Fields{ - "type": "docker", - }) + subLogger = log.WithField("type", "docker") } readLogs = false dockerTomb := tomb.Tomb{} out := make(chan types.Event) dockerSource := DockerSource{} - err := dockerSource.Configure([]byte(ts.config), subLogger) + + err := dockerSource.Configure([]byte(ts.config), subLogger, configuration.METRICS_NONE) if err != nil { t.Fatalf("Unexpected error : %s", err) } + dockerSource.Client = new(mockDockerCli) actualLines := 0 readerTomb := &tomb.Tomb{} streamTomb := tomb.Tomb{} streamTomb.Go(func() error { - return dockerSource.StreamingAcquisition(out, &dockerTomb) + return dockerSource.StreamingAcquisition(ctx, out, &dockerTomb) }) readerTomb.Go(func() error { time.Sleep(1 * time.Second) @@ -193,7 +197,7 @@ container_name_regexp: actualLines++ ticker.Reset(1 * time.Second) case <-ticker.C: - log.Infof("no more line to read") + log.Infof("no more lines to read") dockerSource.t.Kill(nil) return nil } @@ -204,21 +208,23 @@ container_name_regexp: if err := readerTomb.Wait(); err != nil { t.Fatal(err) } + if ts.expectedLines != 0 { assert.Equal(t, ts.expectedLines, actualLines) } + err = streamTomb.Wait() if err != nil { t.Fatalf("docker acquisition error: %s", err) } } - } func (cli *mockDockerCli) ContainerList(ctx context.Context, options dockerTypes.ContainerListOptions) ([]dockerTypes.Container, error) { - if readLogs == true { + if readLogs { return []dockerTypes.Container{}, nil } + containers := make([]dockerTypes.Container, 0) container := &dockerTypes.Container{ ID: "12456", @@ -230,19 +236,23 @@ func (cli *mockDockerCli) ContainerList(ctx context.Context, options dockerTypes } func (cli *mockDockerCli) ContainerLogs(ctx context.Context, container string, options dockerTypes.ContainerLogsOptions) (io.ReadCloser, error) { - if readLogs == true { + if readLogs { return io.NopCloser(strings.NewReader("")), nil } + readLogs = true data := []string{"docker\n", "test\n", "1234\n"} ret := "" + for _, line := range data { startLineByte := make([]byte, 8) - binary.LittleEndian.PutUint32(startLineByte, 1) //stdout stream + binary.LittleEndian.PutUint32(startLineByte, 1) // stdout stream binary.BigEndian.PutUint32(startLineByte[4:], uint32(len(line))) ret += fmt.Sprintf("%s%s", startLineByte, line) } + r := io.NopCloser(strings.NewReader(ret)) // r type is io.ReadCloser + return r, nil } @@ -252,6 +262,7 @@ func (cli *mockDockerCli) ContainerInspect(ctx context.Context, c string) (docke Tty: false, }, } + return r, nil } @@ -285,18 +296,17 @@ func TestOneShot(t *testing.T) { } for _, ts := range tests { - var subLogger *log.Entry - var logger *log.Logger + var ( + subLogger *log.Entry + logger *log.Logger + ) + if ts.expectedOutput != "" { logger.SetLevel(ts.logLevel) - subLogger = logger.WithFields(log.Fields{ - "type": "docker", - }) + subLogger = logger.WithField("type", "docker") } else { log.SetLevel(ts.logLevel) - subLogger = log.WithFields(log.Fields{ - "type": "docker", - }) + subLogger = log.WithField("type", "docker") } readLogs = false @@ -307,6 +317,7 @@ func TestOneShot(t *testing.T) { if err := dockerClient.ConfigureByDSN(ts.dsn, labels, subLogger, ""); err != nil { t.Fatalf("unable to configure dsn '%s': %s", ts.dsn, err) } + dockerClient.Client = new(mockDockerCli) out := make(chan types.Event, 100) tomb := tomb.Tomb{} @@ -315,8 +326,58 @@ func TestOneShot(t *testing.T) { // else we do the check before actualLines is incremented ... if ts.expectedLines != 0 { - assert.Equal(t, ts.expectedLines, len(out)) + assert.Len(t, out, ts.expectedLines) } } +} +func TestParseLabels(t *testing.T) { + tests := []struct { + name string + labels map[string]string + expected map[string]interface{} + }{ + { + name: "bad label", + labels: map[string]string{"crowdsecfoo": "bar"}, + expected: map[string]interface{}{}, + }, + { + name: "simple label", + labels: map[string]string{"crowdsec.bar": "baz"}, + expected: map[string]interface{}{"bar": "baz"}, + }, + { + name: "multiple simple labels", + labels: map[string]string{"crowdsec.bar": "baz", "crowdsec.foo": "bar"}, + expected: map[string]interface{}{"bar": "baz", "foo": "bar"}, + }, + { + name: "multiple simple labels 2", + labels: map[string]string{"crowdsec.bar": "baz", "bla": "foo"}, + expected: map[string]interface{}{"bar": "baz"}, + }, + { + name: "end with dot", + labels: map[string]string{"crowdsec.bar.": "baz"}, + expected: map[string]interface{}{}, + }, + { + name: "consecutive dots", + labels: map[string]string{"crowdsec......bar": "baz"}, + expected: map[string]interface{}{}, + }, + { + name: "crowdsec labels", + labels: map[string]string{"crowdsec.labels.type": "nginx"}, + expected: map[string]interface{}{"labels": map[string]interface{}{"type": "nginx"}}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + labels := parseLabels(test.labels) + assert.Equal(t, test.expected, labels) + }) + } } diff --git a/pkg/acquisition/modules/docker/utils.go b/pkg/acquisition/modules/docker/utils.go new file mode 100644 index 00000000000..6a0d494097f --- /dev/null +++ b/pkg/acquisition/modules/docker/utils.go @@ -0,0 +1,38 @@ +package dockeracquisition + +import ( + "strings" +) + +func parseLabels(labels map[string]string) map[string]interface{} { + result := make(map[string]interface{}) + for key, value := range labels { + parseKeyToMap(result, key, value) + } + return result +} + +func parseKeyToMap(m map[string]interface{}, key string, value string) { + if !strings.HasPrefix(key, "crowdsec") { + return + } + parts := strings.Split(key, ".") + + if len(parts) < 2 || parts[0] != "crowdsec" { + return + } + + for i := range parts { + if parts[i] == "" { + return + } + } + + for i := 1; i < len(parts)-1; i++ { + if _, ok := m[parts[i]]; !ok { + m[parts[i]] = make(map[string]interface{}) + } + m = m[parts[i]].(map[string]interface{}) + } + m[parts[len(parts)-1]] = value +} diff --git a/pkg/acquisition/modules/file/file.go b/pkg/acquisition/modules/file/file.go index 0aa1f6d92c0..2d2df3ff4d4 100644 --- a/pkg/acquisition/modules/file/file.go +++ b/pkg/acquisition/modules/file/file.go @@ -3,6 +3,8 @@ package fileacquisition import ( "bufio" "compress/gzip" + "context" + "errors" "fmt" "io" "net/url" @@ -11,17 +13,17 @@ import ( "regexp" "strconv" "strings" + "sync" "time" "github.com/fsnotify/fsnotify" "github.com/nxadm/tail" - "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" "gopkg.in/tomb.v2" "gopkg.in/yaml.v2" - "github.com/crowdsecurity/go-cs-lib/pkg/trace" + "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" "github.com/crowdsecurity/crowdsec/pkg/types" @@ -38,13 +40,14 @@ type FileConfiguration struct { Filenames []string ExcludeRegexps []string `yaml:"exclude_regexps"` Filename string - ForceInotify bool `yaml:"force_inotify"` - MaxBufferSize int `yaml:"max_buffer_size"` - PollWithoutInotify bool `yaml:"poll_without_inotify"` + ForceInotify bool `yaml:"force_inotify"` + MaxBufferSize int `yaml:"max_buffer_size"` + PollWithoutInotify *bool `yaml:"poll_without_inotify"` configuration.DataSourceCommonCfg `yaml:",inline"` } type FileSource struct { + metricsLevel int config FileConfiguration watcher *fsnotify.Watcher watchedDirectories map[string]bool @@ -52,6 +55,7 @@ type FileSource struct { logger *log.Entry files []string exclude_regexps []*regexp.Regexp + tailMapMutex *sync.RWMutex } func (f *FileSource) GetUuid() string { @@ -60,6 +64,7 @@ func (f *FileSource) GetUuid() string { func (f *FileSource) UnmarshalConfig(yamlConfig []byte) error { f.config = FileConfiguration{} + err := yaml.UnmarshalStrict(yamlConfig, &f.config) if err != nil { return fmt.Errorf("cannot parse FileAcquisition configuration: %w", err) @@ -69,12 +74,12 @@ func (f *FileSource) UnmarshalConfig(yamlConfig []byte) error { f.logger.Tracef("FileAcquisition configuration: %+v", f.config) } - if len(f.config.Filename) != 0 { + if f.config.Filename != "" { f.config.Filenames = append(f.config.Filenames, f.config.Filename) } if len(f.config.Filenames) == 0 { - return fmt.Errorf("no filename or filenames configuration provided") + return errors.New("no filename or filenames configuration provided") } if f.config.Mode == "" { @@ -90,14 +95,16 @@ func (f *FileSource) UnmarshalConfig(yamlConfig []byte) error { if err != nil { return fmt.Errorf("could not compile regexp %s: %w", exclude, err) } + f.exclude_regexps = append(f.exclude_regexps, re) } return nil } -func (f *FileSource) Configure(yamlConfig []byte, logger *log.Entry) error { +func (f *FileSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLevel int) error { f.logger = logger + f.metricsLevel = MetricsLevel err := f.UnmarshalConfig(yamlConfig) if err != nil { @@ -105,6 +112,7 @@ func (f *FileSource) Configure(yamlConfig []byte, logger *log.Entry) error { } f.watchedDirectories = make(map[string]bool) + f.tailMapMutex = &sync.RWMutex{} f.tails = make(map[string]bool) f.watcher, err = fsnotify.NewWatcher() @@ -118,56 +126,68 @@ func (f *FileSource) Configure(yamlConfig []byte, logger *log.Entry) error { if f.config.ForceInotify { directory := filepath.Dir(pattern) f.logger.Infof("Force add watch on %s", directory) + if !f.watchedDirectories[directory] { err = f.watcher.Add(directory) if err != nil { f.logger.Errorf("Could not create watch on directory %s : %s", directory, err) continue } + f.watchedDirectories[directory] = true } } + files, err := filepath.Glob(pattern) if err != nil { return fmt.Errorf("glob failure: %w", err) } + if len(files) == 0 { f.logger.Warnf("No matching files for pattern %s", pattern) continue } - for _, file := range files { - //check if file is excluded + for _, file := range files { + // check if file is excluded excluded := false + for _, pattern := range f.exclude_regexps { if pattern.MatchString(file) { excluded = true + f.logger.Infof("Skipping file %s as it matches exclude pattern %s", file, pattern) + break } } + if excluded { continue } - if files[0] != pattern && f.config.Mode == configuration.TAIL_MODE { //we have a glob pattern + + if files[0] != pattern && f.config.Mode == configuration.TAIL_MODE { // we have a glob pattern directory := filepath.Dir(file) f.logger.Debugf("Will add watch to directory: %s", directory) - if !f.watchedDirectories[directory] { + if !f.watchedDirectories[directory] { err = f.watcher.Add(directory) if err != nil { f.logger.Errorf("Could not create watch on directory %s : %s", directory, err) continue } + f.watchedDirectories[directory] = true } else { f.logger.Debugf("Watch for directory %s already exists", directory) } } + f.logger.Infof("Adding file %s to datasources", file) f.files = append(f.files, file) } } + return nil } @@ -183,34 +203,39 @@ func (f *FileSource) ConfigureByDSN(dsn string, labels map[string]string, logger args := strings.Split(dsn, "?") - if len(args[0]) == 0 { - return fmt.Errorf("empty file:// DSN") + if args[0] == "" { + return errors.New("empty file:// DSN") } - if len(args) == 2 && len(args[1]) != 0 { + if len(args) == 2 && args[1] != "" { params, err := url.ParseQuery(args[1]) if err != nil { return fmt.Errorf("could not parse file args: %w", err) } + for key, value := range params { switch key { case "log_level": if len(value) != 1 { return errors.New("expected zero or one value for 'log_level'") } + lvl, err := log.ParseLevel(value[0]) if err != nil { return fmt.Errorf("unknown level %s: %w", value[0], err) } + f.logger.Logger.SetLevel(lvl) case "max_buffer_size": if len(value) != 1 { return errors.New("expected zero or one value for 'max_buffer_size'") } + maxBufferSize, err := strconv.Atoi(value[0]) if err != nil { return fmt.Errorf("could not parse max_buffer_size %s: %w", value[0], err) } + f.config.MaxBufferSize = maxBufferSize default: return fmt.Errorf("unknown parameter %s", key) @@ -223,6 +248,7 @@ func (f *FileSource) ConfigureByDSN(dsn string, labels map[string]string, logger f.config.UniqueId = uuid f.logger.Debugf("Will try pattern %s", args[0]) + files, err := filepath.Glob(args[0]) if err != nil { return fmt.Errorf("glob failure: %w", err) @@ -240,6 +266,7 @@ func (f *FileSource) ConfigureByDSN(dsn string, labels map[string]string, logger f.logger.Infof("Adding file %s to filelist", file) f.files = append(f.files, file) } + return nil } @@ -255,22 +282,26 @@ func (f *FileSource) SupportedModes() []string { // OneShotAcquisition reads a set of file and returns when done func (f *FileSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { f.logger.Debug("In oneshot") + for _, file := range f.files { fi, err := os.Stat(file) if err != nil { return fmt.Errorf("could not stat file %s : %w", file, err) } + if fi.IsDir() { f.logger.Warnf("%s is a directory, ignoring it.", file) continue } + f.logger.Infof("reading %s at once", file) + err = f.readFile(file, out, t) if err != nil { return err } - } + return nil } @@ -290,32 +321,38 @@ func (f *FileSource) CanRun() error { return nil } -func (f *FileSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (f *FileSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { f.logger.Debug("Starting live acquisition") t.Go(func() error { return f.monitorNewFiles(out, t) }) + for _, file := range f.files { - //before opening the file, check if we need to specifically avoid it. (XXX) + // before opening the file, check if we need to specifically avoid it. (XXX) skip := false + for _, pattern := range f.exclude_regexps { if pattern.MatchString(file) { f.logger.Infof("file %s matches exclusion pattern %s, skipping", file, pattern.String()) + skip = true + break } } + if skip { continue } - //cf. https://github.com/crowdsecurity/crowdsec/issues/1168 - //do not rely on stat, reclose file immediately as it's opened by Tail + // cf. https://github.com/crowdsecurity/crowdsec/issues/1168 + // do not rely on stat, reclose file immediately as it's opened by Tail fd, err := os.Open(file) if err != nil { f.logger.Errorf("unable to read %s : %s", file, err) continue } + if err := fd.Close(); err != nil { f.logger.Errorf("unable to close %s : %s", file, err) continue @@ -325,22 +362,54 @@ func (f *FileSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) er if err != nil { return fmt.Errorf("could not stat file %s : %w", file, err) } + if fi.IsDir() { f.logger.Warnf("%s is a directory, ignoring it.", file) continue } - tail, err := tail.TailFile(file, tail.Config{ReOpen: true, Follow: true, Poll: f.config.PollWithoutInotify, Location: &tail.SeekInfo{Offset: 0, Whence: io.SeekEnd}, Logger: log.NewEntry(log.StandardLogger())}) + pollFile := false + if f.config.PollWithoutInotify != nil { + pollFile = *f.config.PollWithoutInotify + } else { + networkFS, fsType, err := types.IsNetworkFS(file) + if err != nil { + f.logger.Warningf("Could not get fs type for %s : %s", file, err) + } + + f.logger.Debugf("fs for %s is network: %t (%s)", file, networkFS, fsType) + + if networkFS { + f.logger.Warnf("Disabling inotify polling on %s as it is on a network share. You can manually set poll_without_inotify to true to make this message disappear, or to false to enforce inotify poll", file) + pollFile = true + } + } + + filink, err := os.Lstat(file) + if err != nil { + f.logger.Errorf("Could not lstat() new file %s, ignoring it : %s", file, err) + continue + } + + if filink.Mode()&os.ModeSymlink == os.ModeSymlink && !pollFile { + f.logger.Warnf("File %s is a symlink, but inotify polling is enabled. Crowdsec will not be able to detect rotation. Consider setting poll_without_inotify to true in your configuration", file) + } + + tail, err := tail.TailFile(file, tail.Config{ReOpen: true, Follow: true, Poll: pollFile, Location: &tail.SeekInfo{Offset: 0, Whence: io.SeekEnd}, Logger: log.NewEntry(log.StandardLogger())}) if err != nil { f.logger.Errorf("Could not start tailing file %s : %s", file, err) continue } + + f.tailMapMutex.Lock() f.tails[file] = true + f.tailMapMutex.Unlock() t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/file/live/fsnotify") return f.tailFile(out, t, tail) }) } + return nil } @@ -350,6 +419,7 @@ func (f *FileSource) Dump() interface{} { func (f *FileSource) monitorNewFiles(out chan types.Event, t *tomb.Tomb) error { logger := f.logger.WithField("goroutine", "inotify") + for { select { case event, ok := <-f.watcher.Events: @@ -357,84 +427,134 @@ func (f *FileSource) monitorNewFiles(out chan types.Event, t *tomb.Tomb) error { return nil } - if event.Op&fsnotify.Create == fsnotify.Create { - fi, err := os.Stat(event.Name) + if event.Op&fsnotify.Create != fsnotify.Create { + continue + } + + fi, err := os.Stat(event.Name) + if err != nil { + logger.Errorf("Could not stat() new file %s, ignoring it : %s", event.Name, err) + continue + } + + if fi.IsDir() { + continue + } + + logger.Debugf("Detected new file %s", event.Name) + + matched := false + + for _, pattern := range f.config.Filenames { + logger.Debugf("Matching %s with %s", pattern, event.Name) + + matched, err = filepath.Match(pattern, event.Name) if err != nil { - logger.Errorf("Could not stat() new file %s, ignoring it : %s", event.Name, err) - continue - } - if fi.IsDir() { - continue - } - logger.Debugf("Detected new file %s", event.Name) - matched := false - for _, pattern := range f.config.Filenames { - logger.Debugf("Matching %s with %s", pattern, event.Name) - matched, err = filepath.Match(pattern, event.Name) - if err != nil { - logger.Errorf("Could not match pattern : %s", err) - continue - } - if matched { - logger.Debugf("Matched %s with %s", pattern, event.Name) - break - } - } - if !matched { + logger.Errorf("Could not match pattern : %s", err) continue } - //before opening the file, check if we need to specifically avoid it. (XXX) - skip := false - for _, pattern := range f.exclude_regexps { - if pattern.MatchString(event.Name) { - f.logger.Infof("file %s matches exclusion pattern %s, skipping", event.Name, pattern.String()) - skip = true - break - } - } - if skip { - continue + if matched { + logger.Debugf("Matched %s with %s", pattern, event.Name) + break } + } + + if !matched { + continue + } + + // before opening the file, check if we need to specifically avoid it. (XXX) + skip := false + + for _, pattern := range f.exclude_regexps { + if pattern.MatchString(event.Name) { + f.logger.Infof("file %s matches exclusion pattern %s, skipping", event.Name, pattern.String()) + + skip = true - if f.tails[event.Name] { - //we already have a tail on it, do not start a new one - logger.Debugf("Already tailing file %s, not creating a new tail", event.Name) break } - //cf. https://github.com/crowdsecurity/crowdsec/issues/1168 - //do not rely on stat, reclose file immediately as it's opened by Tail - fd, err := os.Open(event.Name) + } + + if skip { + continue + } + + f.tailMapMutex.RLock() + if f.tails[event.Name] { + f.tailMapMutex.RUnlock() + // we already have a tail on it, do not start a new one + logger.Debugf("Already tailing file %s, not creating a new tail", event.Name) + + break + } + f.tailMapMutex.RUnlock() + // cf. https://github.com/crowdsecurity/crowdsec/issues/1168 + // do not rely on stat, reclose file immediately as it's opened by Tail + fd, err := os.Open(event.Name) + if err != nil { + f.logger.Errorf("unable to read %s : %s", event.Name, err) + continue + } + + if err = fd.Close(); err != nil { + f.logger.Errorf("unable to close %s : %s", event.Name, err) + continue + } + + pollFile := false + if f.config.PollWithoutInotify != nil { + pollFile = *f.config.PollWithoutInotify + } else { + networkFS, fsType, err := types.IsNetworkFS(event.Name) if err != nil { - f.logger.Errorf("unable to read %s : %s", event.Name, err) - continue - } - if err := fd.Close(); err != nil { - f.logger.Errorf("unable to close %s : %s", event.Name, err) - continue + f.logger.Warningf("Could not get fs type for %s : %s", event.Name, err) } - //Slightly different parameters for Location, as we want to read the first lines of the newly created file - tail, err := tail.TailFile(event.Name, tail.Config{ReOpen: true, Follow: true, Poll: f.config.PollWithoutInotify, Location: &tail.SeekInfo{Offset: 0, Whence: io.SeekStart}}) - if err != nil { - logger.Errorf("Could not start tailing file %s : %s", event.Name, err) - break + + f.logger.Debugf("fs for %s is network: %t (%s)", event.Name, networkFS, fsType) + + if networkFS { + pollFile = true } - f.tails[event.Name] = true - t.Go(func() error { - defer trace.CatchPanic("crowdsec/acquis/tailfile") - return f.tailFile(out, t, tail) - }) } + + filink, err := os.Lstat(event.Name) + if err != nil { + logger.Errorf("Could not lstat() new file %s, ignoring it : %s", event.Name, err) + continue + } + + if filink.Mode()&os.ModeSymlink == os.ModeSymlink && !pollFile { + logger.Warnf("File %s is a symlink, but inotify polling is enabled. Crowdsec will not be able to detect rotation. Consider setting poll_without_inotify to true in your configuration", event.Name) + } + + // Slightly different parameters for Location, as we want to read the first lines of the newly created file + tail, err := tail.TailFile(event.Name, tail.Config{ReOpen: true, Follow: true, Poll: pollFile, Location: &tail.SeekInfo{Offset: 0, Whence: io.SeekStart}}) + if err != nil { + logger.Errorf("Could not start tailing file %s : %s", event.Name, err) + break + } + + f.tailMapMutex.Lock() + f.tails[event.Name] = true + f.tailMapMutex.Unlock() + t.Go(func() error { + defer trace.CatchPanic("crowdsec/acquis/tailfile") + return f.tailFile(out, t, tail) + }) case err, ok := <-f.watcher.Errors: if !ok { return nil } + logger.Errorf("Error while monitoring folder: %s", err) case <-t.Dying(): err := f.watcher.Close() if err != nil { return fmt.Errorf("could not remove all inotify watches: %w", err) } + return nil } } @@ -443,41 +563,62 @@ func (f *FileSource) monitorNewFiles(out chan types.Event, t *tomb.Tomb) error { func (f *FileSource) tailFile(out chan types.Event, t *tomb.Tomb, tail *tail.Tail) error { logger := f.logger.WithField("tail", tail.Filename) logger.Debugf("-> Starting tail of %s", tail.Filename) + for { select { case <-t.Dying(): logger.Infof("File datasource %s stopping", tail.Filename) + if err := tail.Stop(); err != nil { f.logger.Errorf("error in stop : %s", err) return err } + + return nil + case <-tail.Dying(): // our tailer is dying + errMsg := fmt.Sprintf("file reader of %s died", tail.Filename) + + err := tail.Err() + if err != nil { + errMsg = fmt.Sprintf(errMsg+" : %s", err) + } + + logger.Warning(errMsg) + return nil - case <-tail.Dying(): //our tailer is dying - logger.Warningf("File reader of %s died", tail.Filename) - t.Kill(fmt.Errorf("dead reader for %s", tail.Filename)) - return fmt.Errorf("reader for %s is dead", tail.Filename) case line := <-tail.Lines: if line == nil { logger.Warningf("tail for %s is empty", tail.Filename) continue } + if line.Err != nil { logger.Warningf("fetch error : %v", line.Err) return line.Err } - if line.Text == "" { //skip empty lines + + if line.Text == "" { // skip empty lines continue } - linesRead.With(prometheus.Labels{"source": tail.Filename}).Inc() + + if f.metricsLevel != configuration.METRICS_NONE { + linesRead.With(prometheus.Labels{"source": tail.Filename}).Inc() + } + + src := tail.Filename + if f.metricsLevel == configuration.METRICS_AGGREGATE { + src = filepath.Base(tail.Filename) + } + l := types.Line{ Raw: trimLine(line.Text), Labels: f.config.Labels, Time: line.Time, - Src: tail.Filename, + Src: src, Process: true, Module: f.GetName(), } - //we're tailing, it must be real time logs + // we're tailing, it must be real time logs logger.Debugf("pushing %+v", l) expectMode := types.LIVE @@ -491,12 +632,14 @@ func (f *FileSource) tailFile(out chan types.Event, t *tomb.Tomb, tail *tail.Tai func (f *FileSource) readFile(filename string, out chan types.Event, t *tomb.Tomb) error { var scanner *bufio.Scanner + logger := f.logger.WithField("oneshot", filename) - fd, err := os.Open(filename) + fd, err := os.Open(filename) if err != nil { return fmt.Errorf("failed opening %s: %w", filename, err) } + defer fd.Close() if strings.HasSuffix(filename, ".gz") { @@ -505,17 +648,20 @@ func (f *FileSource) readFile(filename string, out chan types.Event, t *tomb.Tom logger.Errorf("Failed to read gz file: %s", err) return fmt.Errorf("failed to read gz %s: %w", filename, err) } + defer gz.Close() scanner = bufio.NewScanner(gz) - } else { scanner = bufio.NewScanner(fd) } + scanner.Split(bufio.ScanLines) + if f.config.MaxBufferSize > 0 { buf := make([]byte, 0, 64*1024) scanner.Buffer(buf, f.config.MaxBufferSize) } + for scanner.Scan() { select { case <-t.Dying(): @@ -525,6 +671,7 @@ func (f *FileSource) readFile(filename string, out chan types.Event, t *tomb.Tom if scanner.Text() == "" { continue } + l := types.Line{ Raw: scanner.Text(), Time: time.Now().UTC(), @@ -536,15 +683,19 @@ func (f *FileSource) readFile(filename string, out chan types.Event, t *tomb.Tom logger.Debugf("line %s", l.Raw) linesRead.With(prometheus.Labels{"source": filename}).Inc() - //we're reading logs at once, it must be time-machine buckets + // we're reading logs at once, it must be time-machine buckets out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE} } } + if err := scanner.Err(); err != nil { logger.Errorf("Error while reading file: %s", err) t.Kill(err) + return err } + t.Kill(nil) + return nil } diff --git a/pkg/acquisition/modules/file/file_test.go b/pkg/acquisition/modules/file/file_test.go index ba33d9bf9e6..3db0042ba2f 100644 --- a/pkg/acquisition/modules/file/file_test.go +++ b/pkg/acquisition/modules/file/file_test.go @@ -1,6 +1,7 @@ package fileacquisition_test import ( + "context" "fmt" "os" "runtime" @@ -13,8 +14,9 @@ import ( "github.com/stretchr/testify/require" "gopkg.in/tomb.v2" - "github.com/crowdsecurity/go-cs-lib/pkg/cstest" + "github.com/crowdsecurity/go-cs-lib/cstest" + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" fileacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/file" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -48,15 +50,12 @@ exclude_regexps: ["as[a-$d"]`, }, } - subLogger := log.WithFields(log.Fields{ - "type": "file", - }) + subLogger := log.WithField("type", "file") for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { f := fileacquisition.FileSource{} - err := f.Configure([]byte(tc.config), subLogger) + err := f.Configure([]byte(tc.config), subLogger, configuration.METRICS_NONE) cstest.RequireErrorContains(t, err, tc.expectedErr) }) } @@ -90,12 +89,9 @@ func TestConfigureDSN(t *testing.T) { }, } - subLogger := log.WithFields(log.Fields{ - "type": "file", - }) + subLogger := log.WithField("type", "file") for _, tc := range tests { - tc := tc t.Run(tc.dsn, func(t *testing.T) { f := fileacquisition.FileSource{} err := f.ConfigureByDSN(tc.dsn, map[string]string{"type": "testtype"}, subLogger, "") @@ -205,14 +201,11 @@ filename: test_files/test_delete.log`, } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { logger, hook := test.NewNullLogger() logger.SetLevel(tc.logLevel) - subLogger := logger.WithFields(log.Fields{ - "type": "file", - }) + subLogger := logger.WithField("type", "file") tomb := tomb.Tomb{} out := make(chan types.Event, 100) @@ -222,7 +215,7 @@ filename: test_files/test_delete.log`, tc.setup() } - err := f.Configure([]byte(tc.config), subLogger) + err := f.Configure([]byte(tc.config), subLogger, configuration.METRICS_NONE) cstest.RequireErrorContains(t, err, tc.expectedConfigErr) if tc.expectedConfigErr != "" { return @@ -251,6 +244,7 @@ filename: test_files/test_delete.log`, } func TestLiveAcquisition(t *testing.T) { + ctx := context.Background() permDeniedFile := "/etc/shadow" permDeniedError := "unable to read /etc/shadow : open /etc/shadow: permission denied" testPattern := "test_files/*.log" @@ -366,14 +360,11 @@ force_inotify: true`, testPattern), } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { logger, hook := test.NewNullLogger() logger.SetLevel(tc.logLevel) - subLogger := logger.WithFields(log.Fields{ - "type": "file", - }) + subLogger := logger.WithField("type", "file") tomb := tomb.Tomb{} out := make(chan types.Event) @@ -384,7 +375,7 @@ force_inotify: true`, testPattern), tc.setup() } - err := f.Configure([]byte(tc.config), subLogger) + err := f.Configure([]byte(tc.config), subLogger, configuration.METRICS_NONE) require.NoError(t, err) if tc.afterConfigure != nil { @@ -405,26 +396,24 @@ force_inotify: true`, testPattern), }() } - err = f.StreamingAcquisition(out, &tomb) + err = f.StreamingAcquisition(ctx, out, &tomb) cstest.RequireErrorContains(t, err, tc.expectedErr) if tc.expectedLines != 0 { fd, err := os.Create("test_files/stream.log") - if err != nil { - t.Fatalf("could not create test file : %s", err) - } + require.NoError(t, err, "could not create test file") - for i := 0; i < 5; i++ { + for i := range 5 { _, err = fmt.Fprintf(fd, "%d\n", i) if err != nil { - t.Fatalf("could not write test file : %s", err) os.Remove("test_files/stream.log") + t.Fatalf("could not write test file : %s", err) } } fd.Close() // we sleep to make sure we detect the new file - time.Sleep(1 * time.Second) + time.Sleep(3 * time.Second) os.Remove("test_files/stream.log") assert.Equal(t, tc.expectedLines, actualLines) } @@ -452,12 +441,10 @@ func TestExclusion(t *testing.T) { exclude_regexps: ["\\.gz$"]` logger, hook := test.NewNullLogger() // logger.SetLevel(ts.logLevel) - subLogger := logger.WithFields(log.Fields{ - "type": "file", - }) + subLogger := logger.WithField("type", "file") f := fileacquisition.FileSource{} - if err := f.Configure([]byte(config), subLogger); err != nil { + if err := f.Configure([]byte(config), subLogger, configuration.METRICS_NONE); err != nil { subLogger.Fatalf("unexpected error: %s", err) } diff --git a/pkg/acquisition/modules/file/tailline.go b/pkg/acquisition/modules/file/tailline.go index ac377b6636e..0de95e4a95c 100644 --- a/pkg/acquisition/modules/file/tailline.go +++ b/pkg/acquisition/modules/file/tailline.go @@ -1,4 +1,4 @@ -// +build linux freebsd netbsd openbsd solaris !windows +//go:build !windows package fileacquisition diff --git a/pkg/acquisition/modules/file/tailline_windows.go b/pkg/acquisition/modules/file/tailline_windows.go index 0c853c6e9c2..2c382b9b342 100644 --- a/pkg/acquisition/modules/file/tailline_windows.go +++ b/pkg/acquisition/modules/file/tailline_windows.go @@ -1,4 +1,4 @@ -// +build windows +//go:build windows package fileacquisition diff --git a/pkg/acquisition/modules/journalctl/journalctl.go b/pkg/acquisition/modules/journalctl/journalctl.go index b060ac364c4..b9cda54a472 100644 --- a/pkg/acquisition/modules/journalctl/journalctl.go +++ b/pkg/acquisition/modules/journalctl/journalctl.go @@ -3,6 +3,7 @@ package journalctlacquisition import ( "bufio" "context" + "errors" "fmt" "net/url" "os/exec" @@ -14,7 +15,7 @@ import ( "gopkg.in/tomb.v2" "gopkg.in/yaml.v2" - "github.com/crowdsecurity/go-cs-lib/pkg/trace" + "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" "github.com/crowdsecurity/crowdsec/pkg/types" @@ -26,10 +27,11 @@ type JournalCtlConfiguration struct { } type JournalCtlSource struct { - config JournalCtlConfiguration - logger *log.Entry - src string - args []string + metricsLevel int + config JournalCtlConfiguration + logger *log.Entry + src string + args []string } const journalctlCmd string = "journalctl" @@ -97,7 +99,7 @@ func (j *JournalCtlSource) runJournalCtl(out chan types.Event, t *tomb.Tomb) err if stdoutscanner == nil { cancel() cmd.Wait() - return fmt.Errorf("failed to create stdout scanner") + return errors.New("failed to create stdout scanner") } stderrScanner := bufio.NewScanner(stderr) @@ -105,13 +107,13 @@ func (j *JournalCtlSource) runJournalCtl(out chan types.Event, t *tomb.Tomb) err if stderrScanner == nil { cancel() cmd.Wait() - return fmt.Errorf("failed to create stderr scanner") + return errors.New("failed to create stderr scanner") } t.Go(func() error { return readLine(stdoutscanner, stdoutChan, errChan) }) t.Go(func() error { - //looks like journalctl closes stderr quite early, so ignore its status (but not its output) + // looks like journalctl closes stderr quite early, so ignore its status (but not its output) return readLine(stderrScanner, stderrChan, nil) }) @@ -120,7 +122,7 @@ func (j *JournalCtlSource) runJournalCtl(out chan types.Event, t *tomb.Tomb) err case <-t.Dying(): logger.Infof("journalctl datasource %s stopping", j.src) cancel() - cmd.Wait() //avoid zombie process + cmd.Wait() // avoid zombie process return nil case stdoutLine := <-stdoutChan: l := types.Line{} @@ -131,7 +133,9 @@ func (j *JournalCtlSource) runJournalCtl(out chan types.Event, t *tomb.Tomb) err l.Src = j.src l.Process = true l.Module = j.GetName() - linesRead.With(prometheus.Labels{"source": j.src}).Inc() + if j.metricsLevel != configuration.METRICS_NONE { + linesRead.With(prometheus.Labels{"source": j.src}).Inc() + } var evt types.Event if !j.config.UseTimeMachine { evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE} @@ -186,7 +190,7 @@ func (j *JournalCtlSource) UnmarshalConfig(yamlConfig []byte) error { } if len(j.config.Filters) == 0 { - return fmt.Errorf("journalctl_filter is required") + return errors.New("journalctl_filter is required") } j.args = append(args, j.config.Filters...) j.src = fmt.Sprintf("journalctl-%s", strings.Join(j.config.Filters, ".")) @@ -194,8 +198,9 @@ func (j *JournalCtlSource) UnmarshalConfig(yamlConfig []byte) error { return nil } -func (j *JournalCtlSource) Configure(yamlConfig []byte, logger *log.Entry) error { +func (j *JournalCtlSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLevel int) error { j.logger = logger + j.metricsLevel = MetricsLevel err := j.UnmarshalConfig(yamlConfig) if err != nil { @@ -212,14 +217,14 @@ func (j *JournalCtlSource) ConfigureByDSN(dsn string, labels map[string]string, j.config.Labels = labels j.config.UniqueId = uuid - //format for the DSN is : journalctl://filters=FILTER1&filters=FILTER2 + // format for the DSN is : journalctl://filters=FILTER1&filters=FILTER2 if !strings.HasPrefix(dsn, "journalctl://") { return fmt.Errorf("invalid DSN %s for journalctl source, must start with journalctl://", dsn) } qs := strings.TrimPrefix(dsn, "journalctl://") - if len(qs) == 0 { - return fmt.Errorf("empty journalctl:// DSN") + if qs == "" { + return errors.New("empty journalctl:// DSN") } params, err := url.ParseQuery(qs) @@ -232,7 +237,7 @@ func (j *JournalCtlSource) ConfigureByDSN(dsn string, labels map[string]string, j.config.Filters = append(j.config.Filters, value...) case "log_level": if len(value) != 1 { - return fmt.Errorf("expected zero or one value for 'log_level'") + return errors.New("expected zero or one value for 'log_level'") } lvl, err := log.ParseLevel(value[0]) if err != nil { @@ -262,21 +267,22 @@ func (j *JournalCtlSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb err := j.runJournalCtl(out, t) j.logger.Debug("Oneshot journalctl acquisition is done") return err - } -func (j *JournalCtlSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (j *JournalCtlSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/journalctl/streaming") return j.runJournalCtl(out, t) }) return nil } + func (j *JournalCtlSource) CanRun() error { - //TODO: add a more precise check on version or something ? + // TODO: add a more precise check on version or something ? _, err := exec.LookPath(journalctlCmd) return err } + func (j *JournalCtlSource) Dump() interface{} { return j } diff --git a/pkg/acquisition/modules/journalctl/journalctl_test.go b/pkg/acquisition/modules/journalctl/journalctl_test.go index 2c04c902820..c416bb5d23e 100644 --- a/pkg/acquisition/modules/journalctl/journalctl_test.go +++ b/pkg/acquisition/modules/journalctl/journalctl_test.go @@ -1,6 +1,7 @@ package journalctlacquisition import ( + "context" "os" "os/exec" "path/filepath" @@ -8,19 +9,22 @@ import ( "testing" "time" - "github.com/crowdsecurity/go-cs-lib/pkg/cstest" - - "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/assert" "gopkg.in/tomb.v2" + + "github.com/crowdsecurity/go-cs-lib/cstest" + + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" + "github.com/crowdsecurity/crowdsec/pkg/types" ) func TestBadConfiguration(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } + tests := []struct { config string expectedErr string @@ -45,12 +49,11 @@ journalctl_filter: }, } - subLogger := log.WithFields(log.Fields{ - "type": "journalctl", - }) + subLogger := log.WithField("type", "journalctl") + for _, test := range tests { f := JournalCtlSource{} - err := f.Configure([]byte(test.config), subLogger) + err := f.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) cstest.AssertErrorContains(t, err, test.expectedErr) } } @@ -59,6 +62,7 @@ func TestConfigureDSN(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } + tests := []struct { dsn string expectedErr string @@ -92,9 +96,9 @@ func TestConfigureDSN(t *testing.T) { expectedErr: "", }, } - subLogger := log.WithFields(log.Fields{ - "type": "journalctl", - }) + + subLogger := log.WithField("type", "journalctl") + for _, test := range tests { f := JournalCtlSource{} err := f.ConfigureByDSN(test.dsn, map[string]string{"type": "testtype"}, subLogger, "") @@ -106,6 +110,7 @@ func TestOneShot(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } + tests := []struct { config string expectedErr string @@ -137,41 +142,45 @@ journalctl_filter: }, } for _, ts := range tests { - var logger *log.Logger - var subLogger *log.Entry - var hook *test.Hook + var ( + logger *log.Logger + subLogger *log.Entry + hook *test.Hook + ) + if ts.expectedOutput != "" { logger, hook = test.NewNullLogger() logger.SetLevel(ts.logLevel) - subLogger = logger.WithFields(log.Fields{ - "type": "journalctl", - }) + subLogger = logger.WithField("type", "journalctl") } else { - subLogger = log.WithFields(log.Fields{ - "type": "journalctl", - }) + subLogger = log.WithField("type", "journalctl") } + tomb := tomb.Tomb{} out := make(chan types.Event, 100) j := JournalCtlSource{} - err := j.Configure([]byte(ts.config), subLogger) + + err := j.Configure([]byte(ts.config), subLogger, configuration.METRICS_NONE) if err != nil { t.Fatalf("Unexpected error : %s", err) } + err = j.OneShotAcquisition(out, &tomb) cstest.AssertErrorContains(t, err, ts.expectedErr) + if err != nil { continue } if ts.expectedLines != 0 { - assert.Equal(t, ts.expectedLines, len(out)) + assert.Len(t, out, ts.expectedLines) } if ts.expectedOutput != "" { if hook.LastEntry() == nil { t.Fatalf("Expected log output '%s' but got nothing !", ts.expectedOutput) } + assert.Contains(t, hook.LastEntry().Message, ts.expectedOutput) hook.Reset() } @@ -179,9 +188,11 @@ journalctl_filter: } func TestStreaming(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } + tests := []struct { config string expectedErr string @@ -202,28 +213,31 @@ journalctl_filter: }, } for _, ts := range tests { - var logger *log.Logger - var subLogger *log.Entry - var hook *test.Hook + var ( + logger *log.Logger + subLogger *log.Entry + hook *test.Hook + ) + if ts.expectedOutput != "" { logger, hook = test.NewNullLogger() logger.SetLevel(ts.logLevel) - subLogger = logger.WithFields(log.Fields{ - "type": "journalctl", - }) + subLogger = logger.WithField("type", "journalctl") } else { - subLogger = log.WithFields(log.Fields{ - "type": "journalctl", - }) + subLogger = log.WithField("type", "journalctl") } + tomb := tomb.Tomb{} out := make(chan types.Event) j := JournalCtlSource{} - err := j.Configure([]byte(ts.config), subLogger) + + err := j.Configure([]byte(ts.config), subLogger, configuration.METRICS_NONE) if err != nil { t.Fatalf("Unexpected error : %s", err) } + actualLines := 0 + if ts.expectedLines != 0 { go func() { READLOOP: @@ -238,8 +252,9 @@ journalctl_filter: }() } - err = j.StreamingAcquisition(out, &tomb) + err = j.StreamingAcquisition(ctx, out, &tomb) cstest.AssertErrorContains(t, err, ts.expectedErr) + if err != nil { continue } @@ -248,16 +263,20 @@ journalctl_filter: time.Sleep(1 * time.Second) assert.Equal(t, ts.expectedLines, actualLines) } + tomb.Kill(nil) tomb.Wait() + output, _ := exec.Command("pgrep", "-x", "journalctl").CombinedOutput() if string(output) != "" { t.Fatalf("Found a journalctl process after killing the tomb !") } + if ts.expectedOutput != "" { if hook.LastEntry() == nil { t.Fatalf("Expected log output '%s' but got nothing !", ts.expectedOutput) } + assert.Contains(t, hook.LastEntry().Message, ts.expectedOutput) hook.Reset() } @@ -270,5 +289,6 @@ func TestMain(m *testing.M) { fullPath := filepath.Join(currentDir, "test_files") os.Setenv("PATH", fullPath+":"+os.Getenv("PATH")) } + os.Exit(m.Run()) } diff --git a/pkg/acquisition/modules/kafka/kafka.go b/pkg/acquisition/modules/kafka/kafka.go index dba8daf7592..9fd5fc2a035 100644 --- a/pkg/acquisition/modules/kafka/kafka.go +++ b/pkg/acquisition/modules/kafka/kafka.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "errors" "fmt" "io" "os" @@ -16,15 +17,13 @@ import ( "gopkg.in/tomb.v2" "gopkg.in/yaml.v2" - "github.com/crowdsecurity/go-cs-lib/pkg/trace" + "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" "github.com/crowdsecurity/crowdsec/pkg/types" ) -var ( - dataSourceName = "kafka" -) +var dataSourceName = "kafka" var linesRead = prometheus.NewCounterVec( prometheus.CounterOpts{ @@ -37,6 +36,7 @@ type KafkaConfiguration struct { Brokers []string `yaml:"brokers"` Topic string `yaml:"topic"` GroupID string `yaml:"group_id"` + Partition int `yaml:"partition"` Timeout string `yaml:"timeout"` TLS *TLSConfig `yaml:"tls"` configuration.DataSourceCommonCfg `yaml:",inline"` @@ -50,9 +50,10 @@ type TLSConfig struct { } type KafkaSource struct { - Config KafkaConfiguration - logger *log.Entry - Reader *kafka.Reader + metricsLevel int + Config KafkaConfiguration + logger *log.Entry + Reader *kafka.Reader } func (k *KafkaSource) GetUuid() string { @@ -79,11 +80,16 @@ func (k *KafkaSource) UnmarshalConfig(yamlConfig []byte) error { k.Config.Mode = configuration.TAIL_MODE } + k.logger.Debugf("successfully parsed kafka configuration : %+v", k.Config) + return err } -func (k *KafkaSource) Configure(yamlConfig []byte, logger *log.Entry) error { +func (k *KafkaSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLevel int) error { k.logger = logger + k.metricsLevel = MetricsLevel + + k.logger.Debugf("start configuring %s source", dataSourceName) err := k.UnmarshalConfig(yamlConfig) if err != nil { @@ -95,7 +101,7 @@ func (k *KafkaSource) Configure(yamlConfig []byte, logger *log.Entry) error { return fmt.Errorf("cannot create %s dialer: %w", dataSourceName, err) } - k.Reader, err = k.Config.NewReader(dialer) + k.Reader, err = k.Config.NewReader(dialer, k.logger) if err != nil { return fmt.Errorf("cannote create %s reader: %w", dataSourceName, err) } @@ -104,6 +110,8 @@ func (k *KafkaSource) Configure(yamlConfig []byte, logger *log.Entry) error { return fmt.Errorf("cannot create %s reader", dataSourceName) } + k.logger.Debugf("successfully configured %s source", dataSourceName) + return nil } @@ -143,13 +151,16 @@ func (k *KafkaSource) ReadMessage(out chan types.Event) error { // Start processing from latest Offset k.Reader.SetOffsetAt(context.Background(), time.Now()) for { + k.logger.Tracef("reading message from topic '%s'", k.Config.Topic) m, err := k.Reader.ReadMessage(context.Background()) if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { return nil } k.logger.Errorln(fmt.Errorf("while reading %s message: %w", dataSourceName, err)) + continue } + k.logger.Tracef("got message: %s", string(m.Value)) l := types.Line{ Raw: string(m.Value), Labels: k.Config.Labels, @@ -158,7 +169,10 @@ func (k *KafkaSource) ReadMessage(out chan types.Event) error { Process: true, Module: k.GetName(), } - linesRead.With(prometheus.Labels{"topic": k.Config.Topic}).Inc() + k.logger.Tracef("line with message read from topic '%s': %+v", k.Config.Topic, l) + if k.metricsLevel != configuration.METRICS_NONE { + linesRead.With(prometheus.Labels{"topic": k.Config.Topic}).Inc() + } var evt types.Event if !k.Config.UseTimeMachine { @@ -171,6 +185,7 @@ func (k *KafkaSource) ReadMessage(out chan types.Event) error { } func (k *KafkaSource) RunReader(out chan types.Event, t *tomb.Tomb) error { + k.logger.Debugf("starting %s datasource reader goroutine with configuration %+v", dataSourceName, k.Config) t.Go(func() error { return k.ReadMessage(out) }) @@ -187,8 +202,8 @@ func (k *KafkaSource) RunReader(out chan types.Event, t *tomb.Tomb) error { } } -func (k *KafkaSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { - k.logger.Infof("start reader on topic '%s'", k.Config.Topic) +func (k *KafkaSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { + k.logger.Infof("start reader on brokers '%+v' with topic '%s'", k.Config.Brokers, k.Config.Topic) t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/kafka/live") @@ -223,7 +238,6 @@ func (kc *KafkaConfiguration) NewTLSConfig() (*tls.Config, error) { caCertPool.AppendCertsFromPEM(caCert) tlsConfig.RootCAs = caCertPool - tlsConfig.BuildNameToCertificate() return &tlsConfig, err } @@ -253,14 +267,23 @@ func (kc *KafkaConfiguration) NewDialer() (*kafka.Dialer, error) { return dialer, nil } -func (kc *KafkaConfiguration) NewReader(dialer *kafka.Dialer) (*kafka.Reader, error) { +func (kc *KafkaConfiguration) NewReader(dialer *kafka.Dialer, logger *log.Entry) (*kafka.Reader, error) { rConf := kafka.ReaderConfig{ - Brokers: kc.Brokers, - Topic: kc.Topic, - Dialer: dialer, + Brokers: kc.Brokers, + Topic: kc.Topic, + Dialer: dialer, + Logger: kafka.LoggerFunc(logger.Debugf), + ErrorLogger: kafka.LoggerFunc(logger.Errorf), + } + if kc.GroupID != "" && kc.Partition != 0 { + return &kafka.Reader{}, errors.New("cannot specify both group_id and partition") } if kc.GroupID != "" { rConf.GroupID = kc.GroupID + } else if kc.Partition != 0 { + rConf.Partition = kc.Partition + } else { + logger.Warnf("no group_id specified, crowdsec will only read from the 1st partition of the topic") } if err := rConf.Validate(); err != nil { return &kafka.Reader{}, fmt.Errorf("while validating reader configuration: %w", err) diff --git a/pkg/acquisition/modules/kafka/kafka_test.go b/pkg/acquisition/modules/kafka/kafka_test.go index 950d33a62e2..d796166a6ca 100644 --- a/pkg/acquisition/modules/kafka/kafka_test.go +++ b/pkg/acquisition/modules/kafka/kafka_test.go @@ -13,8 +13,9 @@ import ( "github.com/stretchr/testify/require" "gopkg.in/tomb.v2" - "github.com/crowdsecurity/go-cs-lib/pkg/cstest" + "github.com/crowdsecurity/go-cs-lib/cstest" + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -58,22 +59,30 @@ brokers: topic: crowdsec`, expectedErr: "", }, + { + config: ` +source: kafka +brokers: + - localhost:9092 +topic: crowdsec +partition: 1 +group_id: crowdsec`, + expectedErr: "cannote create kafka reader: cannot specify both group_id and partition", + }, } - subLogger := log.WithFields(log.Fields{ - "type": "kafka", - }) + subLogger := log.WithField("type", "kafka") + for _, test := range tests { k := KafkaSource{} - err := k.Configure([]byte(test.config), subLogger) + err := k.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) cstest.AssertErrorContains(t, err, test.expectedErr) } } -func writeToKafka(w *kafka.Writer, logs []string) { - +func writeToKafka(ctx context.Context, w *kafka.Writer, logs []string) { for idx, log := range logs { - err := w.WriteMessages(context.Background(), kafka.Message{ + err := w.WriteMessages(ctx, kafka.Message{ Key: []byte(strconv.Itoa(idx)), // create an arbitrary message payload for the value Value: []byte(log), @@ -95,7 +104,9 @@ func createTopic(topic string, broker string) { if err != nil { panic(err) } + var controllerConn *kafka.Conn + controllerConn, err = kafka.Dial("tcp", net.JoinHostPort(controller.Host, strconv.Itoa(controller.Port))) if err != nil { panic(err) @@ -117,9 +128,11 @@ func createTopic(topic string, broker string) { } func TestStreamingAcquisition(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } + tests := []struct { name string logs []string @@ -137,9 +150,7 @@ func TestStreamingAcquisition(t *testing.T) { }, } - subLogger := log.WithFields(log.Fields{ - "type": "kafka", - }) + subLogger := log.WithField("type", "kafka") createTopic("crowdsecplaintext", "localhost:9092") @@ -148,28 +159,30 @@ func TestStreamingAcquisition(t *testing.T) { Topic: "crowdsecplaintext", }) if w == nil { - log.Fatalf("Unable to setup a kafka producer") + t.Fatal("Unable to setup a kafka producer") } for _, ts := range tests { - ts := ts t.Run(ts.name, func(t *testing.T) { k := KafkaSource{} + err := k.Configure([]byte(` source: kafka brokers: - localhost:9092 -topic: crowdsecplaintext`), subLogger) +topic: crowdsecplaintext`), subLogger, configuration.METRICS_NONE) if err != nil { t.Fatalf("could not configure kafka source : %s", err) } + tomb := tomb.Tomb{} out := make(chan types.Event) - err = k.StreamingAcquisition(out, &tomb) + err = k.StreamingAcquisition(ctx, out, &tomb) cstest.AssertErrorContains(t, err, ts.expectedErr) actualLines := 0 - go writeToKafka(w, ts.logs) + + go writeToKafka(ctx, w, ts.logs) READLOOP: for { select { @@ -184,13 +197,14 @@ topic: crowdsecplaintext`), subLogger) tomb.Wait() }) } - } func TestStreamingAcquisitionWithSSL(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } + tests := []struct { name string logs []string @@ -207,9 +221,7 @@ func TestStreamingAcquisitionWithSSL(t *testing.T) { }, } - subLogger := log.WithFields(log.Fields{ - "type": "kafka", - }) + subLogger := log.WithField("type", "kafka") createTopic("crowdsecssl", "localhost:9092") @@ -218,13 +230,13 @@ func TestStreamingAcquisitionWithSSL(t *testing.T) { Topic: "crowdsecssl", }) if w2 == nil { - log.Fatalf("Unable to setup a kafka producer") + t.Fatal("Unable to setup a kafka producer") } for _, ts := range tests { - ts := ts t.Run(ts.name, func(t *testing.T) { k := KafkaSource{} + err := k.Configure([]byte(` source: kafka brokers: @@ -235,17 +247,19 @@ tls: client_cert: ./testdata/kafkaClient.certificate.pem client_key: ./testdata/kafkaClient.key ca_cert: ./testdata/snakeoil-ca-1.crt - `), subLogger) + `), subLogger, configuration.METRICS_NONE) if err != nil { t.Fatalf("could not configure kafka source : %s", err) } + tomb := tomb.Tomb{} out := make(chan types.Event) - err = k.StreamingAcquisition(out, &tomb) + err = k.StreamingAcquisition(ctx, out, &tomb) cstest.AssertErrorContains(t, err, ts.expectedErr) actualLines := 0 - go writeToKafka(w2, ts.logs) + + go writeToKafka(ctx, w2, ts.logs) READLOOP: for { select { @@ -260,5 +274,4 @@ tls: tomb.Wait() }) } - } diff --git a/pkg/acquisition/modules/kinesis/kinesis.go b/pkg/acquisition/modules/kinesis/kinesis.go index cc263da4f7d..ca3a847dbfb 100644 --- a/pkg/acquisition/modules/kinesis/kinesis.go +++ b/pkg/acquisition/modules/kinesis/kinesis.go @@ -3,7 +3,9 @@ package kinesisacquisition import ( "bytes" "compress/gzip" + "context" "encoding/json" + "errors" "fmt" "io" "strings" @@ -18,7 +20,7 @@ import ( "gopkg.in/tomb.v2" "gopkg.in/yaml.v2" - "github.com/crowdsecurity/go-cs-lib/pkg/trace" + "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" "github.com/crowdsecurity/crowdsec/pkg/types" @@ -28,7 +30,7 @@ type KinesisConfiguration struct { configuration.DataSourceCommonCfg `yaml:",inline"` StreamName string `yaml:"stream_name"` StreamARN string `yaml:"stream_arn"` - UseEnhancedFanOut bool `yaml:"use_enhanced_fanout"` //Use RegisterStreamConsumer and SubscribeToShard instead of GetRecords + UseEnhancedFanOut bool `yaml:"use_enhanced_fanout"` // Use RegisterStreamConsumer and SubscribeToShard instead of GetRecords AwsProfile *string `yaml:"aws_profile"` AwsRegion string `yaml:"aws_region"` AwsEndpoint string `yaml:"aws_endpoint"` @@ -38,6 +40,7 @@ type KinesisConfiguration struct { } type KinesisSource struct { + metricsLevel int Config KinesisConfiguration logger *log.Entry kClient *kinesis.Kinesis @@ -94,7 +97,7 @@ func (k *KinesisSource) newClient() error { } if sess == nil { - return fmt.Errorf("failed to create aws session") + return errors.New("failed to create aws session") } config := aws.NewConfig() if k.Config.AwsRegion != "" { @@ -105,15 +108,15 @@ func (k *KinesisSource) newClient() error { } k.kClient = kinesis.New(sess, config) if k.kClient == nil { - return fmt.Errorf("failed to create kinesis client") + return errors.New("failed to create kinesis client") } return nil } func (k *KinesisSource) GetMetrics() []prometheus.Collector { return []prometheus.Collector{linesRead, linesReadShards} - } + func (k *KinesisSource) GetAggregMetrics() []prometheus.Collector { return []prometheus.Collector{linesRead, linesReadShards} } @@ -123,7 +126,7 @@ func (k *KinesisSource) UnmarshalConfig(yamlConfig []byte) error { err := yaml.UnmarshalStrict(yamlConfig, &k.Config) if err != nil { - return fmt.Errorf("Cannot parse kinesis datasource configuration: %w", err) + return fmt.Errorf("cannot parse kinesis datasource configuration: %w", err) } if k.Config.Mode == "" { @@ -131,16 +134,16 @@ func (k *KinesisSource) UnmarshalConfig(yamlConfig []byte) error { } if k.Config.StreamName == "" && !k.Config.UseEnhancedFanOut { - return fmt.Errorf("stream_name is mandatory when use_enhanced_fanout is false") + return errors.New("stream_name is mandatory when use_enhanced_fanout is false") } if k.Config.StreamARN == "" && k.Config.UseEnhancedFanOut { - return fmt.Errorf("stream_arn is mandatory when use_enhanced_fanout is true") + return errors.New("stream_arn is mandatory when use_enhanced_fanout is true") } if k.Config.ConsumerName == "" && k.Config.UseEnhancedFanOut { - return fmt.Errorf("consumer_name is mandatory when use_enhanced_fanout is true") + return errors.New("consumer_name is mandatory when use_enhanced_fanout is true") } if k.Config.StreamARN != "" && k.Config.StreamName != "" { - return fmt.Errorf("stream_arn and stream_name are mutually exclusive") + return errors.New("stream_arn and stream_name are mutually exclusive") } if k.Config.MaxRetries <= 0 { k.Config.MaxRetries = 10 @@ -149,8 +152,9 @@ func (k *KinesisSource) UnmarshalConfig(yamlConfig []byte) error { return nil } -func (k *KinesisSource) Configure(yamlConfig []byte, logger *log.Entry) error { +func (k *KinesisSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLevel int) error { k.logger = logger + k.metricsLevel = MetricsLevel err := k.UnmarshalConfig(yamlConfig) if err != nil { @@ -167,7 +171,7 @@ func (k *KinesisSource) Configure(yamlConfig []byte, logger *log.Entry) error { } func (k *KinesisSource) ConfigureByDSN(string, map[string]string, *log.Entry, string) error { - return fmt.Errorf("kinesis datasource does not support command-line acquisition") + return errors.New("kinesis datasource does not support command-line acquisition") } func (k *KinesisSource) GetMode() string { @@ -179,13 +183,12 @@ func (k *KinesisSource) GetName() string { } func (k *KinesisSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { - return fmt.Errorf("kinesis datasource does not support one-shot acquisition") + return errors.New("kinesis datasource does not support one-shot acquisition") } func (k *KinesisSource) decodeFromSubscription(record []byte) ([]CloudwatchSubscriptionLogEvent, error) { b := bytes.NewBuffer(record) r, err := gzip.NewReader(b) - if err != nil { k.logger.Error(err) return nil, err @@ -206,7 +209,7 @@ func (k *KinesisSource) decodeFromSubscription(record []byte) ([]CloudwatchSubsc func (k *KinesisSource) WaitForConsumerDeregistration(consumerName string, streamARN string) error { maxTries := k.Config.MaxRetries - for i := 0; i < maxTries; i++ { + for i := range maxTries { _, err := k.kClient.DescribeStreamConsumer(&kinesis.DescribeStreamConsumerInput{ ConsumerName: aws.String(consumerName), StreamARN: aws.String(streamARN), @@ -247,7 +250,7 @@ func (k *KinesisSource) DeregisterConsumer() error { func (k *KinesisSource) WaitForConsumerRegistration(consumerARN string) error { maxTries := k.Config.MaxRetries - for i := 0; i < maxTries; i++ { + for i := range maxTries { describeOutput, err := k.kClient.DescribeStreamConsumer(&kinesis.DescribeStreamConsumerInput{ ConsumerARN: aws.String(consumerARN), }) @@ -283,17 +286,21 @@ func (k *KinesisSource) RegisterConsumer() (*kinesis.RegisterStreamConsumerOutpu func (k *KinesisSource) ParseAndPushRecords(records []*kinesis.Record, out chan types.Event, logger *log.Entry, shardId string) { for _, record := range records { if k.Config.StreamARN != "" { - linesReadShards.With(prometheus.Labels{"stream": k.Config.StreamARN, "shard": shardId}).Inc() - linesRead.With(prometheus.Labels{"stream": k.Config.StreamARN}).Inc() + if k.metricsLevel != configuration.METRICS_NONE { + linesReadShards.With(prometheus.Labels{"stream": k.Config.StreamARN, "shard": shardId}).Inc() + linesRead.With(prometheus.Labels{"stream": k.Config.StreamARN}).Inc() + } } else { - linesReadShards.With(prometheus.Labels{"stream": k.Config.StreamName, "shard": shardId}).Inc() - linesRead.With(prometheus.Labels{"stream": k.Config.StreamName}).Inc() + if k.metricsLevel != configuration.METRICS_NONE { + linesReadShards.With(prometheus.Labels{"stream": k.Config.StreamName, "shard": shardId}).Inc() + linesRead.With(prometheus.Labels{"stream": k.Config.StreamName}).Inc() + } } var data []CloudwatchSubscriptionLogEvent var err error if k.Config.FromSubscription { - //The AWS docs says that the data is base64 encoded - //but apparently GetRecords decodes it for us ? + // The AWS docs says that the data is base64 encoded + // but apparently GetRecords decodes it for us ? data, err = k.decodeFromSubscription(record.Data) if err != nil { logger.Errorf("Cannot decode data: %s", err) @@ -327,10 +334,10 @@ func (k *KinesisSource) ParseAndPushRecords(records []*kinesis.Record, out chan } func (k *KinesisSource) ReadFromSubscription(reader kinesis.SubscribeToShardEventStreamReader, out chan types.Event, shardId string, streamName string) error { - logger := k.logger.WithFields(log.Fields{"shard_id": shardId}) - //ghetto sync, kinesis allows to subscribe to a closed shard, which will make the goroutine exit immediately - //and we won't be able to start a new one if this is the first one started by the tomb - //TODO: look into parent shards to see if a shard is closed before starting to read it ? + logger := k.logger.WithField("shard_id", shardId) + // ghetto sync, kinesis allows to subscribe to a closed shard, which will make the goroutine exit immediately + // and we won't be able to start a new one if this is the first one started by the tomb + // TODO: look into parent shards to see if a shard is closed before starting to read it ? time.Sleep(time.Second) for { select { @@ -390,7 +397,7 @@ func (k *KinesisSource) EnhancedRead(out chan types.Event, t *tomb.Tomb) error { return fmt.Errorf("resource part of stream ARN %s does not start with stream/", k.Config.StreamARN) } - k.logger = k.logger.WithFields(log.Fields{"stream": parsedARN.Resource[7:]}) + k.logger = k.logger.WithField("stream", parsedARN.Resource[7:]) k.logger.Info("starting kinesis acquisition with enhanced fan-out") err = k.DeregisterConsumer() if err != nil { @@ -413,7 +420,7 @@ func (k *KinesisSource) EnhancedRead(out chan types.Event, t *tomb.Tomb) error { case <-t.Dying(): k.logger.Infof("Kinesis source is dying") k.shardReaderTomb.Kill(nil) - _ = k.shardReaderTomb.Wait() //we don't care about the error as we kill the tomb ourselves + _ = k.shardReaderTomb.Wait() // we don't care about the error as we kill the tomb ourselves err = k.DeregisterConsumer() if err != nil { return fmt.Errorf("cannot deregister consumer: %w", err) @@ -424,7 +431,7 @@ func (k *KinesisSource) EnhancedRead(out chan types.Event, t *tomb.Tomb) error { if k.shardReaderTomb.Err() != nil { return k.shardReaderTomb.Err() } - //All goroutines have exited without error, so a resharding event, start again + // All goroutines have exited without error, so a resharding event, start again k.logger.Debugf("All reader goroutines have exited, resharding event or periodic resubscribe") continue } @@ -432,17 +439,19 @@ func (k *KinesisSource) EnhancedRead(out chan types.Event, t *tomb.Tomb) error { } func (k *KinesisSource) ReadFromShard(out chan types.Event, shardId string) error { - logger := k.logger.WithFields(log.Fields{"shard": shardId}) + logger := k.logger.WithField("shard", shardId) logger.Debugf("Starting to read shard") - sharIt, err := k.kClient.GetShardIterator(&kinesis.GetShardIteratorInput{ShardId: aws.String(shardId), + sharIt, err := k.kClient.GetShardIterator(&kinesis.GetShardIteratorInput{ + ShardId: aws.String(shardId), StreamName: &k.Config.StreamName, - ShardIteratorType: aws.String(kinesis.ShardIteratorTypeLatest)}) + ShardIteratorType: aws.String(kinesis.ShardIteratorTypeLatest), + }) if err != nil { logger.Errorf("Cannot get shard iterator: %s", err) return fmt.Errorf("cannot get shard iterator: %w", err) } it := sharIt.ShardIterator - //AWS recommends to wait for a second between calls to GetRecords for a given shard + // AWS recommends to wait for a second between calls to GetRecords for a given shard ticker := time.NewTicker(time.Second) for { select { @@ -453,7 +462,7 @@ func (k *KinesisSource) ReadFromShard(out chan types.Event, shardId string) erro switch err.(type) { case *kinesis.ProvisionedThroughputExceededException: logger.Warn("Provisioned throughput exceeded") - //TODO: implement exponential backoff + // TODO: implement exponential backoff continue case *kinesis.ExpiredIteratorException: logger.Warn("Expired iterator") @@ -478,7 +487,7 @@ func (k *KinesisSource) ReadFromShard(out chan types.Event, shardId string) erro } func (k *KinesisSource) ReadFromStream(out chan types.Event, t *tomb.Tomb) error { - k.logger = k.logger.WithFields(log.Fields{"stream": k.Config.StreamName}) + k.logger = k.logger.WithField("stream", k.Config.StreamName) k.logger.Info("starting kinesis acquisition from shards") for { shards, err := k.kClient.ListShards(&kinesis.ListShardsInput{ @@ -499,7 +508,7 @@ func (k *KinesisSource) ReadFromStream(out chan types.Event, t *tomb.Tomb) error case <-t.Dying(): k.logger.Info("kinesis source is dying") k.shardReaderTomb.Kill(nil) - _ = k.shardReaderTomb.Wait() //we don't care about the error as we kill the tomb ourselves + _ = k.shardReaderTomb.Wait() // we don't care about the error as we kill the tomb ourselves return nil case <-k.shardReaderTomb.Dying(): reason := k.shardReaderTomb.Err() @@ -513,14 +522,13 @@ func (k *KinesisSource) ReadFromStream(out chan types.Event, t *tomb.Tomb) error } } -func (k *KinesisSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (k *KinesisSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/kinesis/streaming") if k.Config.UseEnhancedFanOut { return k.EnhancedRead(out, t) - } else { - return k.ReadFromStream(out, t) } + return k.ReadFromStream(out, t) }) return nil } diff --git a/pkg/acquisition/modules/kinesis/kinesis_test.go b/pkg/acquisition/modules/kinesis/kinesis_test.go index 25941e20d5c..027cbde9240 100644 --- a/pkg/acquisition/modules/kinesis/kinesis_test.go +++ b/pkg/acquisition/modules/kinesis/kinesis_test.go @@ -3,6 +3,7 @@ package kinesisacquisition import ( "bytes" "compress/gzip" + "context" "encoding/json" "fmt" "net" @@ -12,15 +13,17 @@ import ( "testing" "time" - "github.com/crowdsecurity/go-cs-lib/pkg/cstest" - "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/kinesis" - "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "gopkg.in/tomb.v2" + + "github.com/crowdsecurity/go-cs-lib/cstest" + + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" + "github.com/crowdsecurity/crowdsec/pkg/types" ) func getLocalStackEndpoint() (string, error) { @@ -29,7 +32,7 @@ func getLocalStackEndpoint() (string, error) { v = strings.TrimPrefix(v, "http://") _, err := net.Dial("tcp", v) if err != nil { - return "", fmt.Errorf("while dialing %s : %s : aws endpoint isn't available", v, err) + return "", fmt.Errorf("while dialing %s: %w: aws endpoint isn't available", v, err) } } return endpoint, nil @@ -58,8 +61,8 @@ func GenSubObject(i int) []byte { gz := gzip.NewWriter(&b) gz.Write(body) gz.Close() - //AWS actually base64 encodes the data, but it looks like kinesis automatically decodes it at some point - //localstack does not do it, so let's just write a raw gzipped stream + // AWS actually base64 encodes the data, but it looks like kinesis automatically decodes it at some point + // localstack does not do it, so let's just write a raw gzipped stream return b.Bytes() } @@ -70,7 +73,7 @@ func WriteToStream(streamName string, count int, shards int, sub bool) { } sess := session.Must(session.NewSession()) kinesisClient := kinesis.New(sess, aws.NewConfig().WithEndpoint(endpoint).WithRegion("us-east-1")) - for i := 0; i < count; i++ { + for i := range count { partition := "partition" if shards != 1 { partition = fmt.Sprintf("partition-%d", i%shards) @@ -97,10 +100,10 @@ func TestMain(m *testing.M) { os.Setenv("AWS_ACCESS_KEY_ID", "foobar") os.Setenv("AWS_SECRET_ACCESS_KEY", "foobar") - //delete_streams() - //create_streams() + // delete_streams() + // create_streams() code := m.Run() - //delete_streams() + // delete_streams() os.Exit(code) } @@ -138,17 +141,16 @@ stream_arn: arn:aws:kinesis:eu-west-1:123456789012:stream/my-stream`, }, } - subLogger := log.WithFields(log.Fields{ - "type": "kinesis", - }) + subLogger := log.WithField("type", "kinesis") for _, test := range tests { f := KinesisSource{} - err := f.Configure([]byte(test.config), subLogger) + err := f.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) cstest.AssertErrorContains(t, err, test.expectedErr) } } func TestReadFromStream(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -170,22 +172,20 @@ stream_name: stream-1-shard`, for _, test := range tests { f := KinesisSource{} config := fmt.Sprintf(test.config, endpoint) - err := f.Configure([]byte(config), log.WithFields(log.Fields{ - "type": "kinesis", - })) + err := f.Configure([]byte(config), log.WithField("type", "kinesis"), configuration.METRICS_NONE) if err != nil { t.Fatalf("Error configuring source: %s", err) } tomb := &tomb.Tomb{} out := make(chan types.Event) - err = f.StreamingAcquisition(out, tomb) + err = f.StreamingAcquisition(ctx, out, tomb) if err != nil { t.Fatalf("Error starting source: %s", err) } - //Allow the datasource to start listening to the stream + // Allow the datasource to start listening to the stream time.Sleep(4 * time.Second) WriteToStream(f.Config.StreamName, test.count, test.shards, false) - for i := 0; i < test.count; i++ { + for i := range test.count { e := <-out assert.Equal(t, fmt.Sprintf("%d", i), e.Line.Raw) } @@ -195,6 +195,7 @@ stream_name: stream-1-shard`, } func TestReadFromMultipleShards(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -216,23 +217,21 @@ stream_name: stream-2-shards`, for _, test := range tests { f := KinesisSource{} config := fmt.Sprintf(test.config, endpoint) - err := f.Configure([]byte(config), log.WithFields(log.Fields{ - "type": "kinesis", - })) + err := f.Configure([]byte(config), log.WithField("type", "kinesis"), configuration.METRICS_NONE) if err != nil { t.Fatalf("Error configuring source: %s", err) } tomb := &tomb.Tomb{} out := make(chan types.Event) - err = f.StreamingAcquisition(out, tomb) + err = f.StreamingAcquisition(ctx, out, tomb) if err != nil { t.Fatalf("Error starting source: %s", err) } - //Allow the datasource to start listening to the stream + // Allow the datasource to start listening to the stream time.Sleep(4 * time.Second) WriteToStream(f.Config.StreamName, test.count, test.shards, false) c := 0 - for i := 0; i < test.count; i++ { + for range test.count { <-out c += 1 } @@ -243,6 +242,7 @@ stream_name: stream-2-shards`, } func TestFromSubscription(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -265,22 +265,20 @@ from_subscription: true`, for _, test := range tests { f := KinesisSource{} config := fmt.Sprintf(test.config, endpoint) - err := f.Configure([]byte(config), log.WithFields(log.Fields{ - "type": "kinesis", - })) + err := f.Configure([]byte(config), log.WithField("type", "kinesis"), configuration.METRICS_NONE) if err != nil { t.Fatalf("Error configuring source: %s", err) } tomb := &tomb.Tomb{} out := make(chan types.Event) - err = f.StreamingAcquisition(out, tomb) + err = f.StreamingAcquisition(ctx, out, tomb) if err != nil { t.Fatalf("Error starting source: %s", err) } - //Allow the datasource to start listening to the stream + // Allow the datasource to start listening to the stream time.Sleep(4 * time.Second) WriteToStream(f.Config.StreamName, test.count, test.shards, true) - for i := 0; i < test.count; i++ { + for i := range test.count { e := <-out assert.Equal(t, fmt.Sprintf("%d", i), e.Line.Raw) } @@ -311,9 +309,7 @@ use_enhanced_fanout: true`, for _, test := range tests { f := KinesisSource{} config := fmt.Sprintf(test.config, endpoint) - err := f.Configure([]byte(config), log.WithFields(log.Fields{ - "type": "kinesis", - })) + err := f.Configure([]byte(config), log.WithField("type", "kinesis")) if err != nil { t.Fatalf("Error configuring source: %s", err) } diff --git a/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go b/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go index 24354738114..f979b044dcc 100644 --- a/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go +++ b/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go @@ -3,6 +3,7 @@ package kubernetesauditacquisition import ( "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -14,7 +15,7 @@ import ( "gopkg.in/yaml.v2" "k8s.io/apiserver/pkg/apis/audit" - "github.com/crowdsecurity/go-cs-lib/pkg/trace" + "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" "github.com/crowdsecurity/crowdsec/pkg/types" @@ -28,12 +29,13 @@ type KubernetesAuditConfiguration struct { } type KubernetesAuditSource struct { - config KubernetesAuditConfiguration - logger *log.Entry - mux *http.ServeMux - server *http.Server - outChan chan types.Event - addr string + metricsLevel int + config KubernetesAuditConfiguration + logger *log.Entry + mux *http.ServeMux + server *http.Server + outChan chan types.Event + addr string } var eventCount = prometheus.NewCounterVec( @@ -72,15 +74,15 @@ func (ka *KubernetesAuditSource) UnmarshalConfig(yamlConfig []byte) error { ka.config = k8sConfig if ka.config.ListenAddr == "" { - return fmt.Errorf("listen_addr cannot be empty") + return errors.New("listen_addr cannot be empty") } if ka.config.ListenPort == 0 { - return fmt.Errorf("listen_port cannot be empty") + return errors.New("listen_port cannot be empty") } if ka.config.WebhookPath == "" { - return fmt.Errorf("webhook_path cannot be empty") + return errors.New("webhook_path cannot be empty") } if ka.config.WebhookPath[0] != '/' { @@ -93,8 +95,9 @@ func (ka *KubernetesAuditSource) UnmarshalConfig(yamlConfig []byte) error { return nil } -func (ka *KubernetesAuditSource) Configure(config []byte, logger *log.Entry) error { +func (ka *KubernetesAuditSource) Configure(config []byte, logger *log.Entry, MetricsLevel int) error { ka.logger = logger + ka.metricsLevel = MetricsLevel err := ka.UnmarshalConfig(config) if err != nil { @@ -117,7 +120,7 @@ func (ka *KubernetesAuditSource) Configure(config []byte, logger *log.Entry) err } func (ka *KubernetesAuditSource) ConfigureByDSN(dsn string, labels map[string]string, logger *log.Entry, uuid string) error { - return fmt.Errorf("k8s-audit datasource does not support command-line acquisition") + return errors.New("k8s-audit datasource does not support command-line acquisition") } func (ka *KubernetesAuditSource) GetMode() string { @@ -129,10 +132,10 @@ func (ka *KubernetesAuditSource) GetName() string { } func (ka *KubernetesAuditSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { - return fmt.Errorf("k8s-audit datasource does not support one-shot acquisition") + return errors.New("k8s-audit datasource does not support one-shot acquisition") } -func (ka *KubernetesAuditSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (ka *KubernetesAuditSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { ka.outChan = out t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/k8s-audit/live") @@ -161,7 +164,9 @@ func (ka *KubernetesAuditSource) Dump() interface{} { } func (ka *KubernetesAuditSource) webhookHandler(w http.ResponseWriter, r *http.Request) { - requestCount.WithLabelValues(ka.addr).Inc() + if ka.metricsLevel != configuration.METRICS_NONE { + requestCount.WithLabelValues(ka.addr).Inc() + } if r.Method != http.MethodPost { w.WriteHeader(http.StatusMethodNotAllowed) return @@ -185,10 +190,12 @@ func (ka *KubernetesAuditSource) webhookHandler(w http.ResponseWriter, r *http.R remoteIP := strings.Split(r.RemoteAddr, ":")[0] for _, auditEvent := range auditEvents.Items { - eventCount.WithLabelValues(ka.addr).Inc() + if ka.metricsLevel != configuration.METRICS_NONE { + eventCount.WithLabelValues(ka.addr).Inc() + } bytesEvent, err := json.Marshal(auditEvent) if err != nil { - ka.logger.Errorf("Error marshaling audit event: %s", err) + ka.logger.Errorf("Error serializing audit event: %s", err) continue } ka.logger.Tracef("Got audit event: %s", string(bytesEvent)) diff --git a/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go b/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go index 799868dc811..a086a756e4a 100644 --- a/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go +++ b/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go @@ -1,15 +1,19 @@ package kubernetesauditacquisition import ( + "context" "net/http/httptest" "strings" "testing" "time" - "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gopkg.in/tomb.v2" + + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" + "github.com/crowdsecurity/crowdsec/pkg/types" ) func TestBadConfiguration(t *testing.T) { @@ -44,12 +48,12 @@ listen_addr: 0.0.0.0`, err := f.UnmarshalConfig([]byte(test.config)) assert.Contains(t, err.Error(), test.expectedErr) - }) } } func TestInvalidConfig(t *testing.T) { + ctx := context.Background() tests := []struct { name string config string @@ -65,9 +69,7 @@ webhook_path: /k8s-audit`, }, } - subLogger := log.WithFields(log.Fields{ - "type": "k8s-audit", - }) + subLogger := log.WithField("type", "k8s-audit") for _, test := range tests { t.Run(test.name, func(t *testing.T) { @@ -78,27 +80,27 @@ webhook_path: /k8s-audit`, err := f.UnmarshalConfig([]byte(test.config)) - assert.NoError(t, err) + require.NoError(t, err) - err = f.Configure([]byte(test.config), subLogger) + err = f.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) - assert.NoError(t, err) - f.StreamingAcquisition(out, tb) + require.NoError(t, err) + f.StreamingAcquisition(ctx, out, tb) time.Sleep(1 * time.Second) tb.Kill(nil) err = tb.Wait() if test.expectedErr != "" { - assert.ErrorContains(t, err, test.expectedErr) + require.ErrorContains(t, err, test.expectedErr) return } - assert.NoError(t, err) + require.NoError(t, err) }) } - } func TestHandler(t *testing.T) { + ctx := context.Background() tests := []struct { name string config string @@ -229,9 +231,7 @@ webhook_path: /k8s-audit`, }, } - subLogger := log.WithFields(log.Fields{ - "type": "k8s-audit", - }) + subLogger := log.WithField("type", "k8s-audit") for _, test := range tests { t.Run(test.name, func(t *testing.T) { @@ -252,27 +252,27 @@ webhook_path: /k8s-audit`, f := KubernetesAuditSource{} err := f.UnmarshalConfig([]byte(test.config)) - assert.NoError(t, err) - err = f.Configure([]byte(test.config), subLogger) + require.NoError(t, err) + err = f.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest(test.method, "/k8s-audit", strings.NewReader(test.body)) w := httptest.NewRecorder() - f.StreamingAcquisition(out, tb) + f.StreamingAcquisition(ctx, out, tb) f.webhookHandler(w, req) res := w.Result() assert.Equal(t, test.expectedStatusCode, res.StatusCode) - //time.Sleep(1 * time.Second) - assert.NoError(t, err) + // time.Sleep(1 * time.Second) + require.NoError(t, err) tb.Kill(nil) err = tb.Wait() - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, test.eventCount, eventCount) }) diff --git a/pkg/acquisition/modules/loki/entry.go b/pkg/acquisition/modules/loki/entry.go new file mode 100644 index 00000000000..c0ff857ea6b --- /dev/null +++ b/pkg/acquisition/modules/loki/entry.go @@ -0,0 +1,60 @@ +package loki + +import ( + "encoding/json" + "strconv" + "time" +) + +type Entry struct { + Timestamp time.Time + Line string +} + +func (e *Entry) UnmarshalJSON(b []byte) error { + var values []string + err := json.Unmarshal(b, &values) + if err != nil { + return err + } + t, err := strconv.Atoi(values[0]) + if err != nil { + return err + } + e.Timestamp = time.Unix(int64(t), 0) + e.Line = values[1] + return nil +} + +type Stream struct { + Stream map[string]string `json:"stream"` + Entries []Entry `json:"values"` +} + +type DroppedEntry struct { + Labels map[string]string `json:"labels"` + Timestamp time.Time `json:"timestamp"` +} + +type Tail struct { + Streams []Stream `json:"streams"` + DroppedEntries []DroppedEntry `json:"dropped_entries"` +} + +// LokiQuery GET response. +// See https://grafana.com/docs/loki/latest/api/#get-lokiapiv1query +type LokiQuery struct { + Status string `json:"status"` + Data Data `json:"data"` +} + +type Data struct { + ResultType string `json:"resultType"` + Result []StreamResult `json:"result"` // Warning, just stream value is handled + Stats interface{} `json:"stats"` // Stats is boring, just ignore it +} + +type StreamResult struct { + Stream map[string]string `json:"stream"` + Values []Entry `json:"values"` +} diff --git a/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go b/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go new file mode 100644 index 00000000000..846e833abea --- /dev/null +++ b/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go @@ -0,0 +1,324 @@ +package lokiclient + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "time" + + "github.com/gorilla/websocket" + log "github.com/sirupsen/logrus" + "gopkg.in/tomb.v2" + + "github.com/crowdsecurity/crowdsec/pkg/apiclient/useragent" +) + +type LokiClient struct { + Logger *log.Entry + + config Config + t *tomb.Tomb + fail_start time.Time + currentTickerInterval time.Duration + requestHeaders map[string]string +} + +type Config struct { + LokiURL string + LokiPrefix string + Query string + Headers map[string]string + + Username string + Password string + + Since time.Duration + Until time.Duration + + FailMaxDuration time.Duration + + DelayFor int + Limit int +} + +func updateURI(uri string, lq LokiQueryRangeResponse, infinite bool) string { + u, _ := url.Parse(uri) + queryParams := u.Query() + + if len(lq.Data.Result) > 0 { + lastTs := lq.Data.Result[0].Entries[len(lq.Data.Result[0].Entries)-1].Timestamp + // +1 the last timestamp to avoid getting the same result again. + queryParams.Set("start", strconv.Itoa(int(lastTs.UnixNano()+1))) + } + + if infinite { + queryParams.Set("end", strconv.Itoa(int(time.Now().UnixNano()))) + } + + u.RawQuery = queryParams.Encode() + return u.String() +} + +func (lc *LokiClient) SetTomb(t *tomb.Tomb) { + lc.t = t +} + +func (lc *LokiClient) resetFailStart() { + if !lc.fail_start.IsZero() { + log.Infof("loki is back after %s", time.Since(lc.fail_start)) + } + lc.fail_start = time.Time{} +} + +func (lc *LokiClient) shouldRetry() bool { + if lc.fail_start.IsZero() { + lc.Logger.Warningf("loki is not available, will retry for %s", lc.config.FailMaxDuration) + lc.fail_start = time.Now() + return true + } + if time.Since(lc.fail_start) > lc.config.FailMaxDuration { + lc.Logger.Errorf("loki didn't manage to recover after %s, giving up", lc.config.FailMaxDuration) + return false + } + return true +} + +func (lc *LokiClient) increaseTicker(ticker *time.Ticker) { + maxTicker := 10 * time.Second + if lc.currentTickerInterval < maxTicker { + lc.currentTickerInterval *= 2 + if lc.currentTickerInterval > maxTicker { + lc.currentTickerInterval = maxTicker + } + ticker.Reset(lc.currentTickerInterval) + } +} + +func (lc *LokiClient) decreaseTicker(ticker *time.Ticker) { + minTicker := 100 * time.Millisecond + if lc.currentTickerInterval != minTicker { + lc.currentTickerInterval = minTicker + ticker.Reset(lc.currentTickerInterval) + } +} + +func (lc *LokiClient) queryRange(ctx context.Context, uri string, c chan *LokiQueryRangeResponse, infinite bool) error { + lc.currentTickerInterval = 100 * time.Millisecond + ticker := time.NewTicker(lc.currentTickerInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-lc.t.Dying(): + return lc.t.Err() + case <-ticker.C: + resp, err := lc.Get(uri) + if err != nil { + if ok := lc.shouldRetry(); !ok { + return fmt.Errorf("error querying range: %w", err) + } + lc.increaseTicker(ticker) + continue + } + + if resp.StatusCode != http.StatusOK { + lc.Logger.Warnf("bad HTTP response code for query range: %d", resp.StatusCode) + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + if ok := lc.shouldRetry(); !ok { + return fmt.Errorf("bad HTTP response code: %d: %s: %w", resp.StatusCode, string(body), err) + } + lc.increaseTicker(ticker) + continue + } + + var lq LokiQueryRangeResponse + if err := json.NewDecoder(resp.Body).Decode(&lq); err != nil { + resp.Body.Close() + if ok := lc.shouldRetry(); !ok { + return fmt.Errorf("error decoding Loki response: %w", err) + } + lc.increaseTicker(ticker) + continue + } + resp.Body.Close() + lc.Logger.Tracef("Got response: %+v", lq) + c <- &lq + lc.resetFailStart() + if !infinite && (len(lq.Data.Result) == 0 || len(lq.Data.Result[0].Entries) < lc.config.Limit) { + lc.Logger.Infof("Got less than %d results (%d), stopping", lc.config.Limit, len(lq.Data.Result)) + close(c) + return nil + } + if len(lq.Data.Result) > 0 { + lc.Logger.Debugf("(timer:%v) %d results / %d entries result[0] (uri:%s)", lc.currentTickerInterval, len(lq.Data.Result), len(lq.Data.Result[0].Entries), uri) + } else { + lc.Logger.Debugf("(timer:%v) no results (uri:%s)", lc.currentTickerInterval, uri) + } + if infinite { + if len(lq.Data.Result) > 0 { //as long as we get results, we keep lowest ticker + lc.decreaseTicker(ticker) + } else { + lc.increaseTicker(ticker) + } + } + + uri = updateURI(uri, lq, infinite) + } + } +} + +func (lc *LokiClient) getURLFor(endpoint string, params map[string]string) string { + u, err := url.Parse(lc.config.LokiURL) + if err != nil { + return "" + } + queryParams := u.Query() + for k, v := range params { + queryParams.Set(k, v) + } + u.RawQuery = queryParams.Encode() + + u.Path, err = url.JoinPath(lc.config.LokiPrefix, u.Path, endpoint) + if err != nil { + return "" + } + + if endpoint == "loki/api/v1/tail" { + if u.Scheme == "http" { + u.Scheme = "ws" + } else { + u.Scheme = "wss" + } + } + + return u.String() +} + +func (lc *LokiClient) Ready(ctx context.Context) error { + tick := time.NewTicker(500 * time.Millisecond) + url := lc.getURLFor("ready", nil) + for { + select { + case <-ctx.Done(): + tick.Stop() + return ctx.Err() + case <-lc.t.Dying(): + tick.Stop() + return lc.t.Err() + case <-tick.C: + lc.Logger.Debug("Checking if Loki is ready") + resp, err := lc.Get(url) + if err != nil { + lc.Logger.Warnf("Error checking if Loki is ready: %s", err) + continue + } + _ = resp.Body.Close() + if resp.StatusCode != http.StatusOK { + lc.Logger.Debugf("Loki is not ready, status code: %d", resp.StatusCode) + continue + } + lc.Logger.Info("Loki is ready") + return nil + } + } +} + +func (lc *LokiClient) Tail(ctx context.Context) (chan *LokiResponse, error) { + responseChan := make(chan *LokiResponse) + dialer := &websocket.Dialer{} + u := lc.getURLFor("loki/api/v1/tail", map[string]string{ + "limit": strconv.Itoa(lc.config.Limit), + "start": strconv.Itoa(int(time.Now().Add(-lc.config.Since).UnixNano())), + "query": lc.config.Query, + "delay_for": strconv.Itoa(lc.config.DelayFor), + }) + + lc.Logger.Debugf("Since: %s (%s)", lc.config.Since, time.Now().Add(-lc.config.Since)) + + if lc.config.Username != "" || lc.config.Password != "" { + dialer.Proxy = func(req *http.Request) (*url.URL, error) { + req.SetBasicAuth(lc.config.Username, lc.config.Password) + return nil, nil + } + } + + requestHeader := http.Header{} + for k, v := range lc.requestHeaders { + requestHeader.Add(k, v) + } + lc.Logger.Infof("Connecting to %s", u) + + conn, _, err := dialer.Dial(u, requestHeader) + if err != nil { + lc.Logger.Errorf("Error connecting to websocket, err: %s", err) + return responseChan, errors.New("error connecting to websocket") + } + + lc.t.Go(func() error { + for { + jsonResponse := &LokiResponse{} + + err = conn.ReadJSON(jsonResponse) + if err != nil { + lc.Logger.Errorf("Error reading from websocket: %s", err) + return fmt.Errorf("websocket error: %w", err) + } + + responseChan <- jsonResponse + } + }) + + return responseChan, nil +} + +func (lc *LokiClient) QueryRange(ctx context.Context, infinite bool) chan *LokiQueryRangeResponse { + url := lc.getURLFor("loki/api/v1/query_range", map[string]string{ + "query": lc.config.Query, + "start": strconv.Itoa(int(time.Now().Add(-lc.config.Since).UnixNano())), + "end": strconv.Itoa(int(time.Now().UnixNano())), + "limit": strconv.Itoa(lc.config.Limit), + "direction": "forward", + }) + + c := make(chan *LokiQueryRangeResponse) + + lc.Logger.Debugf("Since: %s (%s)", lc.config.Since, time.Now().Add(-lc.config.Since)) + + lc.Logger.Infof("Connecting to %s", url) + lc.t.Go(func() error { + return lc.queryRange(ctx, url, c, infinite) + }) + return c +} + +// Create a wrapper for http.Get to be able to set headers and auth +func (lc *LokiClient) Get(url string) (*http.Response, error) { + request, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + for k, v := range lc.requestHeaders { + request.Header.Add(k, v) + } + return http.DefaultClient.Do(request) +} + +func NewLokiClient(config Config) *LokiClient { + headers := make(map[string]string) + for k, v := range config.Headers { + headers[k] = v + } + if config.Username != "" || config.Password != "" { + headers["Authorization"] = "Basic " + base64.StdEncoding.EncodeToString([]byte(config.Username+":"+config.Password)) + } + headers["User-Agent"] = useragent.Default() + return &LokiClient{Logger: log.WithField("component", "lokiclient"), config: config, requestHeaders: headers} +} diff --git a/pkg/acquisition/modules/loki/internal/lokiclient/types.go b/pkg/acquisition/modules/loki/internal/lokiclient/types.go new file mode 100644 index 00000000000..d5aed204406 --- /dev/null +++ b/pkg/acquisition/modules/loki/internal/lokiclient/types.go @@ -0,0 +1,55 @@ +package lokiclient + +import ( + "encoding/json" + "strconv" + "time" +) + +type Entry struct { + Timestamp time.Time + Line string +} + +func (e *Entry) UnmarshalJSON(b []byte) error { + var values []string + err := json.Unmarshal(b, &values) + if err != nil { + return err + } + t, err := strconv.Atoi(values[0]) + if err != nil { + return err + } + e.Timestamp = time.Unix(0, int64(t)) + e.Line = values[1] + return nil +} + +type Stream struct { + Stream map[string]string `json:"stream"` + Entries []Entry `json:"values"` +} + +type DroppedEntry struct { + Labels map[string]string `json:"labels"` + Timestamp time.Time `json:"timestamp"` +} + +type LokiResponse struct { + Streams []Stream `json:"streams"` + DroppedEntries []interface{} `json:"dropped_entries"` //We don't care about the actual content i think ? +} + +// LokiQuery GET response. +// See https://grafana.com/docs/loki/latest/api/#get-lokiapiv1query +type LokiQueryRangeResponse struct { + Status string `json:"status"` + Data Data `json:"data"` +} + +type Data struct { + ResultType string `json:"resultType"` + Result []Stream `json:"result"` // Warning, just stream value is handled + Stats interface{} `json:"stats"` // Stats is boring, just ignore it +} diff --git a/pkg/acquisition/modules/loki/loki.go b/pkg/acquisition/modules/loki/loki.go new file mode 100644 index 00000000000..f867feeb84b --- /dev/null +++ b/pkg/acquisition/modules/loki/loki.go @@ -0,0 +1,374 @@ +package loki + +/* +https://grafana.com/docs/loki/latest/api/#get-lokiapiv1tail +*/ + +import ( + "context" + "errors" + "fmt" + "net/url" + "strconv" + "strings" + "time" + + "github.com/prometheus/client_golang/prometheus" + log "github.com/sirupsen/logrus" + tomb "gopkg.in/tomb.v2" + yaml "gopkg.in/yaml.v2" + + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" + "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/loki/internal/lokiclient" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +const ( + readyTimeout time.Duration = 3 * time.Second + readyLoop int = 3 + readySleep time.Duration = 10 * time.Second + lokiLimit int = 100 +) + +var linesRead = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "cs_lokisource_hits_total", + Help: "Total lines that were read.", + }, + []string{"source"}) + +type LokiAuthConfiguration struct { + Username string `yaml:"username"` + Password string `yaml:"password"` +} + +type LokiConfiguration struct { + URL string `yaml:"url"` // Loki url + Prefix string `yaml:"prefix"` // Loki prefix + Query string `yaml:"query"` // LogQL query + Limit int `yaml:"limit"` // Limit of logs to read + DelayFor time.Duration `yaml:"delay_for"` + Since time.Duration `yaml:"since"` + Headers map[string]string `yaml:"headers"` // HTTP headers for talking to Loki + WaitForReady time.Duration `yaml:"wait_for_ready"` // Retry interval, default is 10 seconds + Auth LokiAuthConfiguration `yaml:"auth"` + MaxFailureDuration time.Duration `yaml:"max_failure_duration"` // Max duration of failure before stopping the source + configuration.DataSourceCommonCfg `yaml:",inline"` +} + +type LokiSource struct { + metricsLevel int + Config LokiConfiguration + + Client *lokiclient.LokiClient + + logger *log.Entry + lokiWebsocket string +} + +func (l *LokiSource) GetMetrics() []prometheus.Collector { + return []prometheus.Collector{linesRead} +} + +func (l *LokiSource) GetAggregMetrics() []prometheus.Collector { + return []prometheus.Collector{linesRead} +} + +func (l *LokiSource) UnmarshalConfig(yamlConfig []byte) error { + err := yaml.UnmarshalStrict(yamlConfig, &l.Config) + if err != nil { + return fmt.Errorf("cannot parse loki acquisition configuration: %w", err) + } + + if l.Config.Query == "" { + return errors.New("loki query is mandatory") + } + + if l.Config.WaitForReady == 0 { + l.Config.WaitForReady = 10 * time.Second + } + + if l.Config.DelayFor < 0*time.Second || l.Config.DelayFor > 5*time.Second { + return errors.New("delay_for should be a value between 1s and 5s") + } + + if l.Config.Mode == "" { + l.Config.Mode = configuration.TAIL_MODE + } + if l.Config.Prefix == "" { + l.Config.Prefix = "/" + } + + if !strings.HasSuffix(l.Config.Prefix, "/") { + l.Config.Prefix += "/" + } + + if l.Config.Limit == 0 { + l.Config.Limit = lokiLimit + } + + if l.Config.Mode == configuration.TAIL_MODE { + l.logger.Infof("Resetting since") + l.Config.Since = 0 + } + + if l.Config.MaxFailureDuration == 0 { + l.Config.MaxFailureDuration = 30 * time.Second + } + + return nil +} + +func (l *LokiSource) Configure(config []byte, logger *log.Entry, MetricsLevel int) error { + l.Config = LokiConfiguration{} + l.logger = logger + l.metricsLevel = MetricsLevel + err := l.UnmarshalConfig(config) + if err != nil { + return err + } + + l.logger.Infof("Since value: %s", l.Config.Since.String()) + + clientConfig := lokiclient.Config{ + LokiURL: l.Config.URL, + Headers: l.Config.Headers, + Limit: l.Config.Limit, + Query: l.Config.Query, + Since: l.Config.Since, + Username: l.Config.Auth.Username, + Password: l.Config.Auth.Password, + FailMaxDuration: l.Config.MaxFailureDuration, + } + + l.Client = lokiclient.NewLokiClient(clientConfig) + l.Client.Logger = logger.WithFields(log.Fields{"component": "lokiclient", "source": l.Config.URL}) + return nil +} + +func (l *LokiSource) ConfigureByDSN(dsn string, labels map[string]string, logger *log.Entry, uuid string) error { + l.logger = logger + l.Config = LokiConfiguration{} + l.Config.Mode = configuration.CAT_MODE + l.Config.Labels = labels + l.Config.UniqueId = uuid + + u, err := url.Parse(dsn) + if err != nil { + return fmt.Errorf("while parsing dsn '%s': %w", dsn, err) + } + if u.Scheme != "loki" { + return fmt.Errorf("invalid DSN %s for loki source, must start with loki://", dsn) + } + if u.Host == "" { + return errors.New("empty loki host") + } + scheme := "http" + + params := u.Query() + if q := params.Get("ssl"); q != "" { + scheme = "https" + } + if q := params.Get("query"); q != "" { + l.Config.Query = q + } + if w := params.Get("wait_for_ready"); w != "" { + l.Config.WaitForReady, err = time.ParseDuration(w) + if err != nil { + return err + } + } else { + l.Config.WaitForReady = 10 * time.Second + } + + if d := params.Get("delay_for"); d != "" { + l.Config.DelayFor, err = time.ParseDuration(d) + if err != nil { + return fmt.Errorf("invalid duration: %w", err) + } + if l.Config.DelayFor < 0*time.Second || l.Config.DelayFor > 5*time.Second { + return errors.New("delay_for should be a value between 1s and 5s") + } + } else { + l.Config.DelayFor = 0 * time.Second + } + + if s := params.Get("since"); s != "" { + l.Config.Since, err = time.ParseDuration(s) + if err != nil { + return fmt.Errorf("invalid since in dsn: %w", err) + } + } + + if max_failure_duration := params.Get("max_failure_duration"); max_failure_duration != "" { + duration, err := time.ParseDuration(max_failure_duration) + if err != nil { + return fmt.Errorf("invalid max_failure_duration in dsn: %w", err) + } + l.Config.MaxFailureDuration = duration + } else { + l.Config.MaxFailureDuration = 5 * time.Second // for OneShot mode it doesn't make sense to have longer duration + } + + if limit := params.Get("limit"); limit != "" { + limit, err := strconv.Atoi(limit) + if err != nil { + return fmt.Errorf("invalid limit in dsn: %w", err) + } + l.Config.Limit = limit + } else { + l.Config.Limit = 5000 // max limit allowed by loki + } + + if logLevel := params.Get("log_level"); logLevel != "" { + level, err := log.ParseLevel(logLevel) + if err != nil { + return fmt.Errorf("invalid log_level in dsn: %w", err) + } + l.Config.LogLevel = &level + l.logger.Logger.SetLevel(level) + } + + l.Config.URL = fmt.Sprintf("%s://%s", scheme, u.Host) + if u.User != nil { + l.Config.Auth.Username = u.User.Username() + l.Config.Auth.Password, _ = u.User.Password() + } + + clientConfig := lokiclient.Config{ + LokiURL: l.Config.URL, + Headers: l.Config.Headers, + Limit: l.Config.Limit, + Query: l.Config.Query, + Since: l.Config.Since, + Username: l.Config.Auth.Username, + Password: l.Config.Auth.Password, + DelayFor: int(l.Config.DelayFor / time.Second), + } + + l.Client = lokiclient.NewLokiClient(clientConfig) + l.Client.Logger = logger.WithFields(log.Fields{"component": "lokiclient", "source": l.Config.URL}) + + return nil +} + +func (l *LokiSource) GetMode() string { + return l.Config.Mode +} + +func (l *LokiSource) GetName() string { + return "loki" +} + +// OneShotAcquisition reads a set of file and returns when done +func (l *LokiSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { + l.logger.Debug("Loki one shot acquisition") + l.Client.SetTomb(t) + readyCtx, cancel := context.WithTimeout(context.Background(), l.Config.WaitForReady) + defer cancel() + err := l.Client.Ready(readyCtx) + if err != nil { + return fmt.Errorf("loki is not ready: %w", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + c := l.Client.QueryRange(ctx, false) + + for { + select { + case <-t.Dying(): + l.logger.Debug("Loki one shot acquisition stopped") + cancel() + return nil + case resp, ok := <-c: + if !ok { + l.logger.Info("Loki acquisition done, chan closed") + cancel() + return nil + } + for _, stream := range resp.Data.Result { + for _, entry := range stream.Entries { + l.readOneEntry(entry, l.Config.Labels, out) + } + } + } + } +} + +func (l *LokiSource) readOneEntry(entry lokiclient.Entry, labels map[string]string, out chan types.Event) { + ll := types.Line{} + ll.Raw = entry.Line + ll.Time = entry.Timestamp + ll.Src = l.Config.URL + ll.Labels = labels + ll.Process = true + ll.Module = l.GetName() + + if l.metricsLevel != configuration.METRICS_NONE { + linesRead.With(prometheus.Labels{"source": l.Config.URL}).Inc() + } + expectMode := types.LIVE + if l.Config.UseTimeMachine { + expectMode = types.TIMEMACHINE + } + out <- types.Event{ + Line: ll, + Process: true, + Type: types.LOG, + ExpectMode: expectMode, + } +} + +func (l *LokiSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { + l.Client.SetTomb(t) + readyCtx, cancel := context.WithTimeout(ctx, l.Config.WaitForReady) + defer cancel() + err := l.Client.Ready(readyCtx) + if err != nil { + return fmt.Errorf("loki is not ready: %w", err) + } + ll := l.logger.WithField("websocket_url", l.lokiWebsocket) + t.Go(func() error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + respChan := l.Client.QueryRange(ctx, true) + if err != nil { + ll.Errorf("could not start loki tail: %s", err) + return fmt.Errorf("while starting loki tail: %w", err) + } + for { + select { + case resp, ok := <-respChan: + if !ok { + ll.Warnf("loki channel closed") + return err + } + for _, stream := range resp.Data.Result { + for _, entry := range stream.Entries { + l.readOneEntry(entry, l.Config.Labels, out) + } + } + case <-t.Dying(): + return nil + } + } + }) + return nil +} + +func (l *LokiSource) CanRun() error { + return nil +} + +func (l *LokiSource) GetUuid() string { + return l.Config.UniqueId +} + +func (l *LokiSource) Dump() interface{} { + return l +} + +// SupportedModes returns the supported modes by the acquisition module +func (l *LokiSource) SupportedModes() []string { + return []string{configuration.TAIL_MODE, configuration.CAT_MODE} +} diff --git a/pkg/acquisition/modules/loki/loki_test.go b/pkg/acquisition/modules/loki/loki_test.go new file mode 100644 index 00000000000..627200217f5 --- /dev/null +++ b/pkg/acquisition/modules/loki/loki_test.go @@ -0,0 +1,564 @@ +package loki_test + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "runtime" + "strings" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + tomb "gopkg.in/tomb.v2" + + "github.com/crowdsecurity/go-cs-lib/cstest" + + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" + "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/loki" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +func TestConfiguration(t *testing.T) { + log.Infof("Test 'TestConfigure'") + + tests := []struct { + config string + expectedErr string + password string + waitForReady time.Duration + delayFor time.Duration + testName string + }{ + { + config: `foobar: asd`, + expectedErr: "line 1: field foobar not found in type loki.LokiConfiguration", + testName: "Unknown field", + }, + { + config: ` +mode: tail +source: loki`, + expectedErr: "loki query is mandatory", + testName: "Missing url", + }, + { + config: ` +mode: tail +source: loki +url: http://localhost:3100/ +`, + expectedErr: "loki query is mandatory", + testName: "Missing query", + }, + { + config: ` +mode: tail +source: loki +url: http://localhost:3100/ +query: > + {server="demo"} +`, + expectedErr: "", + testName: "Correct config", + }, + { + config: ` +mode: tail +source: loki +url: http://localhost:3100/ +wait_for_ready: 5s +query: > + {server="demo"} +`, + expectedErr: "", + testName: "Correct config with wait_for_ready", + waitForReady: 5 * time.Second, + }, + { + config: ` +mode: tail +source: loki +url: http://localhost:3100/ +delay_for: 1s +query: > + {server="demo"} +`, + expectedErr: "", + testName: "Correct config with delay_for", + delayFor: 1 * time.Second, + }, + { + config: ` +mode: tail +source: loki +url: http://localhost:3100/ +auth: + username: foo + password: bar +query: > + {server="demo"} +`, + expectedErr: "", + password: "bar", + testName: "Correct config with password", + }, + { + config: ` +mode: tail +source: loki +url: http://localhost:3100/ +delay_for: 10s +query: > + {server="demo"} +`, + expectedErr: "delay_for should be a value between 1s and 5s", + testName: "Invalid DelayFor", + }, + } + subLogger := log.WithField("type", "loki") + + for _, test := range tests { + t.Run(test.testName, func(t *testing.T) { + lokiSource := loki.LokiSource{} + err := lokiSource.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) + cstest.AssertErrorContains(t, err, test.expectedErr) + + if test.password != "" { + p := lokiSource.Config.Auth.Password + if test.password != p { + t.Fatalf("Password mismatch : %s != %s", test.password, p) + } + } + + if test.waitForReady != 0 { + if lokiSource.Config.WaitForReady != test.waitForReady { + t.Fatalf("Wrong WaitForReady %v != %v", lokiSource.Config.WaitForReady, test.waitForReady) + } + } + + if test.delayFor != 0 { + if lokiSource.Config.DelayFor != test.delayFor { + t.Fatalf("Wrong DelayFor %v != %v", lokiSource.Config.DelayFor, test.delayFor) + } + } + }) + } +} + +func TestConfigureDSN(t *testing.T) { + log.Infof("Test 'TestConfigureDSN'") + + tests := []struct { + name string + dsn string + expectedErr string + since time.Time + password string + scheme string + waitForReady time.Duration + delayFor time.Duration + }{ + { + name: "Wrong scheme", + dsn: "wrong://", + expectedErr: "invalid DSN wrong:// for loki source, must start with loki://", + }, + { + name: "Correct DSN", + dsn: `loki://localhost:3100/?query={server="demo"}`, + expectedErr: "", + }, + { + name: "Empty host", + dsn: "loki://", + expectedErr: "empty loki host", + }, + { + name: "Invalid DSN", + dsn: "loki", + expectedErr: "invalid DSN loki for loki source, must start with loki://", + }, + { + name: "Invalid Delay", + dsn: `loki://localhost:3100/?query={server="demo"}&delay_for=10s`, + expectedErr: "delay_for should be a value between 1s and 5s", + }, + { + name: "Bad since param", + dsn: `loki://127.0.0.1:3100/?since=3h&query={server="demo"}`, + since: time.Now().Add(-3 * time.Hour), + }, + { + name: "Basic Auth", + dsn: `loki://login:password@localhost:3102/?query={server="demo"}`, + password: "password", + }, + { + name: "Correct DSN", + dsn: `loki://localhost:3100/?query={server="demo"}&wait_for_ready=5s&delay_for=1s`, + expectedErr: "", + waitForReady: 5 * time.Second, + delayFor: 1 * time.Second, + }, + { + name: "SSL DSN", + dsn: `loki://localhost:3100/?ssl=true`, + scheme: "https", + }, + } + + for _, test := range tests { + subLogger := log.WithFields(log.Fields{ + "type": "loki", + "name": test.name, + }) + + t.Logf("Test : %s", test.name) + + lokiSource := &loki.LokiSource{} + err := lokiSource.ConfigureByDSN(test.dsn, map[string]string{"type": "testtype"}, subLogger, "") + cstest.AssertErrorContains(t, err, test.expectedErr) + + noDuration, _ := time.ParseDuration("0s") + if lokiSource.Config.Since != noDuration && lokiSource.Config.Since.Round(time.Second) != time.Since(test.since).Round(time.Second) { + t.Fatalf("Invalid since %v", lokiSource.Config.Since) + } + + if test.password != "" { + p := lokiSource.Config.Auth.Password + if test.password != p { + t.Fatalf("Password mismatch : %s != %s", test.password, p) + } + } + + if test.scheme != "" { + url, _ := url.Parse(lokiSource.Config.URL) + if test.scheme != url.Scheme { + t.Fatalf("Schema mismatch : %s != %s", test.scheme, url.Scheme) + } + } + + if test.waitForReady != 0 { + if lokiSource.Config.WaitForReady != test.waitForReady { + t.Fatalf("Wrong WaitForReady %v != %v", lokiSource.Config.WaitForReady, test.waitForReady) + } + } + + if test.delayFor != 0 { + if lokiSource.Config.DelayFor != test.delayFor { + t.Fatalf("Wrong DelayFor %v != %v", lokiSource.Config.DelayFor, test.delayFor) + } + } + } +} + +func feedLoki(ctx context.Context, logger *log.Entry, n int, title string) error { + streams := LogStreams{ + Streams: []LogStream{ + { + Stream: map[string]string{ + "server": "demo", + "domain": "cw.example.com", + "key": title, + }, + Values: make([]LogValue, n), + }, + }, + } + for i := range n { + streams.Streams[0].Values[i] = LogValue{ + Time: time.Now(), + Line: fmt.Sprintf("Log line #%d %v", i, title), + } + } + + buff, err := json.Marshal(streams) + if err != nil { + return err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://127.0.0.1:3100/loki/api/v1/push", bytes.NewBuffer(buff)) + if err != nil { + return err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Scope-Orgid", "1234") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusNoContent { + b, _ := io.ReadAll(resp.Body) + logger.Error(string(b)) + + return fmt.Errorf("Bad post status %d", resp.StatusCode) + } + + logger.Info(n, " Events sent") + + return nil +} + +func TestOneShotAcquisition(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping test on windows") + } + + log.SetOutput(os.Stdout) + log.SetLevel(log.InfoLevel) + log.Info("Test 'TestStreamingAcquisition'") + + title := time.Now().String() // Loki will be messy, with a lot of stuff, lets use a unique key + tests := []struct { + config string + }{ + { + config: fmt.Sprintf(` +mode: cat +source: loki +url: http://127.0.0.1:3100 +query: '{server="demo",key="%s"}' +headers: + x-scope-orgid: "1234" +since: 1h +`, title), + }, + } + + for _, ts := range tests { + logger := log.New() + subLogger := logger.WithField("type", "loki") + lokiSource := loki.LokiSource{} + err := lokiSource.Configure([]byte(ts.config), subLogger, configuration.METRICS_NONE) + if err != nil { + t.Fatalf("Unexpected error : %s", err) + } + + ctx := context.Background() + + err = feedLoki(ctx, subLogger, 20, title) + if err != nil { + t.Fatalf("Unexpected error : %s", err) + } + + out := make(chan types.Event) + read := 0 + + go func() { + for { + <-out + + read++ + } + }() + + lokiTomb := tomb.Tomb{} + + err = lokiSource.OneShotAcquisition(out, &lokiTomb) + if err != nil { + t.Fatalf("Unexpected error : %s", err) + } + + assert.Equal(t, 20, read) + } +} + +func TestStreamingAcquisition(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping test on windows") + } + + log.SetOutput(os.Stdout) + log.SetLevel(log.InfoLevel) + log.Info("Test 'TestStreamingAcquisition'") + + title := time.Now().String() + tests := []struct { + name string + config string + expectedErr string + streamErr string + expectedLines int + }{ + { + name: "Bad port", + config: `mode: tail +source: loki +url: "http://127.0.0.1:3101" +headers: + x-scope-orgid: "1234" +query: > + {server="demo"}`, // No Loki server here + expectedErr: "", + streamErr: `loki is not ready: context deadline exceeded`, + expectedLines: 0, + }, + { + name: "ok", + config: `mode: tail +source: loki +url: "http://127.0.0.1:3100" +headers: + x-scope-orgid: "1234" +query: > + {server="demo"}`, + expectedErr: "", + streamErr: "", + expectedLines: 20, + }, + } + + ctx := context.Background() + + for _, ts := range tests { + t.Run(ts.name, func(t *testing.T) { + logger := log.New() + subLogger := logger.WithFields(log.Fields{ + "type": "loki", + "name": ts.name, + }) + + out := make(chan types.Event) + lokiTomb := tomb.Tomb{} + lokiSource := loki.LokiSource{} + + err := lokiSource.Configure([]byte(ts.config), subLogger, configuration.METRICS_NONE) + if err != nil { + t.Fatalf("Unexpected error : %s", err) + } + + err = lokiSource.StreamingAcquisition(ctx, out, &lokiTomb) + cstest.AssertErrorContains(t, err, ts.streamErr) + + if ts.streamErr != "" { + return + } + + time.Sleep(time.Second * 2) // We need to give time to start reading from the WS + + readTomb := tomb.Tomb{} + readCtx, cancel := context.WithTimeout(ctx, time.Second*10) + count := 0 + + readTomb.Go(func() error { + defer cancel() + + for { + select { + case <-readCtx.Done(): + return readCtx.Err() + case evt := <-out: + count++ + + if !strings.HasSuffix(evt.Line.Raw, title) { + return fmt.Errorf("Incorrect suffix : %s", evt.Line.Raw) + } + + if count == ts.expectedLines { + return nil + } + } + } + }) + + err = feedLoki(ctx, subLogger, ts.expectedLines, title) + if err != nil { + t.Fatalf("Unexpected error : %s", err) + } + + err = readTomb.Wait() + + cancel() + + if err != nil { + t.Fatalf("Unexpected error : %s", err) + } + + assert.Equal(t, ts.expectedLines, count) + }) + } +} + +func TestStopStreaming(t *testing.T) { + ctx := context.Background() + if runtime.GOOS == "windows" { + t.Skip("Skipping test on windows") + } + + config := ` +mode: tail +source: loki +url: http://127.0.0.1:3100 +headers: + x-scope-orgid: "1234" +query: > + {server="demo"} +` + logger := log.New() + subLogger := logger.WithField("type", "loki") + title := time.Now().String() + lokiSource := loki.LokiSource{} + + err := lokiSource.Configure([]byte(config), subLogger, configuration.METRICS_NONE) + if err != nil { + t.Fatalf("Unexpected error : %s", err) + } + + out := make(chan types.Event) + + lokiTomb := &tomb.Tomb{} + + err = lokiSource.StreamingAcquisition(ctx, out, lokiTomb) + if err != nil { + t.Fatalf("Unexpected error : %s", err) + } + + time.Sleep(time.Second * 2) + + err = feedLoki(ctx, subLogger, 1, title) + if err != nil { + t.Fatalf("Unexpected error : %s", err) + } + + lokiTomb.Kill(nil) + + err = lokiTomb.Wait() + if err != nil { + t.Fatalf("Unexpected error : %s", err) + } +} + +type LogStreams struct { + Streams []LogStream `json:"streams"` +} + +type LogStream struct { + Stream map[string]string `json:"stream"` + Values []LogValue `json:"values"` +} + +type LogValue struct { + Time time.Time + Line string +} + +func (l *LogValue) MarshalJSON() ([]byte, error) { + line, err := json.Marshal(l.Line) + if err != nil { + return nil, err + } + + return []byte(fmt.Sprintf(`["%d",%s]`, l.Time.UnixNano(), string(line))), nil +} diff --git a/pkg/acquisition/modules/loki/timestamp.go b/pkg/acquisition/modules/loki/timestamp.go new file mode 100644 index 00000000000..f1ec246eaa8 --- /dev/null +++ b/pkg/acquisition/modules/loki/timestamp.go @@ -0,0 +1,29 @@ +package loki + +import ( + "fmt" + "time" +) + +type timestamp time.Time + +func (t *timestamp) UnmarshalYAML(unmarshal func(interface{}) error) error { + var tt time.Time + err := unmarshal(&tt) + if err == nil { + *t = timestamp(tt) + return nil + } + var d time.Duration + err = unmarshal(&d) + if err == nil { + *t = timestamp(time.Now().Add(-d)) + fmt.Println("t", time.Time(*t).Format(time.RFC3339)) + return nil + } + return err +} + +func (t *timestamp) IsZero() bool { + return time.Time(*t).IsZero() +} diff --git a/pkg/acquisition/modules/loki/timestamp_test.go b/pkg/acquisition/modules/loki/timestamp_test.go new file mode 100644 index 00000000000..a583cc057d3 --- /dev/null +++ b/pkg/acquisition/modules/loki/timestamp_test.go @@ -0,0 +1,47 @@ +package loki + +import ( + "testing" + "time" + + "gopkg.in/yaml.v2" +) + +func TestTimestampFail(t *testing.T) { + var tt timestamp + err := yaml.Unmarshal([]byte("plop"), tt) + if err == nil { + t.Fail() + } +} + +func TestTimestampTime(t *testing.T) { + var tt timestamp + const ts string = "2022-06-14T12:56:39+02:00" + err := yaml.Unmarshal([]byte(ts), &tt) + if err != nil { + t.Error(err) + t.Fail() + } + if ts != time.Time(tt).Format(time.RFC3339) { + t.Fail() + } +} + +func TestTimestampDuration(t *testing.T) { + var tt timestamp + err := yaml.Unmarshal([]byte("3h"), &tt) + if err != nil { + t.Error(err) + t.Fail() + } + d, err := time.ParseDuration("3h") + if err != nil { + t.Error(err) + t.Fail() + } + z := time.Now().Add(-d) + if z.Round(time.Second) != time.Time(tt).Round(time.Second) { + t.Fail() + } +} diff --git a/pkg/acquisition/modules/s3/s3.go b/pkg/acquisition/modules/s3/s3.go index 651d40d3d50..ed1964edebf 100644 --- a/pkg/acquisition/modules/s3/s3.go +++ b/pkg/acquisition/modules/s3/s3.go @@ -38,7 +38,7 @@ type S3Configuration struct { AwsEndpoint string `yaml:"aws_endpoint"` BucketName string `yaml:"bucket_name"` Prefix string `yaml:"prefix"` - Key string `yaml:"-"` //Only for DSN acquisition + Key string `yaml:"-"` // Only for DSN acquisition PollingMethod string `yaml:"polling_method"` PollingInterval int `yaml:"polling_interval"` SQSName string `yaml:"sqs_name"` @@ -47,15 +47,16 @@ type S3Configuration struct { } type S3Source struct { - Config S3Configuration - logger *log.Entry - s3Client s3iface.S3API - sqsClient sqsiface.SQSAPI - readerChan chan S3Object - t *tomb.Tomb - out chan types.Event - ctx aws.Context - cancel context.CancelFunc + MetricsLevel int + Config S3Configuration + logger *log.Entry + s3Client s3iface.S3API + sqsClient sqsiface.SQSAPI + readerChan chan S3Object + t *tomb.Tomb + out chan types.Event + ctx aws.Context + cancel context.CancelFunc } type S3Object struct { @@ -92,10 +93,12 @@ type S3Event struct { } `json:"detail"` } -const PollMethodList = "list" -const PollMethodSQS = "sqs" -const SQSFormatEventBridge = "eventbridge" -const SQSFormatS3Notification = "s3notification" +const ( + PollMethodList = "list" + PollMethodSQS = "sqs" + SQSFormatEventBridge = "eventbridge" + SQSFormatS3Notification = "s3notification" +) var linesRead = prometheus.NewCounterVec( prometheus.CounterOpts{ @@ -130,7 +133,6 @@ func (s *S3Source) newS3Client() error { } sess, err := session.NewSessionWithOptions(options) - if err != nil { return fmt.Errorf("failed to create aws session: %w", err) } @@ -145,7 +147,7 @@ func (s *S3Source) newS3Client() error { s.s3Client = s3.New(sess, config) if s.s3Client == nil { - return fmt.Errorf("failed to create S3 client") + return errors.New("failed to create S3 client") } return nil @@ -166,7 +168,7 @@ func (s *S3Source) newSQSClient() error { } if sess == nil { - return fmt.Errorf("failed to create aws session") + return errors.New("failed to create aws session") } config := aws.NewConfig() if s.Config.AwsRegion != "" { @@ -177,7 +179,7 @@ func (s *S3Source) newSQSClient() error { } s.sqsClient = sqs.New(sess, config) if s.sqsClient == nil { - return fmt.Errorf("failed to create SQS client") + return errors.New("failed to create SQS client") } return nil } @@ -204,7 +206,7 @@ func (s *S3Source) getBucketContent() ([]*s3.Object, error) { logger := s.logger.WithField("method", "getBucketContent") logger.Debugf("Getting bucket content for %s", s.Config.BucketName) bucketObjects := make([]*s3.Object, 0) - var continuationToken *string = nil + var continuationToken *string for { out, err := s.s3Client.ListObjectsV2WithContext(s.ctx, &s3.ListObjectsV2Input{ Bucket: aws.String(s.Config.BucketName), @@ -250,16 +252,15 @@ func (s *S3Source) listPoll() error { continue } for i := len(bucketObjects) - 1; i >= 0; i-- { - if bucketObjects[i].LastModified.After(lastObjectDate) { - newObject = true - logger.Debugf("Found new object %s", *bucketObjects[i].Key) - s.readerChan <- S3Object{ - Bucket: s.Config.BucketName, - Key: *bucketObjects[i].Key, - } - } else { + if !bucketObjects[i].LastModified.After(lastObjectDate) { break } + newObject = true + logger.Debugf("Found new object %s", *bucketObjects[i].Key) + s.readerChan <- S3Object{ + Bucket: s.Config.BucketName, + Key: *bucketObjects[i].Key, + } } if newObject { lastObjectDate = *bucketObjects[len(bucketObjects)-1].LastModified @@ -277,7 +278,7 @@ func extractBucketAndPrefixFromEventBridge(message *string) (string, string, err if eventBody.Detail.Bucket.Name != "" { return eventBody.Detail.Bucket.Name, eventBody.Detail.Object.Key, nil } - return "", "", fmt.Errorf("invalid event body for event bridge format") + return "", "", errors.New("invalid event body for event bridge format") } func extractBucketAndPrefixFromS3Notif(message *string) (string, string, error) { @@ -287,7 +288,7 @@ func extractBucketAndPrefixFromS3Notif(message *string) (string, string, error) return "", "", err } if len(s3notifBody.Records) == 0 { - return "", "", fmt.Errorf("no records found in S3 notification") + return "", "", errors.New("no records found in S3 notification") } if !strings.HasPrefix(s3notifBody.Records[0].EventName, "ObjectCreated:") { return "", "", fmt.Errorf("event %s is not supported", s3notifBody.Records[0].EventName) @@ -296,19 +297,20 @@ func extractBucketAndPrefixFromS3Notif(message *string) (string, string, error) } func (s *S3Source) extractBucketAndPrefix(message *string) (string, string, error) { - if s.Config.SQSFormat == SQSFormatEventBridge { + switch s.Config.SQSFormat { + case SQSFormatEventBridge: bucket, key, err := extractBucketAndPrefixFromEventBridge(message) if err != nil { return "", "", err } return bucket, key, nil - } else if s.Config.SQSFormat == SQSFormatS3Notification { + case SQSFormatS3Notification: bucket, key, err := extractBucketAndPrefixFromS3Notif(message) if err != nil { return "", "", err } return bucket, key, nil - } else { + default: bucket, key, err := extractBucketAndPrefixFromEventBridge(message) if err == nil { s.Config.SQSFormat = SQSFormatEventBridge @@ -319,7 +321,7 @@ func (s *S3Source) extractBucketAndPrefix(message *string) (string, string, erro s.Config.SQSFormat = SQSFormatS3Notification return bucket, key, nil } - return "", "", fmt.Errorf("SQS message format not supported") + return "", "", errors.New("SQS message format not supported") } } @@ -336,7 +338,7 @@ func (s *S3Source) sqsPoll() error { out, err := s.sqsClient.ReceiveMessageWithContext(s.ctx, &sqs.ReceiveMessageInput{ QueueUrl: aws.String(s.Config.SQSName), MaxNumberOfMessages: aws.Int64(10), - WaitTimeSeconds: aws.Int64(20), //Probably no need to make it configurable ? + WaitTimeSeconds: aws.Int64(20), // Probably no need to make it configurable ? }) if err != nil { logger.Errorf("Error while polling SQS: %s", err) @@ -345,11 +347,13 @@ func (s *S3Source) sqsPoll() error { logger.Tracef("SQS output: %v", out) logger.Debugf("Received %d messages from SQS", len(out.Messages)) for _, message := range out.Messages { - sqsMessagesReceived.WithLabelValues(s.Config.SQSName).Inc() + if s.MetricsLevel != configuration.METRICS_NONE { + sqsMessagesReceived.WithLabelValues(s.Config.SQSName).Inc() + } bucket, key, err := s.extractBucketAndPrefix(message.Body) if err != nil { logger.Errorf("Error while parsing SQS message: %s", err) - //Always delete the message to avoid infinite loop + // Always delete the message to avoid infinite loop _, err = s.sqsClient.DeleteMessage(&sqs.DeleteMessageInput{ QueueUrl: aws.String(s.Config.SQSName), ReceiptHandle: message.ReceiptHandle, @@ -375,7 +379,7 @@ func (s *S3Source) sqsPoll() error { } func (s *S3Source) readFile(bucket string, key string) error { - //TODO: Handle SSE-C + // TODO: Handle SSE-C var scanner *bufio.Scanner logger := s.logger.WithFields(log.Fields{ @@ -388,14 +392,13 @@ func (s *S3Source) readFile(bucket string, key string) error { Bucket: aws.String(bucket), Key: aws.String(key), }) - if err != nil { return fmt.Errorf("failed to get object %s/%s: %w", bucket, key, err) } defer output.Body.Close() if strings.HasSuffix(key, ".gz") { - //This *might* be a gzipped file, but sometimes the SDK will decompress the data for us (it's not clear when it happens, only had the issue with cloudtrail logs) + // This *might* be a gzipped file, but sometimes the SDK will decompress the data for us (it's not clear when it happens, only had the issue with cloudtrail logs) header := make([]byte, 2) _, err := output.Body.Read(header) if err != nil { @@ -426,14 +429,20 @@ func (s *S3Source) readFile(bucket string, key string) error { default: text := scanner.Text() logger.Tracef("Read line %s", text) - linesRead.WithLabelValues(bucket).Inc() + if s.MetricsLevel != configuration.METRICS_NONE { + linesRead.WithLabelValues(bucket).Inc() + } l := types.Line{} l.Raw = text l.Labels = s.Config.Labels l.Time = time.Now().UTC() l.Process = true l.Module = s.GetName() - l.Src = bucket + "/" + key + if s.MetricsLevel == configuration.METRICS_FULL { + l.Src = bucket + "/" + key + } else if s.MetricsLevel == configuration.METRICS_AGGREGATE { + l.Src = bucket + } var evt types.Event if !s.Config.UseTimeMachine { evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE} @@ -446,7 +455,9 @@ func (s *S3Source) readFile(bucket string, key string) error { if err := scanner.Err(); err != nil { return fmt.Errorf("failed to read object %s/%s: %s", bucket, key, err) } - objectsRead.WithLabelValues(bucket).Inc() + if s.MetricsLevel != configuration.METRICS_NONE { + objectsRead.WithLabelValues(bucket).Inc() + } return nil } @@ -457,6 +468,7 @@ func (s *S3Source) GetUuid() string { func (s *S3Source) GetMetrics() []prometheus.Collector { return []prometheus.Collector{linesRead, objectsRead, sqsMessagesReceived} } + func (s *S3Source) GetAggregMetrics() []prometheus.Collector { return []prometheus.Collector{linesRead, objectsRead, sqsMessagesReceived} } @@ -487,15 +499,15 @@ func (s *S3Source) UnmarshalConfig(yamlConfig []byte) error { } if s.Config.BucketName != "" && s.Config.SQSName != "" { - return fmt.Errorf("bucket_name and sqs_name are mutually exclusive") + return errors.New("bucket_name and sqs_name are mutually exclusive") } if s.Config.PollingMethod == PollMethodSQS && s.Config.SQSName == "" { - return fmt.Errorf("sqs_name is required when using sqs polling method") + return errors.New("sqs_name is required when using sqs polling method") } if s.Config.BucketName == "" && s.Config.PollingMethod == PollMethodList { - return fmt.Errorf("bucket_name is required") + return errors.New("bucket_name is required") } if s.Config.SQSFormat != "" && s.Config.SQSFormat != SQSFormatEventBridge && s.Config.SQSFormat != SQSFormatS3Notification { @@ -505,7 +517,7 @@ func (s *S3Source) UnmarshalConfig(yamlConfig []byte) error { return nil } -func (s *S3Source) Configure(yamlConfig []byte, logger *log.Entry) error { +func (s *S3Source) Configure(yamlConfig []byte, logger *log.Entry, metricsLevel int) error { err := s.UnmarshalConfig(yamlConfig) if err != nil { return err @@ -557,11 +569,11 @@ func (s *S3Source) ConfigureByDSN(dsn string, labels map[string]string, logger * }) dsn = strings.TrimPrefix(dsn, "s3://") args := strings.Split(dsn, "?") - if len(args[0]) == 0 { - return fmt.Errorf("empty s3:// DSN") + if args[0] == "" { + return errors.New("empty s3:// DSN") } - if len(args) == 2 && len(args[1]) != 0 { + if len(args) == 2 && args[1] != "" { params, err := url.ParseQuery(args[1]) if err != nil { return fmt.Errorf("could not parse s3 args: %w", err) @@ -600,7 +612,7 @@ func (s *S3Source) ConfigureByDSN(dsn string, labels map[string]string, logger * pathParts := strings.Split(args[0], "/") s.logger.Debugf("pathParts: %v", pathParts) - //FIXME: handle s3://bucket/ + // FIXME: handle s3://bucket/ if len(pathParts) == 1 { s.Config.BucketName = pathParts[0] s.Config.Prefix = "" @@ -643,7 +655,7 @@ func (s *S3Source) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error return err } } else { - //No key, get everything in the bucket based on the prefix + // No key, get everything in the bucket based on the prefix objects, err := s.getBucketContent() if err != nil { return err @@ -659,11 +671,11 @@ func (s *S3Source) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error return nil } -func (s *S3Source) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (s *S3Source) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { s.t = t s.out = out - s.readerChan = make(chan S3Object, 100) //FIXME: does this needs to be buffered? - s.ctx, s.cancel = context.WithCancel(context.Background()) + s.readerChan = make(chan S3Object, 100) // FIXME: does this needs to be buffered? + s.ctx, s.cancel = context.WithCancel(ctx) s.logger.Infof("starting acquisition of %s/%s", s.Config.BucketName, s.Config.Prefix) t.Go(func() error { s.readManager() diff --git a/pkg/acquisition/modules/s3/s3_test.go b/pkg/acquisition/modules/s3/s3_test.go index 02423b1392c..05a974517a0 100644 --- a/pkg/acquisition/modules/s3/s3_test.go +++ b/pkg/acquisition/modules/s3/s3_test.go @@ -14,10 +14,12 @@ import ( "github.com/aws/aws-sdk-go/service/s3/s3iface" "github.com/aws/aws-sdk-go/service/sqs" "github.com/aws/aws-sdk-go/service/sqs/sqsiface" - "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "gopkg.in/tomb.v2" + + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" + "github.com/crowdsecurity/crowdsec/pkg/types" ) func TestBadConfiguration(t *testing.T) { @@ -66,7 +68,7 @@ sqs_name: foobar for _, test := range tests { t.Run(test.name, func(t *testing.T) { f := S3Source{} - err := f.Configure([]byte(test.config), nil) + err := f.Configure([]byte(test.config), nil, configuration.METRICS_NONE) if err == nil { t.Fatalf("expected error, got none") } @@ -111,7 +113,7 @@ polling_method: list t.Run(test.name, func(t *testing.T) { f := S3Source{} logger := log.NewEntry(log.New()) - err := f.Configure([]byte(test.config), logger) + err := f.Configure([]byte(test.config), logger, configuration.METRICS_NONE) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } @@ -265,13 +267,12 @@ func TestDSNAcquis(t *testing.T) { time.Sleep(2 * time.Second) done <- true assert.Equal(t, test.expectedCount, linesRead) - }) } - } func TestListPolling(t *testing.T) { + ctx := context.Background() tests := []struct { name string config string @@ -306,7 +307,7 @@ prefix: foo/ f := S3Source{} logger := log.NewEntry(log.New()) logger.Logger.SetLevel(log.TraceLevel) - err := f.Configure([]byte(test.config), logger) + err := f.Configure([]byte(test.config), logger, configuration.METRICS_NONE) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } @@ -331,8 +332,7 @@ prefix: foo/ } }() - err = f.StreamingAcquisition(out, &tb) - + err = f.StreamingAcquisition(ctx, out, &tb) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } @@ -349,6 +349,7 @@ prefix: foo/ } func TestSQSPoll(t *testing.T) { + ctx := context.Background() tests := []struct { name string config string @@ -381,7 +382,7 @@ sqs_name: test linesRead := 0 f := S3Source{} logger := log.NewEntry(log.New()) - err := f.Configure([]byte(test.config), logger) + err := f.Configure([]byte(test.config), logger, configuration.METRICS_NONE) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } @@ -412,8 +413,7 @@ sqs_name: test } }() - err = f.StreamingAcquisition(out, &tb) - + err = f.StreamingAcquisition(ctx, out, &tb) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } diff --git a/pkg/acquisition/modules/syslog/internal/parser/rfc3164/parse.go b/pkg/acquisition/modules/syslog/internal/parser/rfc3164/parse.go index 3b59a806b8b..66d842ed519 100644 --- a/pkg/acquisition/modules/syslog/internal/parser/rfc3164/parse.go +++ b/pkg/acquisition/modules/syslog/internal/parser/rfc3164/parse.go @@ -1,7 +1,7 @@ package rfc3164 import ( - "fmt" + "errors" "time" "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/syslog/internal/parser/utils" @@ -52,7 +52,7 @@ func (r *RFC3164) parsePRI() error { pri := 0 if r.buf[r.position] != '<' { - return fmt.Errorf("PRI must start with '<'") + return errors.New("PRI must start with '<'") } r.position++ @@ -64,18 +64,18 @@ func (r *RFC3164) parsePRI() error { break } if c < '0' || c > '9' { - return fmt.Errorf("PRI must be a number") + return errors.New("PRI must be a number") } pri = pri*10 + int(c-'0') r.position++ } if pri > 999 { - return fmt.Errorf("PRI must be up to 3 characters long") + return errors.New("PRI must be up to 3 characters long") } if r.position == r.len && r.buf[r.position-1] != '>' { - return fmt.Errorf("PRI must end with '>'") + return errors.New("PRI must end with '>'") } r.PRI = pri @@ -98,7 +98,7 @@ func (r *RFC3164) parseTimestamp() error { } } if !validTs { - return fmt.Errorf("timestamp is not valid") + return errors.New("timestamp is not valid") } if r.useCurrentYear { if r.Timestamp.Year() == 0 { @@ -122,11 +122,11 @@ func (r *RFC3164) parseHostname() error { } if r.strictHostname { if !utils.IsValidHostnameOrIP(string(hostname)) { - return fmt.Errorf("hostname is not valid") + return errors.New("hostname is not valid") } } if len(hostname) == 0 { - return fmt.Errorf("hostname is empty") + return errors.New("hostname is empty") } r.Hostname = string(hostname) return nil @@ -147,7 +147,7 @@ func (r *RFC3164) parseTag() error { r.position++ } if len(tag) == 0 { - return fmt.Errorf("tag is empty") + return errors.New("tag is empty") } r.Tag = string(tag) @@ -167,7 +167,7 @@ func (r *RFC3164) parseTag() error { break } if c < '0' || c > '9' { - return fmt.Errorf("pid inside tag must be a number") + return errors.New("pid inside tag must be a number") } tmpPid = append(tmpPid, c) r.position++ @@ -175,7 +175,7 @@ func (r *RFC3164) parseTag() error { } if hasPid && !pidEnd { - return fmt.Errorf("pid inside tag must be closed with ']'") + return errors.New("pid inside tag must be closed with ']'") } if hasPid { @@ -191,7 +191,7 @@ func (r *RFC3164) parseMessage() error { } if r.position == r.len { - return fmt.Errorf("message is empty") + return errors.New("message is empty") } c := r.buf[r.position] @@ -202,7 +202,7 @@ func (r *RFC3164) parseMessage() error { for { if r.position >= r.len { - return fmt.Errorf("message is empty") + return errors.New("message is empty") } c := r.buf[r.position] if c != ' ' { @@ -219,7 +219,7 @@ func (r *RFC3164) parseMessage() error { func (r *RFC3164) Parse(message []byte) error { r.len = len(message) if r.len == 0 { - return fmt.Errorf("message is empty") + return errors.New("message is empty") } r.buf = message diff --git a/pkg/acquisition/modules/syslog/internal/parser/rfc3164/parse_test.go b/pkg/acquisition/modules/syslog/internal/parser/rfc3164/parse_test.go index 48772d596f4..3af6614bce6 100644 --- a/pkg/acquisition/modules/syslog/internal/parser/rfc3164/parse_test.go +++ b/pkg/acquisition/modules/syslog/internal/parser/rfc3164/parse_test.go @@ -4,6 +4,10 @@ import ( "fmt" "testing" "time" + + "github.com/stretchr/testify/assert" + + "github.com/crowdsecurity/go-cs-lib/cstest" ) func TestPri(t *testing.T) { @@ -22,33 +26,24 @@ func TestPri(t *testing.T) { } for _, test := range tests { - test := test t.Run(test.input, func(t *testing.T) { r := &RFC3164{} r.buf = []byte(test.input) r.len = len(r.buf) + err := r.parsePRI() - if err != nil { - if test.expectedErr != "" { - if err.Error() != test.expectedErr { - t.Errorf("expected error %s, got %s", test.expectedErr, err) - } - } else { - t.Errorf("unexpected error: %s", err) - } - } else { - if test.expectedErr != "" { - t.Errorf("expected error %s, got no error", test.expectedErr) - } else if r.PRI != test.expected { - t.Errorf("expected %d, got %d", test.expected, r.PRI) - } + cstest.RequireErrorContains(t, err, test.expectedErr) + + if test.expectedErr != "" { + return } + + assert.Equal(t, test.expected, r.PRI) }) } } func TestTimestamp(t *testing.T) { - tests := []struct { input string expected string @@ -64,31 +59,24 @@ func TestTimestamp(t *testing.T) { } for _, test := range tests { - test := test t.Run(test.input, func(t *testing.T) { opts := []RFC3164Option{} if test.currentYear { opts = append(opts, WithCurrentYear()) } + r := NewRFC3164Parser(opts...) r.buf = []byte(test.input) r.len = len(r.buf) + err := r.parseTimestamp() - if err != nil { - if test.expectedErr != "" { - if err.Error() != test.expectedErr { - t.Errorf("expected error %s, got %s", test.expectedErr, err) - } - } else { - t.Errorf("unexpected error: %s", err) - } - } else { - if test.expectedErr != "" { - t.Errorf("expected error %s, got no error", test.expectedErr) - } else if r.Timestamp.Format(time.RFC3339) != test.expected { - t.Errorf("expected %s, got %s", test.expected, r.Timestamp.Format(time.RFC3339)) - } + cstest.RequireErrorContains(t, err, test.expectedErr) + + if test.expectedErr != "" { + return } + + assert.Equal(t, test.expected, r.Timestamp.Format(time.RFC3339)) }) } } @@ -118,31 +106,24 @@ func TestHostname(t *testing.T) { } for _, test := range tests { - test := test t.Run(test.input, func(t *testing.T) { opts := []RFC3164Option{} if test.strictHostname { opts = append(opts, WithStrictHostname()) } + r := NewRFC3164Parser(opts...) r.buf = []byte(test.input) r.len = len(r.buf) + err := r.parseHostname() - if err != nil { - if test.expectedErr != "" { - if err.Error() != test.expectedErr { - t.Errorf("expected error %s, got %s", test.expectedErr, err) - } - } else { - t.Errorf("unexpected error: %s", err) - } - } else { - if test.expectedErr != "" { - t.Errorf("expected error %s, got no error", test.expectedErr) - } else if r.Hostname != test.expected { - t.Errorf("expected %s, got %s", test.expected, r.Hostname) - } + cstest.RequireErrorContains(t, err, test.expectedErr) + + if test.expectedErr != "" { + return } + + assert.Equal(t, test.expected, r.Hostname) }) } } @@ -163,32 +144,20 @@ func TestTag(t *testing.T) { } for _, test := range tests { - test := test t.Run(test.input, func(t *testing.T) { r := &RFC3164{} r.buf = []byte(test.input) r.len = len(r.buf) + err := r.parseTag() - if err != nil { - if test.expectedErr != "" { - if err.Error() != test.expectedErr { - t.Errorf("expected error %s, got %s", test.expectedErr, err) - } - } else { - t.Errorf("unexpected error: %s", err) - } - } else { - if test.expectedErr != "" { - t.Errorf("expected error %s, got no error", test.expectedErr) - } else { - if r.Tag != test.expected { - t.Errorf("expected %s, got %s", test.expected, r.Tag) - } - if r.PID != test.expectedPID { - t.Errorf("expected %s, got %s", test.expected, r.Message) - } - } + cstest.RequireErrorContains(t, err, test.expectedErr) + + if test.expectedErr != "" { + return } + + assert.Equal(t, test.expected, r.Tag) + assert.Equal(t, test.expectedPID, r.PID) }) } } @@ -207,27 +176,19 @@ func TestMessage(t *testing.T) { } for _, test := range tests { - test := test t.Run(test.input, func(t *testing.T) { r := &RFC3164{} r.buf = []byte(test.input) r.len = len(r.buf) + err := r.parseMessage() - if err != nil { - if test.expectedErr != "" { - if err.Error() != test.expectedErr { - t.Errorf("expected error %s, got %s", test.expectedErr, err) - } - } else { - t.Errorf("unexpected error: %s", err) - } - } else { - if test.expectedErr != "" { - t.Errorf("expected error %s, got no error", test.expectedErr) - } else if r.Message != test.expected { - t.Errorf("expected message %s, got %s", test.expected, r.Tag) - } + cstest.RequireErrorContains(t, err, test.expectedErr) + + if test.expectedErr != "" { + return } + + assert.Equal(t, test.expected, r.Message) }) } } @@ -241,6 +202,7 @@ func TestParse(t *testing.T) { Message string PRI int } + tests := []struct { input string expected expected @@ -329,42 +291,22 @@ func TestParse(t *testing.T) { } for _, test := range tests { - test := test t.Run(test.input, func(t *testing.T) { r := NewRFC3164Parser(test.opts...) + err := r.Parse([]byte(test.input)) - if err != nil { - if test.expectedErr != "" { - if err.Error() != test.expectedErr { - t.Errorf("expected error '%s', got '%s'", test.expectedErr, err) - } - } else { - t.Errorf("unexpected error: '%s'", err) - } - } else { - if test.expectedErr != "" { - t.Errorf("expected error '%s', got no error", test.expectedErr) - } else { - if r.Timestamp != test.expected.Timestamp { - t.Errorf("expected timestamp '%s', got '%s'", test.expected.Timestamp, r.Timestamp) - } - if r.Hostname != test.expected.Hostname { - t.Errorf("expected hostname '%s', got '%s'", test.expected.Hostname, r.Hostname) - } - if r.Tag != test.expected.Tag { - t.Errorf("expected tag '%s', got '%s'", test.expected.Tag, r.Tag) - } - if r.PID != test.expected.PID { - t.Errorf("expected pid '%s', got '%s'", test.expected.PID, r.PID) - } - if r.Message != test.expected.Message { - t.Errorf("expected message '%s', got '%s'", test.expected.Message, r.Message) - } - if r.PRI != test.expected.PRI { - t.Errorf("expected pri '%d', got '%d'", test.expected.PRI, r.PRI) - } - } + cstest.RequireErrorContains(t, err, test.expectedErr) + + if test.expectedErr != "" { + return } + + assert.Equal(t, test.expected.Timestamp, r.Timestamp) + assert.Equal(t, test.expected.Hostname, r.Hostname) + assert.Equal(t, test.expected.Tag, r.Tag) + assert.Equal(t, test.expected.PID, r.PID) + assert.Equal(t, test.expected.Message, r.Message) + assert.Equal(t, test.expected.PRI, r.PRI) }) } } diff --git a/pkg/acquisition/modules/syslog/internal/parser/rfc3164/perf_test.go b/pkg/acquisition/modules/syslog/internal/parser/rfc3164/perf_test.go index 42073cafbae..3805090f57f 100644 --- a/pkg/acquisition/modules/syslog/internal/parser/rfc3164/perf_test.go +++ b/pkg/acquisition/modules/syslog/internal/parser/rfc3164/perf_test.go @@ -51,7 +51,6 @@ func BenchmarkParse(b *testing.B) { } var err error for _, test := range tests { - test := test b.Run(string(test.input), func(b *testing.B) { for i := 0; i < b.N; i++ { r := NewRFC3164Parser(test.opts...) diff --git a/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse.go b/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse.go index 8b71a77e2e3..639e91e1224 100644 --- a/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse.go +++ b/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse.go @@ -1,7 +1,7 @@ package rfc5424 import ( - "fmt" + "errors" "time" "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/syslog/internal/parser/utils" @@ -52,7 +52,7 @@ func (r *RFC5424) parsePRI() error { pri := 0 if r.buf[r.position] != '<' { - return fmt.Errorf("PRI must start with '<'") + return errors.New("PRI must start with '<'") } r.position++ @@ -64,18 +64,18 @@ func (r *RFC5424) parsePRI() error { break } if c < '0' || c > '9' { - return fmt.Errorf("PRI must be a number") + return errors.New("PRI must be a number") } pri = pri*10 + int(c-'0') r.position++ } if pri > 999 { - return fmt.Errorf("PRI must be up to 3 characters long") + return errors.New("PRI must be up to 3 characters long") } if r.position == r.len && r.buf[r.position-1] != '>' { - return fmt.Errorf("PRI must end with '>'") + return errors.New("PRI must end with '>'") } r.PRI = pri @@ -84,11 +84,11 @@ func (r *RFC5424) parsePRI() error { func (r *RFC5424) parseVersion() error { if r.buf[r.position] != '1' { - return fmt.Errorf("version must be 1") + return errors.New("version must be 1") } r.position += 2 if r.position >= r.len { - return fmt.Errorf("version must be followed by a space") + return errors.New("version must be followed by a space") } return nil } @@ -113,17 +113,17 @@ func (r *RFC5424) parseTimestamp() error { } if len(timestamp) == 0 { - return fmt.Errorf("timestamp is empty") + return errors.New("timestamp is empty") } if r.position == r.len { - return fmt.Errorf("EOL after timestamp") + return errors.New("EOL after timestamp") } date, err := time.Parse(VALID_TIMESTAMP, string(timestamp)) if err != nil { - return fmt.Errorf("timestamp is not valid") + return errors.New("timestamp is not valid") } r.Timestamp = date @@ -131,7 +131,7 @@ func (r *RFC5424) parseTimestamp() error { r.position++ if r.position >= r.len { - return fmt.Errorf("EOL after timestamp") + return errors.New("EOL after timestamp") } return nil @@ -156,11 +156,11 @@ func (r *RFC5424) parseHostname() error { } if r.strictHostname { if !utils.IsValidHostnameOrIP(string(hostname)) { - return fmt.Errorf("hostname is not valid") + return errors.New("hostname is not valid") } } if len(hostname) == 0 { - return fmt.Errorf("hostname is empty") + return errors.New("hostname is empty") } r.Hostname = string(hostname) return nil @@ -185,11 +185,11 @@ func (r *RFC5424) parseAppName() error { } if len(appname) == 0 { - return fmt.Errorf("appname is empty") + return errors.New("appname is empty") } if len(appname) > 48 { - return fmt.Errorf("appname is too long") + return errors.New("appname is too long") } r.Tag = string(appname) @@ -215,11 +215,11 @@ func (r *RFC5424) parseProcID() error { } if len(procid) == 0 { - return fmt.Errorf("procid is empty") + return errors.New("procid is empty") } if len(procid) > 128 { - return fmt.Errorf("procid is too long") + return errors.New("procid is too long") } r.PID = string(procid) @@ -245,11 +245,11 @@ func (r *RFC5424) parseMsgID() error { } if len(msgid) == 0 { - return fmt.Errorf("msgid is empty") + return errors.New("msgid is empty") } if len(msgid) > 32 { - return fmt.Errorf("msgid is too long") + return errors.New("msgid is too long") } r.MsgID = string(msgid) @@ -263,7 +263,7 @@ func (r *RFC5424) parseStructuredData() error { return nil } if r.buf[r.position] != '[' { - return fmt.Errorf("structured data must start with '[' or be '-'") + return errors.New("structured data must start with '[' or be '-'") } prev := byte(0) for r.position < r.len { @@ -281,14 +281,14 @@ func (r *RFC5424) parseStructuredData() error { } r.position++ if !done { - return fmt.Errorf("structured data must end with ']'") + return errors.New("structured data must end with ']'") } return nil } func (r *RFC5424) parseMessage() error { if r.position == r.len { - return fmt.Errorf("message is empty") + return errors.New("message is empty") } message := []byte{} @@ -305,7 +305,7 @@ func (r *RFC5424) parseMessage() error { func (r *RFC5424) Parse(message []byte) error { r.len = len(message) if r.len == 0 { - return fmt.Errorf("syslog line is empty") + return errors.New("syslog line is empty") } r.buf = message @@ -315,7 +315,7 @@ func (r *RFC5424) Parse(message []byte) error { } if r.position >= r.len { - return fmt.Errorf("EOL after PRI") + return errors.New("EOL after PRI") } err = r.parseVersion() @@ -324,7 +324,7 @@ func (r *RFC5424) Parse(message []byte) error { } if r.position >= r.len { - return fmt.Errorf("EOL after Version") + return errors.New("EOL after Version") } err = r.parseTimestamp() @@ -333,7 +333,7 @@ func (r *RFC5424) Parse(message []byte) error { } if r.position >= r.len { - return fmt.Errorf("EOL after Timestamp") + return errors.New("EOL after Timestamp") } err = r.parseHostname() @@ -342,7 +342,7 @@ func (r *RFC5424) Parse(message []byte) error { } if r.position >= r.len { - return fmt.Errorf("EOL after hostname") + return errors.New("EOL after hostname") } err = r.parseAppName() @@ -351,7 +351,7 @@ func (r *RFC5424) Parse(message []byte) error { } if r.position >= r.len { - return fmt.Errorf("EOL after appname") + return errors.New("EOL after appname") } err = r.parseProcID() @@ -360,7 +360,7 @@ func (r *RFC5424) Parse(message []byte) error { } if r.position >= r.len { - return fmt.Errorf("EOL after ProcID") + return errors.New("EOL after ProcID") } err = r.parseMsgID() @@ -369,7 +369,7 @@ func (r *RFC5424) Parse(message []byte) error { } if r.position >= r.len { - return fmt.Errorf("EOL after MSGID") + return errors.New("EOL after MSGID") } err = r.parseStructuredData() @@ -378,7 +378,7 @@ func (r *RFC5424) Parse(message []byte) error { } if r.position >= r.len { - return fmt.Errorf("EOL after SD") + return errors.New("EOL after SD") } err = r.parseMessage() diff --git a/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse_test.go b/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse_test.go index 9a030e6fef4..0938e947fe7 100644 --- a/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse_test.go +++ b/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse_test.go @@ -4,9 +4,9 @@ import ( "testing" "time" - "github.com/crowdsecurity/go-cs-lib/pkg/cstest" - "github.com/stretchr/testify/require" + + "github.com/crowdsecurity/go-cs-lib/cstest" ) func TestPri(t *testing.T) { @@ -25,7 +25,6 @@ func TestPri(t *testing.T) { } for _, test := range tests { - test := test t.Run(test.input, func(t *testing.T) { r := &RFC5424{} r.buf = []byte(test.input) @@ -61,7 +60,6 @@ func TestHostname(t *testing.T) { } for _, test := range tests { - test := test t.Run(test.input, func(t *testing.T) { opts := []RFC5424Option{} if test.strictHostname { @@ -200,7 +198,6 @@ func TestParse(t *testing.T) { } for _, test := range tests { - test := test t.Run(test.name, func(t *testing.T) { r := NewRFC5424Parser(test.opts...) err := r.Parse([]byte(test.input)) diff --git a/pkg/acquisition/modules/syslog/internal/parser/rfc5424/perf_test.go b/pkg/acquisition/modules/syslog/internal/parser/rfc5424/perf_test.go index 318571e91ee..a86c17e8ddf 100644 --- a/pkg/acquisition/modules/syslog/internal/parser/rfc5424/perf_test.go +++ b/pkg/acquisition/modules/syslog/internal/parser/rfc5424/perf_test.go @@ -92,7 +92,6 @@ func BenchmarkParse(b *testing.B) { } var err error for _, test := range tests { - test := test b.Run(test.label, func(b *testing.B) { for i := 0; i < b.N; i++ { r := NewRFC5424Parser() diff --git a/pkg/acquisition/modules/syslog/internal/parser/utils/utils.go b/pkg/acquisition/modules/syslog/internal/parser/utils/utils.go index 8fe717a6ab2..5e0bf8fe771 100644 --- a/pkg/acquisition/modules/syslog/internal/parser/utils/utils.go +++ b/pkg/acquisition/modules/syslog/internal/parser/utils/utils.go @@ -34,7 +34,7 @@ func isValidHostname(s string) bool { last := byte('.') nonNumeric := false // true once we've seen a letter or hyphen partlen := 0 - for i := 0; i < len(s); i++ { + for i := range len(s) { c := s[i] switch { default: diff --git a/pkg/acquisition/modules/syslog/syslog.go b/pkg/acquisition/modules/syslog/syslog.go index 840e372007b..5315096fb9b 100644 --- a/pkg/acquisition/modules/syslog/syslog.go +++ b/pkg/acquisition/modules/syslog/syslog.go @@ -1,6 +1,8 @@ package syslogacquisition import ( + "context" + "errors" "fmt" "net" "strings" @@ -11,7 +13,7 @@ import ( "gopkg.in/tomb.v2" "gopkg.in/yaml.v2" - "github.com/crowdsecurity/go-cs-lib/pkg/trace" + "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/syslog/internal/parser/rfc3164" @@ -29,10 +31,11 @@ type SyslogConfiguration struct { } type SyslogSource struct { - config SyslogConfiguration - logger *log.Entry - server *syslogserver.SyslogServer - serverTomb *tomb.Tomb + metricsLevel int + config SyslogConfiguration + logger *log.Entry + server *syslogserver.SyslogServer + serverTomb *tomb.Tomb } var linesReceived = prometheus.NewCounterVec( @@ -78,11 +81,11 @@ func (s *SyslogSource) GetAggregMetrics() []prometheus.Collector { } func (s *SyslogSource) ConfigureByDSN(dsn string, labels map[string]string, logger *log.Entry, uuid string) error { - return fmt.Errorf("syslog datasource does not support one shot acquisition") + return errors.New("syslog datasource does not support one shot acquisition") } func (s *SyslogSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { - return fmt.Errorf("syslog datasource does not support one shot acquisition") + return errors.New("syslog datasource does not support one shot acquisition") } func validatePort(port int) bool { @@ -103,7 +106,7 @@ func (s *SyslogSource) UnmarshalConfig(yamlConfig []byte) error { } if s.config.Addr == "" { - s.config.Addr = "127.0.0.1" //do we want a usable or secure default ? + s.config.Addr = "127.0.0.1" // do we want a usable or secure default ? } if s.config.Port == 0 { s.config.Port = 514 @@ -121,10 +124,10 @@ func (s *SyslogSource) UnmarshalConfig(yamlConfig []byte) error { return nil } -func (s *SyslogSource) Configure(yamlConfig []byte, logger *log.Entry) error { +func (s *SyslogSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLevel int) error { s.logger = logger s.logger.Infof("Starting syslog datasource configuration") - + s.metricsLevel = MetricsLevel err := s.UnmarshalConfig(yamlConfig) if err != nil { return err @@ -133,7 +136,7 @@ func (s *SyslogSource) Configure(yamlConfig []byte, logger *log.Entry) error { return nil } -func (s *SyslogSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (s *SyslogSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { c := make(chan syslogserver.SyslogMessage) s.server = &syslogserver.SyslogServer{Logger: s.logger.WithField("syslog", "internal"), MaxMessageLen: s.config.MaxMessageLen} s.server.SetChannel(c) @@ -150,7 +153,8 @@ func (s *SyslogSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) } func (s *SyslogSource) buildLogFromSyslog(ts time.Time, hostname string, - appname string, pid string, msg string) string { + appname string, pid string, msg string, +) string { ret := "" if !ts.IsZero() { ret += ts.Format("Jan 2 15:04:05") @@ -176,7 +180,6 @@ func (s *SyslogSource) buildLogFromSyslog(ts time.Time, hostname string, ret += msg } return ret - } func (s *SyslogSource) handleSyslogMsg(out chan types.Event, t *tomb.Tomb, c chan syslogserver.SyslogMessage) error { @@ -198,7 +201,9 @@ func (s *SyslogSource) handleSyslogMsg(out chan types.Event, t *tomb.Tomb, c cha logger := s.logger.WithField("client", syslogLine.Client) logger.Tracef("raw: %s", syslogLine) - linesReceived.With(prometheus.Labels{"source": syslogLine.Client}).Inc() + if s.metricsLevel != configuration.METRICS_NONE { + linesReceived.With(prometheus.Labels{"source": syslogLine.Client}).Inc() + } p := rfc3164.NewRFC3164Parser(rfc3164.WithCurrentYear()) err := p.Parse(syslogLine.Message) if err != nil { @@ -211,8 +216,14 @@ func (s *SyslogSource) handleSyslogMsg(out chan types.Event, t *tomb.Tomb, c cha continue } line = s.buildLogFromSyslog(p2.Timestamp, p2.Hostname, p2.Tag, p2.PID, p2.Message) + if s.metricsLevel != configuration.METRICS_NONE { + linesParsed.With(prometheus.Labels{"source": syslogLine.Client, "type": "rfc5424"}).Inc() + } } else { line = s.buildLogFromSyslog(p.Timestamp, p.Hostname, p.Tag, p.PID, p.Message) + if s.metricsLevel != configuration.METRICS_NONE { + linesParsed.With(prometheus.Labels{"source": syslogLine.Client, "type": "rfc3164"}).Inc() + } } line = strings.TrimSuffix(line, "\n") diff --git a/pkg/acquisition/modules/syslog/syslog_test.go b/pkg/acquisition/modules/syslog/syslog_test.go index 2557f26d502..57fa3e8747b 100644 --- a/pkg/acquisition/modules/syslog/syslog_test.go +++ b/pkg/acquisition/modules/syslog/syslog_test.go @@ -1,19 +1,21 @@ package syslogacquisition import ( + "context" "fmt" "net" "runtime" "testing" "time" - "github.com/crowdsecurity/go-cs-lib/pkg/cstest" - - "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "gopkg.in/tomb.v2" - "github.com/stretchr/testify/assert" + "github.com/crowdsecurity/go-cs-lib/cstest" + + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" + "github.com/crowdsecurity/crowdsec/pkg/types" ) func TestConfigure(t *testing.T) { @@ -51,12 +53,10 @@ listen_addr: 10.0.0`, }, } - subLogger := log.WithFields(log.Fields{ - "type": "syslog", - }) + subLogger := log.WithField("type", "syslog") for _, test := range tests { s := SyslogSource{} - err := s.Configure([]byte(test.config), subLogger) + err := s.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) cstest.AssertErrorContains(t, err, test.expectedErr) } } @@ -81,6 +81,7 @@ func writeToSyslog(logs []string) { } func TestStreamingAcquisition(t *testing.T) { + ctx := context.Background() tests := []struct { name string config string @@ -101,8 +102,10 @@ listen_addr: 127.0.0.1`, listen_port: 4242 listen_addr: 127.0.0.1`, expectedLines: 2, - logs: []string{`<13>1 2021-05-18T11:58:40.828081+02:00 mantis sshd 49340 - [timeQuality isSynced="0" tzKnown="1"] blabla`, - `<13>1 2021-05-18T12:12:37.560695+02:00 mantis sshd 49340 - [timeQuality isSynced="0" tzKnown="1"] blabla2[foobar]`}, + logs: []string{ + `<13>1 2021-05-18T11:58:40.828081+02:00 mantis sshd 49340 - [timeQuality isSynced="0" tzKnown="1"] blabla`, + `<13>1 2021-05-18T12:12:37.560695+02:00 mantis sshd 49340 - [timeQuality isSynced="0" tzKnown="1"] blabla2[foobar]`, + }, }, { name: "RFC3164", @@ -110,10 +113,12 @@ listen_addr: 127.0.0.1`, listen_port: 4242 listen_addr: 127.0.0.1`, expectedLines: 3, - logs: []string{`<13>May 18 12:37:56 mantis sshd[49340]: blabla2[foobar]`, + logs: []string{ + `<13>May 18 12:37:56 mantis sshd[49340]: blabla2[foobar]`, `<13>May 18 12:37:56 mantis sshd[49340]: blabla2`, `<13>May 18 12:37:56 mantis sshd: blabla2`, - `<13>May 18 12:37:56 mantis sshd`}, + `<13>May 18 12:37:56 mantis sshd`, + }, }, } if runtime.GOOS != "windows" { @@ -131,19 +136,16 @@ listen_addr: 127.0.0.1`, } for _, ts := range tests { - ts := ts t.Run(ts.name, func(t *testing.T) { - subLogger := log.WithFields(log.Fields{ - "type": "syslog", - }) + subLogger := log.WithField("type", "syslog") s := SyslogSource{} - err := s.Configure([]byte(ts.config), subLogger) + err := s.Configure([]byte(ts.config), subLogger, configuration.METRICS_NONE) if err != nil { t.Fatalf("could not configure syslog source : %s", err) } tomb := tomb.Tomb{} out := make(chan types.Event) - err = s.StreamingAcquisition(out, &tomb) + err = s.StreamingAcquisition(ctx, out, &tomb) cstest.AssertErrorContains(t, err, ts.expectedErr) if ts.expectedErr != "" { return diff --git a/pkg/acquisition/modules/wineventlog/test_files/Setup.evtx b/pkg/acquisition/modules/wineventlog/test_files/Setup.evtx new file mode 100644 index 00000000000..2c4f8b0f680 Binary files /dev/null and b/pkg/acquisition/modules/wineventlog/test_files/Setup.evtx differ diff --git a/pkg/acquisition/modules/wineventlog/wineventlog.go b/pkg/acquisition/modules/wineventlog/wineventlog.go index f0eca5d13d7..6d522d8d8cb 100644 --- a/pkg/acquisition/modules/wineventlog/wineventlog.go +++ b/pkg/acquisition/modules/wineventlog/wineventlog.go @@ -3,6 +3,7 @@ package wineventlogacquisition import ( + "context" "errors" "github.com/prometheus/client_golang/prometheus" @@ -23,7 +24,7 @@ func (w *WinEventLogSource) UnmarshalConfig(yamlConfig []byte) error { return nil } -func (w *WinEventLogSource) Configure(yamlConfig []byte, logger *log.Entry) error { +func (w *WinEventLogSource) Configure(yamlConfig []byte, logger *log.Entry, metricsLevel int) error { return nil } @@ -59,7 +60,7 @@ func (w *WinEventLogSource) CanRun() error { return errors.New("windows event log acquisition is only supported on Windows") } -func (w *WinEventLogSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (w *WinEventLogSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { return nil } diff --git a/pkg/acquisition/modules/wineventlog/wineventlog_windows.go b/pkg/acquisition/modules/wineventlog/wineventlog_windows.go index 3a78a193267..ca40363155b 100644 --- a/pkg/acquisition/modules/wineventlog/wineventlog_windows.go +++ b/pkg/acquisition/modules/wineventlog/wineventlog_windows.go @@ -1,10 +1,13 @@ package wineventlogacquisition import ( + "context" "encoding/xml" "errors" "fmt" + "net/url" "runtime" + "strconv" "strings" "syscall" "time" @@ -17,7 +20,7 @@ import ( "gopkg.in/tomb.v2" "gopkg.in/yaml.v2" - "github.com/crowdsecurity/go-cs-lib/pkg/trace" + "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" "github.com/crowdsecurity/crowdsec/pkg/types" @@ -29,16 +32,17 @@ type WinEventLogConfiguration struct { EventLevel string `yaml:"event_level"` EventIDs []int `yaml:"event_ids"` XPathQuery string `yaml:"xpath_query"` - EventFile string `yaml:"event_file"` + EventFile string PrettyName string `yaml:"pretty_name"` } type WinEventLogSource struct { - config WinEventLogConfiguration - logger *log.Entry - evtConfig *winlog.SubscribeConfig - query string - name string + metricsLevel int + config WinEventLogConfiguration + logger *log.Entry + evtConfig *winlog.SubscribeConfig + query string + name string } type QueryList struct { @@ -46,10 +50,13 @@ type QueryList struct { } type Select struct { - Path string `xml:"Path,attr"` + Path string `xml:"Path,attr,omitempty"` Query string `xml:",chardata"` } +// 0 identifies the local machine in windows APIs +const localMachine = 0 + var linesRead = prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "cs_winevtlogsource_hits_total", @@ -148,7 +155,7 @@ func (w *WinEventLogSource) buildXpathQuery() (string, error) { queryList := QueryList{Select: Select{Path: w.config.EventChannel, Query: query}} xpathQuery, err := xml.Marshal(queryList) if err != nil { - w.logger.Errorf("Marshal failed: %v", err) + w.logger.Errorf("Serialize failed: %v", err) return "", err } w.logger.Debugf("xpathQuery: %s", xpathQuery) @@ -188,7 +195,9 @@ func (w *WinEventLogSource) getEvents(out chan types.Event, t *tomb.Tomb) error continue } for _, event := range renderedEvents { - linesRead.With(prometheus.Labels{"source": w.name}).Inc() + if w.metricsLevel != configuration.METRICS_NONE { + linesRead.With(prometheus.Labels{"source": w.name}).Inc() + } l := types.Line{} l.Raw = event l.Module = w.GetName() @@ -208,20 +217,28 @@ func (w *WinEventLogSource) getEvents(out chan types.Event, t *tomb.Tomb) error } } -func (w *WinEventLogSource) generateConfig(query string) (*winlog.SubscribeConfig, error) { +func (w *WinEventLogSource) generateConfig(query string, live bool) (*winlog.SubscribeConfig, error) { var config winlog.SubscribeConfig var err error - // Create a subscription signaler. - config.SignalEvent, err = windows.CreateEvent( - nil, // Default security descriptor. - 1, // Manual reset. - 1, // Initial state is signaled. - nil) // Optional name. - if err != nil { - return &config, fmt.Errorf("windows.CreateEvent failed: %v", err) + if live { + // Create a subscription signaler. + config.SignalEvent, err = windows.CreateEvent( + nil, // Default security descriptor. + 1, // Manual reset. + 1, // Initial state is signaled. + nil) // Optional name. + if err != nil { + return &config, fmt.Errorf("windows.CreateEvent failed: %v", err) + } + config.Flags = wevtapi.EvtSubscribeToFutureEvents + } else { + config.ChannelPath, err = syscall.UTF16PtrFromString(w.config.EventFile) + if err != nil { + return &config, fmt.Errorf("syscall.UTF16PtrFromString failed: %v", err) + } + config.Flags = wevtapi.EvtQueryFilePath | wevtapi.EvtQueryForwardDirection } - config.Flags = wevtapi.EvtSubscribeToFutureEvents config.Query, err = syscall.UTF16PtrFromString(query) if err != nil { return &config, fmt.Errorf("syscall.UTF16PtrFromString failed: %v", err) @@ -243,11 +260,11 @@ func (w *WinEventLogSource) UnmarshalConfig(yamlConfig []byte) error { } if w.config.EventChannel != "" && w.config.XPathQuery != "" { - return fmt.Errorf("event_channel and xpath_query are mutually exclusive") + return errors.New("event_channel and xpath_query are mutually exclusive") } if w.config.EventChannel == "" && w.config.XPathQuery == "" { - return fmt.Errorf("event_channel or xpath_query must be set") + return errors.New("event_channel or xpath_query must be set") } w.config.Mode = configuration.TAIL_MODE @@ -270,15 +287,16 @@ func (w *WinEventLogSource) UnmarshalConfig(yamlConfig []byte) error { return nil } -func (w *WinEventLogSource) Configure(yamlConfig []byte, logger *log.Entry) error { +func (w *WinEventLogSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLevel int) error { w.logger = logger + w.metricsLevel = MetricsLevel err := w.UnmarshalConfig(yamlConfig) if err != nil { return err } - w.evtConfig, err = w.generateConfig(w.query) + w.evtConfig, err = w.generateConfig(w.query, true) if err != nil { return err } @@ -287,6 +305,78 @@ func (w *WinEventLogSource) Configure(yamlConfig []byte, logger *log.Entry) erro } func (w *WinEventLogSource) ConfigureByDSN(dsn string, labels map[string]string, logger *log.Entry, uuid string) error { + if !strings.HasPrefix(dsn, "wineventlog://") { + return fmt.Errorf("invalid DSN %s for wineventlog source, must start with wineventlog://", dsn) + } + + w.logger = logger + w.config = WinEventLogConfiguration{} + + dsn = strings.TrimPrefix(dsn, "wineventlog://") + + args := strings.Split(dsn, "?") + + if args[0] == "" { + return errors.New("empty wineventlog:// DSN") + } + + if len(args) > 2 { + return errors.New("too many arguments in DSN") + } + + w.config.EventFile = args[0] + + if len(args) == 2 && args[1] != "" { + params, err := url.ParseQuery(args[1]) + if err != nil { + return fmt.Errorf("failed to parse DSN parameters: %w", err) + } + + for key, value := range params { + switch key { + case "log_level": + if len(value) != 1 { + return errors.New("log_level must be a single value") + } + lvl, err := log.ParseLevel(value[0]) + if err != nil { + return fmt.Errorf("failed to parse log_level: %s", err) + } + w.logger.Logger.SetLevel(lvl) + case "event_id": + for _, id := range value { + evtid, err := strconv.Atoi(id) + if err != nil { + return fmt.Errorf("failed to parse event_id: %s", err) + } + w.config.EventIDs = append(w.config.EventIDs, evtid) + } + case "event_level": + if len(value) != 1 { + return errors.New("event_level must be a single value") + } + w.config.EventLevel = value[0] + } + } + } + + var err error + + //FIXME: handle custom xpath query + w.query, err = w.buildXpathQuery() + + if err != nil { + return fmt.Errorf("buildXpathQuery failed: %w", err) + } + + w.logger.Debugf("query: %s\n", w.query) + + w.evtConfig, err = w.generateConfig(w.query, false) + + if err != nil { + return fmt.Errorf("generateConfig failed: %w", err) + } + return nil } @@ -295,10 +385,57 @@ func (w *WinEventLogSource) GetMode() string { } func (w *WinEventLogSource) SupportedModes() []string { - return []string{configuration.TAIL_MODE} + return []string{configuration.TAIL_MODE, configuration.CAT_MODE} } func (w *WinEventLogSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { + + handle, err := wevtapi.EvtQuery(localMachine, w.evtConfig.ChannelPath, w.evtConfig.Query, w.evtConfig.Flags) + + if err != nil { + return fmt.Errorf("EvtQuery failed: %v", err) + } + + defer winlog.Close(handle) + + publisherCache := make(map[string]windows.Handle) + defer func() { + for _, h := range publisherCache { + winlog.Close(h) + } + }() + +OUTER_LOOP: + for { + select { + case <-t.Dying(): + w.logger.Infof("wineventlog is dying") + return nil + default: + evts, err := w.getXMLEvents(w.evtConfig, publisherCache, handle, 500) + if err == windows.ERROR_NO_MORE_ITEMS { + log.Info("No more items") + break OUTER_LOOP + } else if err != nil { + return fmt.Errorf("getXMLEvents failed: %v", err) + } + w.logger.Debugf("Got %d events", len(evts)) + for _, evt := range evts { + w.logger.Tracef("Event: %s", evt) + if w.metricsLevel != configuration.METRICS_NONE { + linesRead.With(prometheus.Labels{"source": w.name}).Inc() + } + l := types.Line{} + l.Raw = evt + l.Module = w.GetName() + l.Labels = w.config.Labels + l.Time = time.Now() + l.Src = w.name + l.Process = true + out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE} + } + } + } return nil } @@ -321,7 +458,7 @@ func (w *WinEventLogSource) CanRun() error { return nil } -func (w *WinEventLogSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (w *WinEventLogSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/wineventlog/streaming") return w.getEvents(out, t) diff --git a/pkg/acquisition/modules/wineventlog/wineventlog_test.go b/pkg/acquisition/modules/wineventlog/wineventlog_windows_test.go similarity index 64% rename from pkg/acquisition/modules/wineventlog/wineventlog_test.go rename to pkg/acquisition/modules/wineventlog/wineventlog_windows_test.go index 54fddc3d8cb..9afef963669 100644 --- a/pkg/acquisition/modules/wineventlog/wineventlog_test.go +++ b/pkg/acquisition/modules/wineventlog/wineventlog_windows_test.go @@ -1,25 +1,25 @@ //go:build windows -// +build windows package wineventlogacquisition import ( - "runtime" + "context" "testing" "time" + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/sys/windows/svc/eventlog" "gopkg.in/tomb.v2" ) func TestBadConfiguration(t *testing.T) { - if runtime.GOOS != "windows" { - t.Skip("Skipping test on non-windows OS") - } + exprhelpers.Init(nil) + tests := []struct { config string expectedErr string @@ -53,20 +53,17 @@ xpath_query: test`, }, } - subLogger := log.WithFields(log.Fields{ - "type": "windowseventlog", - }) + subLogger := log.WithField("type", "windowseventlog") for _, test := range tests { f := WinEventLogSource{} - err := f.Configure([]byte(test.config), subLogger) + err := f.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) assert.Contains(t, err.Error(), test.expectedErr) } } func TestQueryBuilder(t *testing.T) { - if runtime.GOOS != "windows" { - t.Skip("Skipping test on non-windows OS") - } + exprhelpers.Init(nil) + tests := []struct { config string expectedQuery string @@ -112,12 +109,10 @@ event_level: bla`, expectedErr: "invalid log level", }, } - subLogger := log.WithFields(log.Fields{ - "type": "windowseventlog", - }) + subLogger := log.WithField("type", "windowseventlog") for _, test := range tests { f := WinEventLogSource{} - f.Configure([]byte(test.config), subLogger) + f.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) q, err := f.buildXpathQuery() if test.expectedErr != "" { if err == nil { @@ -125,16 +120,15 @@ event_level: bla`, } assert.Contains(t, err.Error(), test.expectedErr) } else { - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, test.expectedQuery, q) } } } func TestLiveAcquisition(t *testing.T) { - if runtime.GOOS != "windows" { - t.Skip("Skipping test on non-windows OS") - } + exprhelpers.Init(nil) + ctx := context.Background() tests := []struct { config string @@ -180,9 +174,7 @@ event_ids: expectedLines: nil, }, } - subLogger := log.WithFields(log.Fields{ - "type": "windowseventlog", - }) + subLogger := log.WithField("type", "windowseventlog") evthandler, err := eventlog.Open("Application") @@ -194,8 +186,8 @@ event_ids: to := &tomb.Tomb{} c := make(chan types.Event) f := WinEventLogSource{} - f.Configure([]byte(test.config), subLogger) - f.StreamingAcquisition(c, to) + f.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) + f.StreamingAcquisition(ctx, c, to) time.Sleep(time.Second) lines := test.expectedLines go func() { @@ -222,12 +214,90 @@ event_ids: } } if test.expectedLines == nil { - assert.Equal(t, 0, len(linesRead)) + assert.Empty(t, linesRead) } else { - assert.Equal(t, len(test.expectedLines), len(linesRead)) assert.Equal(t, test.expectedLines, linesRead) } to.Kill(nil) to.Wait() } } + +func TestOneShotAcquisition(t *testing.T) { + tests := []struct { + name string + dsn string + expectedCount int + expectedErr string + expectedConfigureErr string + }{ + { + name: "non-existing file", + dsn: `wineventlog://foo.evtx`, + expectedCount: 0, + expectedErr: "The system cannot find the file specified.", + }, + { + name: "empty DSN", + dsn: `wineventlog://`, + expectedCount: 0, + expectedConfigureErr: "empty wineventlog:// DSN", + }, + { + name: "existing file", + dsn: `wineventlog://test_files/Setup.evtx`, + expectedCount: 24, + expectedErr: "", + }, + { + name: "filter on event_id", + dsn: `wineventlog://test_files/Setup.evtx?event_id=2`, + expectedCount: 1, + }, + { + name: "filter on event_id", + dsn: `wineventlog://test_files/Setup.evtx?event_id=2&event_id=3`, + expectedCount: 24, + }, + } + + exprhelpers.Init(nil) + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + lineCount := 0 + to := &tomb.Tomb{} + c := make(chan types.Event) + f := WinEventLogSource{} + err := f.ConfigureByDSN(test.dsn, map[string]string{"type": "wineventlog"}, log.WithField("type", "windowseventlog"), "") + + if test.expectedConfigureErr != "" { + assert.Contains(t, err.Error(), test.expectedConfigureErr) + return + } + + require.NoError(t, err) + + go func() { + for { + select { + case <-c: + lineCount++ + case <-to.Dying(): + return + } + } + }() + + err = f.OneShotAcquisition(c, to) + if test.expectedErr != "" { + assert.Contains(t, err.Error(), test.expectedErr) + } else { + require.NoError(t, err) + + time.Sleep(2 * time.Second) + assert.Equal(t, test.expectedCount, lineCount) + } + }) + } +} diff --git a/pkg/acquisition/s3.go b/pkg/acquisition/s3.go new file mode 100644 index 00000000000..73343b0408d --- /dev/null +++ b/pkg/acquisition/s3.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_s3 + +package acquisition + +import ( + s3acquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/s3" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("s3", func() DataSource { return &s3acquisition.S3Source{} }) +} diff --git a/pkg/acquisition/syslog.go b/pkg/acquisition/syslog.go new file mode 100644 index 00000000000..f62cc23b916 --- /dev/null +++ b/pkg/acquisition/syslog.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_syslog + +package acquisition + +import ( + syslogacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/syslog" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("syslog", func() DataSource { return &syslogacquisition.SyslogSource{} }) +} diff --git a/pkg/acquisition/wineventlog.go b/pkg/acquisition/wineventlog.go new file mode 100644 index 00000000000..0c4889a3f5c --- /dev/null +++ b/pkg/acquisition/wineventlog.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_wineventlog + +package acquisition + +import ( + wineventlogacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/wineventlog" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("wineventlog", func() DataSource { return &wineventlogacquisition.WinEventLogSource{} }) +} diff --git a/pkg/alertcontext/alertcontext.go b/pkg/alertcontext/alertcontext.go index 29c13e3f3bc..16ebc6d0ac2 100644 --- a/pkg/alertcontext/alertcontext.go +++ b/pkg/alertcontext/alertcontext.go @@ -3,12 +3,12 @@ package alertcontext import ( "encoding/json" "fmt" + "slices" "strconv" - "github.com/antonmedv/expr" - "github.com/antonmedv/expr/vm" + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" log "github.com/sirupsen/logrus" - "golang.org/x/exp/slices" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/models" @@ -16,12 +16,10 @@ import ( ) const ( - maxContextValueLen = 4000 + MaxContextValueLen = 4000 ) -var ( - alertContext = Context{} -) +var alertContext = Context{} type Context struct { ContextToSend map[string][]string @@ -34,25 +32,27 @@ func ValidateContextExpr(key string, expressions []string) error { for _, expression := range expressions { _, err := expr.Compile(expression, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { - return fmt.Errorf("compilation of '%s' failed: %v", expression, err) + return fmt.Errorf("compilation of '%s' failed: %w", expression, err) } } + return nil } func NewAlertContext(contextToSend map[string][]string, valueLength int) error { - var clog = log.New() + clog := log.New() if err := types.ConfigureLogger(clog); err != nil { - return fmt.Errorf("couldn't create logger for alert context: %s", err) + return fmt.Errorf("couldn't create logger for alert context: %w", err) } if valueLength == 0 { - clog.Debugf("No console context value length provided, using default: %d", maxContextValueLen) - valueLength = maxContextValueLen + clog.Debugf("No console context value length provided, using default: %d", MaxContextValueLen) + valueLength = MaxContextValueLen } - if valueLength > maxContextValueLen { - clog.Debugf("Provided console context value length (%d) is higher than the maximum, using default: %d", valueLength, maxContextValueLen) - valueLength = maxContextValueLen + + if valueLength > MaxContextValueLen { + clog.Debugf("Provided console context value length (%d) is higher than the maximum, using default: %d", valueLength, MaxContextValueLen) + valueLength = MaxContextValueLen } alertContext = Context{ @@ -63,30 +63,36 @@ func NewAlertContext(contextToSend map[string][]string, valueLength int) error { } for key, values := range contextToSend { - alertContext.ContextToSendCompiled[key] = make([]*vm.Program, 0) + if _, ok := alertContext.ContextToSend[key]; !ok { + alertContext.ContextToSend[key] = make([]string, 0) + } + + if _, ok := alertContext.ContextToSendCompiled[key]; !ok { + alertContext.ContextToSendCompiled[key] = make([]*vm.Program, 0) + } + for _, value := range values { valueCompiled, err := expr.Compile(value, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { - return fmt.Errorf("compilation of '%s' context value failed: %v", value, err) + return fmt.Errorf("compilation of '%s' context value failed: %w", value, err) } + alertContext.ContextToSendCompiled[key] = append(alertContext.ContextToSendCompiled[key], valueCompiled) + alertContext.ContextToSend[key] = append(alertContext.ContextToSend[key], value) } } return nil } -func truncate(values []string, contextValueLen int) (string, error) { - var ret string +func TruncateContext(values []string, contextValueLen int) (string, error) { valueByte, err := json.Marshal(values) if err != nil { - return "", fmt.Errorf("unable to dump metas: %s", err) + return "", fmt.Errorf("unable to dump metas: %w", err) } - ret = string(valueByte) - for { - if len(ret) <= contextValueLen { - break - } + + ret := string(valueByte) + for len(ret) > contextValueLen { // if there is only 1 value left and that the size is too big, truncate it if len(values) == 1 { valueToTruncate := values[0] @@ -98,12 +104,15 @@ func truncate(values []string, contextValueLen int) (string, error) { // if there is multiple value inside, just remove the last one values = values[:len(values)-1] } + valueByte, err = json.Marshal(values) if err != nil { - return "", fmt.Errorf("unable to dump metas: %s", err) + return "", fmt.Errorf("unable to dump metas: %w", err) } + ret = string(valueByte) } + return ret, nil } @@ -112,41 +121,49 @@ func EventToContext(events []types.Event) (models.Meta, []error) { metas := make([]*models.MetaItems0, 0) tmpContext := make(map[string][]string) + for _, evt := range events { for key, values := range alertContext.ContextToSendCompiled { if _, ok := tmpContext[key]; !ok { tmpContext[key] = make([]string, 0) } + for _, value := range values { var val string + output, err := expr.Run(value, map[string]interface{}{"evt": evt}) if err != nil { - errors = append(errors, fmt.Errorf("failed to get value for %s : %v", key, err)) + errors = append(errors, fmt.Errorf("failed to get value for %s: %w", key, err)) continue } + switch out := output.(type) { case string: val = out case int: val = strconv.Itoa(out) default: - errors = append(errors, fmt.Errorf("unexpected return type for %s : %T", key, output)) + errors = append(errors, fmt.Errorf("unexpected return type for %s: %T", key, output)) continue } + if val != "" && !slices.Contains(tmpContext[key], val) { tmpContext[key] = append(tmpContext[key], val) } } } } + for key, values := range tmpContext { if len(values) == 0 { continue } - valueStr, err := truncate(values, alertContext.ContextValueLen) + + valueStr, err := TruncateContext(values, alertContext.ContextValueLen) if err != nil { - log.Warningf(err.Error()) + log.Warning(err.Error()) } + meta := models.MetaItems0{ Key: key, Value: valueStr, @@ -155,5 +172,6 @@ func EventToContext(events []types.Event) (models.Meta, []error) { } ret := models.Meta(metas) + return ret, errors } diff --git a/pkg/alertcontext/alertcontext_test.go b/pkg/alertcontext/alertcontext_test.go index 2e7e71bd62f..c111d1bbcfb 100644 --- a/pkg/alertcontext/alertcontext_test.go +++ b/pkg/alertcontext/alertcontext_test.go @@ -4,9 +4,11 @@ import ( "fmt" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/types" - "github.com/stretchr/testify/assert" ) func TestNewAlertContext(t *testing.T) { @@ -29,8 +31,7 @@ func TestNewAlertContext(t *testing.T) { for _, test := range tests { fmt.Printf("Running test '%s'\n", test.name) err := NewAlertContext(test.contextToSend, test.valueLength) - assert.ErrorIs(t, err, test.expectedErr) - + require.ErrorIs(t, err, test.expectedErr) } } @@ -193,7 +194,7 @@ func TestEventToContext(t *testing.T) { for _, test := range tests { fmt.Printf("Running test '%s'\n", test.name) err := NewAlertContext(test.contextToSend, test.valueLength) - assert.ErrorIs(t, err, nil) + require.NoError(t, err) metas, _ := EventToContext(test.events) assert.ElementsMatch(t, test.expectedResult, metas) diff --git a/pkg/alertcontext/config.go b/pkg/alertcontext/config.go new file mode 100644 index 00000000000..6ef877619e4 --- /dev/null +++ b/pkg/alertcontext/config.go @@ -0,0 +1,142 @@ +package alertcontext + +import ( + "encoding/json" + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "slices" + + log "github.com/sirupsen/logrus" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" +) + +var ErrNoContextData = errors.New("no context to send") + +// this file is here to avoid circular dependencies between the configuration and the hub + +// HubItemWrapper is a wrapper around a hub item to unmarshal only the context part +// because there are other fields like name etc. +type HubItemWrapper struct { + Context map[string][]string `yaml:"context"` +} + +// mergeContext adds the context from src to dest. +func mergeContext(dest map[string][]string, src map[string][]string) error { + if len(src) == 0 { + return ErrNoContextData + } + + for k, v := range src { + if _, ok := dest[k]; !ok { + dest[k] = make([]string, 0) + } + + for _, s := range v { + if !slices.Contains(dest[k], s) { + dest[k] = append(dest[k], s) + } + } + } + + return nil +} + +// addContextFromItem merges the context from an item into the context to send to the console. +func addContextFromItem(toSend map[string][]string, item *cwhub.Item) error { + filePath := item.State.LocalPath + log.Tracef("loading console context from %s", filePath) + + content, err := os.ReadFile(filePath) + if err != nil { + return err + } + + wrapper := &HubItemWrapper{} + + err = yaml.Unmarshal(content, wrapper) + if err != nil { + return fmt.Errorf("%s: %w", filePath, err) + } + + err = mergeContext(toSend, wrapper.Context) + if err != nil { + // having an empty hub item deserves an error + log.Errorf("while merging context from %s: %s. Note that context data should be under the 'context:' key, the top-level is metadata.", filePath, err) + } + + return nil +} + +// addContextFromFile merges the context from a file into the context to send to the console. +func addContextFromFile(toSend map[string][]string, filePath string) error { + log.Tracef("loading console context from %s", filePath) + + content, err := os.ReadFile(filePath) + if err != nil { + return err + } + + newContext := make(map[string][]string, 0) + + err = yaml.Unmarshal(content, newContext) + if err != nil { + return fmt.Errorf("%s: %w", filePath, err) + } + + err = mergeContext(toSend, newContext) + if err != nil && !errors.Is(err, ErrNoContextData) { + // having an empty console/context.yaml is not an error + return err + } + + return nil +} + +// LoadConsoleContext loads the context from the hub (if provided) and the file console_context_path. +func LoadConsoleContext(c *csconfig.Config, hub *cwhub.Hub) error { + c.Crowdsec.ContextToSend = make(map[string][]string, 0) + + if hub != nil { + for _, item := range hub.GetInstalledByType(cwhub.CONTEXTS, true) { + // context in item files goes under the key 'context' + if err := addContextFromItem(c.Crowdsec.ContextToSend, item); err != nil { + return err + } + } + } + + ignoreMissing := false + + if c.Crowdsec.ConsoleContextPath != "" { + // if it's provided, it must exist + if _, err := os.Stat(c.Crowdsec.ConsoleContextPath); err != nil { + return fmt.Errorf("while checking console_context_path: %w", err) + } + } else { + c.Crowdsec.ConsoleContextPath = filepath.Join(c.ConfigPaths.ConfigDir, "console", "context.yaml") + ignoreMissing = true + } + + if err := addContextFromFile(c.Crowdsec.ContextToSend, c.Crowdsec.ConsoleContextPath); err != nil { + if !errors.Is(err, fs.ErrNotExist) { + return err + } else if !ignoreMissing { + log.Warningf("while merging context from %s: %s", c.Crowdsec.ConsoleContextPath, err) + } + } + + feedback, err := json.Marshal(c.Crowdsec.ContextToSend) + if err != nil { + return fmt.Errorf("serializing console context: %s", err) + } + + log.Debugf("console context to send: %s", feedback) + + return nil +} diff --git a/pkg/apiclient/alerts_service.go b/pkg/apiclient/alerts_service.go index dd2ba2975ff..a3da84d306e 100644 --- a/pkg/apiclient/alerts_service.go +++ b/pkg/apiclient/alerts_service.go @@ -10,8 +10,6 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/models" ) -// type ApiAlerts service - type AlertsService service type AlertsListOpts struct { @@ -49,35 +47,35 @@ type AlertsDeleteOpts struct { } func (s *AlertsService) Add(ctx context.Context, alerts models.AddAlertsRequest) (*models.AddAlertsResponse, *Response, error) { - - var added_ids models.AddAlertsResponse - u := fmt.Sprintf("%s/alerts", s.client.URLPrefix) + req, err := s.client.NewRequest(http.MethodPost, u, &alerts) if err != nil { return nil, nil, err } - resp, err := s.client.Do(ctx, req, &added_ids) + addedIds := models.AddAlertsResponse{} + + resp, err := s.client.Do(ctx, req, &addedIds) if err != nil { return nil, resp, err } - return &added_ids, resp, nil + + return &addedIds, resp, nil } // to demo query arguments func (s *AlertsService) List(ctx context.Context, opts AlertsListOpts) (*models.GetAlertsResponse, *Response, error) { - var alerts models.GetAlertsResponse - var URI string u := fmt.Sprintf("%s/alerts", s.client.URLPrefix) + params, err := qs.Values(opts) if err != nil { return nil, nil, fmt.Errorf("building query: %w", err) } + + URI := u if len(params) > 0 { - URI = fmt.Sprintf("%s?%s", u, params.Encode()) - } else { - URI = u + URI = fmt.Sprintf("%s?%s", URI, params.Encode()) } req, err := s.client.NewRequest(http.MethodGet, URI, nil) @@ -85,20 +83,23 @@ func (s *AlertsService) List(ctx context.Context, opts AlertsListOpts) (*models. return nil, nil, fmt.Errorf("building request: %w", err) } + alerts := models.GetAlertsResponse{} + resp, err := s.client.Do(ctx, req, &alerts) if err != nil { return nil, resp, fmt.Errorf("performing request: %w", err) } + return &alerts, resp, nil } // to demo query arguments func (s *AlertsService) Delete(ctx context.Context, opts AlertsDeleteOpts) (*models.DeleteAlertsResponse, *Response, error) { - var alerts models.DeleteAlertsResponse params, err := qs.Values(opts) if err != nil { return nil, nil, err } + u := fmt.Sprintf("%s/alerts?%s", s.client.URLPrefix, params.Encode()) req, err := s.client.NewRequest(http.MethodDelete, u, nil) @@ -106,31 +107,35 @@ func (s *AlertsService) Delete(ctx context.Context, opts AlertsDeleteOpts) (*mod return nil, nil, err } + alerts := models.DeleteAlertsResponse{} + resp, err := s.client.Do(ctx, req, &alerts) if err != nil { return nil, resp, err } + return &alerts, resp, nil } -func (s *AlertsService) DeleteOne(ctx context.Context, alert_id string) (*models.DeleteAlertsResponse, *Response, error) { - var alerts models.DeleteAlertsResponse - u := fmt.Sprintf("%s/alerts/%s", s.client.URLPrefix, alert_id) +func (s *AlertsService) DeleteOne(ctx context.Context, alertID string) (*models.DeleteAlertsResponse, *Response, error) { + u := fmt.Sprintf("%s/alerts/%s", s.client.URLPrefix, alertID) req, err := s.client.NewRequest(http.MethodDelete, u, nil) if err != nil { return nil, nil, err } + alerts := models.DeleteAlertsResponse{} + resp, err := s.client.Do(ctx, req, &alerts) if err != nil { return nil, resp, err } + return &alerts, resp, nil } func (s *AlertsService) GetByID(ctx context.Context, alertID int) (*models.Alert, *Response, error) { - var alert models.Alert u := fmt.Sprintf("%s/alerts/%d", s.client.URLPrefix, alertID) req, err := s.client.NewRequest(http.MethodGet, u, nil) @@ -138,9 +143,12 @@ func (s *AlertsService) GetByID(ctx context.Context, alertID int) (*models.Alert return nil, nil, err } + alert := models.Alert{} + resp, err := s.client.Do(ctx, req, &alert) if err != nil { return nil, nil, err } + return &alert, resp, nil } diff --git a/pkg/apiclient/alerts_service_test.go b/pkg/apiclient/alerts_service_test.go index aa5039f0bc7..0d1ff41685f 100644 --- a/pkg/apiclient/alerts_service_test.go +++ b/pkg/apiclient/alerts_service_test.go @@ -5,14 +5,14 @@ import ( "fmt" "net/http" "net/url" - "reflect" "testing" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/crowdsecurity/go-cs-lib/pkg/version" + "github.com/crowdsecurity/go-cs-lib/cstest" + "github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/crowdsec/pkg/models" ) @@ -25,31 +25,28 @@ func TestAlertsListAsMachine(t *testing.T) { w.WriteHeader(http.StatusOK) w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`)) }) + log.Printf("URL is %s", urlx) + apiURL, err := url.Parse(urlx + "/") - if err != nil { - log.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) + client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", }) - - if err != nil { - log.Fatalf("new api client: %s", err) - } + require.NoError(t, err) defer teardown() mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) { - if r.URL.RawQuery == "ip=1.2.3.4" { testMethod(t, r, "GET") w.WriteHeader(http.StatusOK) fmt.Fprintf(w, `null`) + return } @@ -105,36 +102,26 @@ func TestAlertsListAsMachine(t *testing.T) { ]`) }) - tcapacity := int32(5) - tduration := "59m49.264032632s" - torigin := "crowdsec" tscenario := "crowdsecurity/ssh-bf" tscope := "Ip" - ttype := "ban" tvalue := "1.1.1.172" ttimestamp := "2020-11-28 10:20:46 +0000 UTC" - teventscount := int32(6) - tleakspeed := "10s" tmessage := "Ip 1.1.1.172 performed 'crowdsecurity/ssh-bf' (6 events over 2.920062ms) at 2020-11-28 10:20:46.845619968 +0100 CET m=+5.903899761" - tscenariohash := "4441dcff07020f6690d998b7101e642359ba405c2abb83565bbbdcee36de280f" - tscenarioversion := "0.1" - tstartat := "2020-11-28 10:20:46.842701127 +0100 +0100" - tstopat := "2020-11-28 10:20:46.845621385 +0100 +0100" expected := models.GetAlertsResponse{ &models.Alert{ - Capacity: &tcapacity, + Capacity: ptr.Of(int32(5)), CreatedAt: "2020-11-28T10:20:47+01:00", Decisions: []*models.Decision{ { - Duration: &tduration, + Duration: ptr.Of("59m49.264032632s"), ID: 1, - Origin: &torigin, + Origin: ptr.Of("crowdsec"), Scenario: &tscenario, Scope: &tscope, - Simulated: new(bool), //false, - Type: &ttype, + Simulated: ptr.Of(false), + Type: ptr.Of("ban"), Value: &tvalue, }, }, @@ -165,16 +152,16 @@ func TestAlertsListAsMachine(t *testing.T) { Timestamp: &ttimestamp, }, }, - EventsCount: &teventscount, + EventsCount: ptr.Of(int32(6)), ID: 1, - Leakspeed: &tleakspeed, + Leakspeed: ptr.Of("10s"), MachineID: "test", Message: &tmessage, Remediation: false, Scenario: &tscenario, - ScenarioHash: &tscenariohash, - ScenarioVersion: &tscenarioversion, - Simulated: new(bool), //(false), + ScenarioHash: ptr.Of("4441dcff07020f6690d998b7101e642359ba405c2abb83565bbbdcee36de280f"), + ScenarioVersion: ptr.Of("0.1"), + Simulated: ptr.Of(false), Source: &models.Source{ AsName: "Cloudflare Inc", AsNumber: "", @@ -186,37 +173,27 @@ func TestAlertsListAsMachine(t *testing.T) { Scope: &tscope, Value: &tvalue, }, - StartAt: &tstartat, - StopAt: &tstopat, + StartAt: ptr.Of("2020-11-28 10:20:46.842701127 +0100 +0100"), + StopAt: ptr.Of("2020-11-28 10:20:46.845621385 +0100 +0100"), }, } - //log.Debugf("data : -> %s", spew.Sdump(alerts)) - //log.Debugf("resp : -> %s", spew.Sdump(resp)) - //log.Debugf("expected : -> %s", spew.Sdump(expected)) - //first one returns data + // log.Debugf("data : -> %s", spew.Sdump(alerts)) + // log.Debugf("resp : -> %s", spew.Sdump(resp)) + // log.Debugf("expected : -> %s", spew.Sdump(expected)) + // first one returns data alerts, resp, err := client.Alerts.List(context.Background(), AlertsListOpts{}) - if err != nil { - log.Errorf("test Unable to list alerts : %+v", err) - } - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) + assert.Equal(t, expected, *alerts) + + // this one doesn't + filter := AlertsListOpts{IPEquals: ptr.Of("1.2.3.4")} - if !reflect.DeepEqual(*alerts, expected) { - t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected) - } - //this one doesn't - filter := AlertsListOpts{IPEquals: new(string)} - *filter.IPEquals = "1.2.3.4" alerts, resp, err = client.Alerts.List(context.Background(), filter) - if err != nil { - log.Errorf("test Unable to list alerts : %+v", err) - } - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - assert.Equal(t, 0, len(*alerts)) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) + assert.Empty(t, *alerts) } func TestAlertsGetAsMachine(t *testing.T) { @@ -228,23 +205,20 @@ func TestAlertsGetAsMachine(t *testing.T) { w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`)) }) log.Printf("URL is %s", urlx) + apiURL, err := url.Parse(urlx + "/") - if err != nil { - log.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) + client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", }) - - if err != nil { - log.Fatalf("new api client: %s", err) - } + require.NoError(t, err) defer teardown() + mux.HandleFunc("/alerts/2", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "GET") w.WriteHeader(http.StatusNotFound) @@ -304,34 +278,24 @@ func TestAlertsGetAsMachine(t *testing.T) { }`) }) - tcapacity := int32(5) - tduration := "59m49.264032632s" - torigin := "crowdsec" tscenario := "crowdsecurity/ssh-bf" tscope := "Ip" ttype := "ban" tvalue := "1.1.1.172" ttimestamp := "2020-11-28 10:20:46 +0000 UTC" - teventscount := int32(6) - tleakspeed := "10s" - tmessage := "Ip 1.1.1.172 performed 'crowdsecurity/ssh-bf' (6 events over 2.920062ms) at 2020-11-28 10:20:46.845619968 +0100 CET m=+5.903899761" - tscenariohash := "4441dcff07020f6690d998b7101e642359ba405c2abb83565bbbdcee36de280f" - tscenarioversion := "0.1" - tstartat := "2020-11-28 10:20:46.842701127 +0100 +0100" - tstopat := "2020-11-28 10:20:46.845621385 +0100 +0100" expected := &models.Alert{ - Capacity: &tcapacity, + Capacity: ptr.Of(int32(5)), CreatedAt: "2020-11-28T10:20:47+01:00", Decisions: []*models.Decision{ { - Duration: &tduration, + Duration: ptr.Of("59m49.264032632s"), ID: 1, - Origin: &torigin, + Origin: ptr.Of("crowdsec"), Scenario: &tscenario, Scope: &tscope, - Simulated: new(bool), //false, + Simulated: ptr.Of(false), Type: &ttype, Value: &tvalue, }, @@ -363,16 +327,16 @@ func TestAlertsGetAsMachine(t *testing.T) { Timestamp: &ttimestamp, }, }, - EventsCount: &teventscount, + EventsCount: ptr.Of(int32(6)), ID: 1, - Leakspeed: &tleakspeed, + Leakspeed: ptr.Of("10s"), MachineID: "test", - Message: &tmessage, + Message: ptr.Of("Ip 1.1.1.172 performed 'crowdsecurity/ssh-bf' (6 events over 2.920062ms) at 2020-11-28 10:20:46.845619968 +0100 CET m=+5.903899761"), Remediation: false, Scenario: &tscenario, - ScenarioHash: &tscenariohash, - ScenarioVersion: &tscenarioversion, - Simulated: new(bool), //(false), + ScenarioHash: ptr.Of("4441dcff07020f6690d998b7101e642359ba405c2abb83565bbbdcee36de280f"), + ScenarioVersion: ptr.Of("0.1"), + Simulated: ptr.Of(false), Source: &models.Source{ AsName: "Cloudflare Inc", AsNumber: "", @@ -384,24 +348,18 @@ func TestAlertsGetAsMachine(t *testing.T) { Scope: &tscope, Value: &tvalue, }, - StartAt: &tstartat, - StopAt: &tstopat, + StartAt: ptr.Of("2020-11-28 10:20:46.842701127 +0100 +0100"), + StopAt: ptr.Of("2020-11-28 10:20:46.845621385 +0100 +0100"), } alerts, resp, err := client.Alerts.GetByID(context.Background(), 1) require.NoError(t, err) - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) + assert.Equal(t, *expected, *alerts) - if !reflect.DeepEqual(*alerts, *expected) { - t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected) - } - - //fail + // fail _, _, err = client.Alerts.GetByID(context.Background(), 2) - assert.Contains(t, fmt.Sprintf("%s", err), "API error: object not found") - + cstest.RequireErrorMessage(t, err, "API error: object not found") } func TestAlertsCreateAsMachine(t *testing.T) { @@ -412,39 +370,36 @@ func TestAlertsCreateAsMachine(t *testing.T) { w.WriteHeader(http.StatusOK) w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`)) }) + mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "POST") w.WriteHeader(http.StatusOK) w.Write([]byte(`["3"]`)) }) + log.Printf("URL is %s", urlx) + apiURL, err := url.Parse(urlx + "/") - if err != nil { - log.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) + client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", }) - - if err != nil { - log.Fatalf("new api client: %s", err) - } + require.NoError(t, err) defer teardown() + alert := models.AddAlertsRequest{} alerts, resp, err := client.Alerts.Add(context.Background(), alert) require.NoError(t, err) + expected := &models.AddAlertsResponse{"3"} - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - if !reflect.DeepEqual(*alerts, *expected) { - t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected) - } + + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) + assert.Equal(t, *expected, *alerts) } func TestAlertsDeleteAsMachine(t *testing.T) { @@ -455,40 +410,35 @@ func TestAlertsDeleteAsMachine(t *testing.T) { w.WriteHeader(http.StatusOK) w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`)) }) + mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "DELETE") - assert.Equal(t, r.URL.RawQuery, "ip=1.2.3.4") + assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery) w.WriteHeader(http.StatusOK) w.Write([]byte(`{"message":"0 deleted alerts"}`)) }) + log.Printf("URL is %s", urlx) + apiURL, err := url.Parse(urlx + "/") - if err != nil { - log.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) + client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", }) - - if err != nil { - log.Fatalf("new api client: %s", err) - } + require.NoError(t, err) defer teardown() - alert := AlertsDeleteOpts{IPEquals: new(string)} - *alert.IPEquals = "1.2.3.4" + + alert := AlertsDeleteOpts{IPEquals: ptr.Of("1.2.3.4")} alerts, resp, err := client.Alerts.Delete(context.Background(), alert) require.NoError(t, err) expected := &models.DeleteAlertsResponse{NbDeleted: ""} - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - if !reflect.DeepEqual(*alerts, *expected) { - t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected) - } + + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) + assert.Equal(t, *expected, *alerts) } diff --git a/pkg/apiclient/auth.go b/pkg/apiclient/auth.go deleted file mode 100644 index 84df74456b1..00000000000 --- a/pkg/apiclient/auth.go +++ /dev/null @@ -1,330 +0,0 @@ -package apiclient - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "math/rand" - "net/http" - "net/http/httputil" - "net/url" - "sync" - "time" - - "github.com/go-openapi/strfmt" - "github.com/pkg/errors" - log "github.com/sirupsen/logrus" - - "github.com/crowdsecurity/crowdsec/pkg/fflag" - "github.com/crowdsecurity/crowdsec/pkg/models" -) - -type APIKeyTransport struct { - APIKey string - // Transport is the underlying HTTP transport to use when making requests. - // It will default to http.DefaultTransport if nil. - Transport http.RoundTripper - URL *url.URL - VersionPrefix string - UserAgent string -} - -// RoundTrip implements the RoundTripper interface. -func (t *APIKeyTransport) RoundTrip(req *http.Request) (*http.Response, error) { - if t.APIKey == "" { - return nil, errors.New("APIKey is empty") - } - - // We must make a copy of the Request so - // that we don't modify the Request we were given. This is required by the - // specification of http.RoundTripper. - req = cloneRequest(req) - req.Header.Add("X-Api-Key", t.APIKey) - if t.UserAgent != "" { - req.Header.Add("User-Agent", t.UserAgent) - } - log.Debugf("req-api: %s %s", req.Method, req.URL.String()) - if log.GetLevel() >= log.TraceLevel { - dump, _ := httputil.DumpRequest(req, true) - log.Tracef("auth-api request: %s", string(dump)) - } - // Make the HTTP request. - resp, err := t.transport().RoundTrip(req) - if err != nil { - log.Errorf("auth-api: auth with api key failed return nil response, error: %s", err) - return resp, err - } - if log.GetLevel() >= log.TraceLevel { - dump, _ := httputil.DumpResponse(resp, true) - log.Tracef("auth-api response: %s", string(dump)) - } - - log.Debugf("resp-api: http %d", resp.StatusCode) - - return resp, err -} - -func (t *APIKeyTransport) Client() *http.Client { - return &http.Client{Transport: t} -} - -func (t *APIKeyTransport) transport() http.RoundTripper { - if t.Transport != nil { - return t.Transport - } - return http.DefaultTransport -} - -type retryRoundTripper struct { - next http.RoundTripper - maxAttempts int - retryStatusCodes []int - withBackOff bool - onBeforeRequest func(attempt int) -} - -func (r retryRoundTripper) ShouldRetry(statusCode int) bool { - for _, code := range r.retryStatusCodes { - if code == statusCode { - return true - } - } - return false -} - -func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - var resp *http.Response - var err error - backoff := 0 - for i := 0; i < r.maxAttempts; i++ { - if i > 0 { - if r.withBackOff && !fflag.DisableHttpRetryBackoff.IsEnabled() { - backoff += 10 + rand.Intn(20) - } - log.Infof("retrying in %d seconds (attempt %d of %d)", backoff, i+1, r.maxAttempts) - select { - case <-req.Context().Done(): - return resp, req.Context().Err() - case <-time.After(time.Duration(backoff) * time.Second): - } - } - if r.onBeforeRequest != nil { - r.onBeforeRequest(i) - } - clonedReq := cloneRequest(req) - resp, err = r.next.RoundTrip(clonedReq) - if err != nil { - log.Errorf("error while performing request: %s; %d retries left", err, r.maxAttempts-i-1) - continue - } - if !r.ShouldRetry(resp.StatusCode) { - return resp, nil - } - } - return resp, err -} - -type JWTTransport struct { - MachineID *string - Password *strfmt.Password - Token string - Expiration time.Time - Scenarios []string - URL *url.URL - VersionPrefix string - UserAgent string - // Transport is the underlying HTTP transport to use when making requests. - // It will default to http.DefaultTransport if nil. - Transport http.RoundTripper - UpdateScenario func() ([]string, error) - refreshTokenMutex sync.Mutex -} - -func (t *JWTTransport) refreshJwtToken() error { - var err error - if t.UpdateScenario != nil { - t.Scenarios, err = t.UpdateScenario() - if err != nil { - return fmt.Errorf("can't update scenario list: %s", err) - } - log.Debugf("scenarios list updated for '%s'", *t.MachineID) - } - - var auth = models.WatcherAuthRequest{ - MachineID: t.MachineID, - Password: t.Password, - Scenarios: t.Scenarios, - } - - var response models.WatcherAuthResponse - - /* - we don't use the main client, so let's build the body - */ - var buf io.ReadWriter = &bytes.Buffer{} - enc := json.NewEncoder(buf) - enc.SetEscapeHTML(false) - err = enc.Encode(auth) - if err != nil { - return fmt.Errorf("could not encode jwt auth body: %w", err) - } - req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s%s/watchers/login", t.URL, t.VersionPrefix), buf) - if err != nil { - return fmt.Errorf("could not create request: %w", err) - } - req.Header.Add("Content-Type", "application/json") - client := &http.Client{ - Transport: &retryRoundTripper{ - next: http.DefaultTransport, - maxAttempts: 5, - withBackOff: true, - retryStatusCodes: []int{http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusGatewayTimeout, http.StatusInternalServerError}, - }, - } - if t.UserAgent != "" { - req.Header.Add("User-Agent", t.UserAgent) - } - if log.GetLevel() >= log.TraceLevel { - dump, _ := httputil.DumpRequest(req, true) - log.Tracef("auth-jwt request: %s", string(dump)) - } - - log.Debugf("auth-jwt(auth): %s %s", req.Method, req.URL.String()) - - resp, err := client.Do(req) - if err != nil { - return fmt.Errorf("could not get jwt token: %w", err) - } - log.Debugf("auth-jwt : http %d", resp.StatusCode) - - if log.GetLevel() >= log.TraceLevel { - dump, _ := httputil.DumpResponse(resp, true) - log.Tracef("auth-jwt response: %s", string(dump)) - } - - defer resp.Body.Close() - - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - log.Debugf("received response status %q when fetching %v", resp.Status, req.URL) - - err = CheckResponse(resp) - if err != nil { - return err - } - } - - if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { - return fmt.Errorf("unable to decode response: %w", err) - } - if err := t.Expiration.UnmarshalText([]byte(response.Expire)); err != nil { - return fmt.Errorf("unable to parse jwt expiration: %w", err) - } - t.Token = response.Token - - log.Debugf("token %s will expire on %s", t.Token, t.Expiration.String()) - return nil -} - -// RoundTrip implements the RoundTripper interface. -func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) { - // in a few occasions several goroutines will execute refreshJwtToken concurrently which is useless and will cause overload on CAPI - // we use a mutex to avoid this - //We also bypass the refresh if we are requesting the login endpoint, as it does not require a token, and it leads to do 2 requests instead of one (refresh + actual login request) - t.refreshTokenMutex.Lock() - if req.URL.Path != "/"+t.VersionPrefix+"/watchers/login" && (t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC())) { - if err := t.refreshJwtToken(); err != nil { - t.refreshTokenMutex.Unlock() - return nil, err - } - } - t.refreshTokenMutex.Unlock() - - if t.UserAgent != "" { - req.Header.Add("User-Agent", t.UserAgent) - } - - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.Token)) - - if log.GetLevel() >= log.TraceLevel { - //requestToDump := cloneRequest(req) - dump, _ := httputil.DumpRequest(req, true) - log.Tracef("req-jwt: %s", string(dump)) - } - - // Make the HTTP request. - resp, err := t.transport().RoundTrip(req) - if log.GetLevel() >= log.TraceLevel { - dump, _ := httputil.DumpResponse(resp, true) - log.Tracef("resp-jwt: %s (err:%v)", string(dump), err) - } - if err != nil { - /*we had an error (network error for example, or 401 because token is refused), reset the token ?*/ - t.Token = "" - return resp, fmt.Errorf("performing jwt auth: %w", err) - } - - log.Debugf("resp-jwt: %d", resp.StatusCode) - - return resp, nil -} - -func (t *JWTTransport) Client() *http.Client { - return &http.Client{Transport: t} -} - -func (t *JWTTransport) ResetToken() { - log.Debug("resetting jwt token") - t.refreshTokenMutex.Lock() - t.Token = "" - t.refreshTokenMutex.Unlock() -} - -func (t *JWTTransport) transport() http.RoundTripper { - var transport http.RoundTripper - if t.Transport != nil { - transport = t.Transport - } else { - transport = http.DefaultTransport - } - // a round tripper that retries once when the status is unauthorized and 5 times when infrastructure is overloaded - return &retryRoundTripper{ - next: &retryRoundTripper{ - next: transport, - maxAttempts: 5, - withBackOff: true, - retryStatusCodes: []int{http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusGatewayTimeout}, - }, - maxAttempts: 2, - withBackOff: false, - retryStatusCodes: []int{http.StatusUnauthorized, http.StatusForbidden}, - onBeforeRequest: func(attempt int) { - // reset the token only in the second attempt as this is when we know we had a 401 or 403 - // the second attempt is supposed to refresh the token - if attempt > 0 { - t.ResetToken() - } - }, - } -} - -// cloneRequest returns a clone of the provided *http.Request. The clone is a -// shallow copy of the struct and its Header map. -func cloneRequest(r *http.Request) *http.Request { - // shallow copy of the struct - r2 := new(http.Request) - *r2 = *r - // deep copy of the Header - r2.Header = make(http.Header, len(r.Header)) - for k, s := range r.Header { - r2.Header[k] = append([]string(nil), s...) - } - - if r.Body != nil { - var b bytes.Buffer - b.ReadFrom(r.Body) - r.Body = io.NopCloser(&b) - r2.Body = io.NopCloser(bytes.NewReader(b.Bytes())) - } - return r2 -} diff --git a/pkg/apiclient/auth_jwt.go b/pkg/apiclient/auth_jwt.go new file mode 100644 index 00000000000..193486ff065 --- /dev/null +++ b/pkg/apiclient/auth_jwt.go @@ -0,0 +1,253 @@ +package apiclient + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httputil" + "net/url" + "sync" + "time" + + "github.com/go-openapi/strfmt" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/pkg/models" +) + +type JWTTransport struct { + MachineID *string + Password *strfmt.Password + Token string + Expiration time.Time + Scenarios []string + URL *url.URL + VersionPrefix string + UserAgent string + RetryConfig *RetryConfig + // Transport is the underlying HTTP transport to use when making requests. + // It will default to http.DefaultTransport if nil. + Transport http.RoundTripper + UpdateScenario func(context.Context) ([]string, error) + refreshTokenMutex sync.Mutex +} + +func (t *JWTTransport) refreshJwtToken() error { + var err error + + ctx := context.TODO() + + if t.UpdateScenario != nil { + t.Scenarios, err = t.UpdateScenario(ctx) + if err != nil { + return fmt.Errorf("can't update scenario list: %w", err) + } + + log.Debugf("scenarios list updated for '%s'", *t.MachineID) + } + + auth := models.WatcherAuthRequest{ + MachineID: t.MachineID, + Password: t.Password, + Scenarios: t.Scenarios, + } + + /* + we don't use the main client, so let's build the body + */ + var buf io.ReadWriter = &bytes.Buffer{} + enc := json.NewEncoder(buf) + enc.SetEscapeHTML(false) + err = enc.Encode(auth) + + if err != nil { + return fmt.Errorf("could not encode jwt auth body: %w", err) + } + + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s%s/watchers/login", t.URL, t.VersionPrefix), buf) + if err != nil { + return fmt.Errorf("could not create request: %w", err) + } + + req.Header.Add("Content-Type", "application/json") + + transport := t.Transport + if transport == nil { + transport = http.DefaultTransport + } + + client := &http.Client{ + Transport: &retryRoundTripper{ + next: transport, + maxAttempts: 5, + withBackOff: true, + retryStatusCodes: []int{http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusGatewayTimeout, http.StatusInternalServerError}, + }, + } + + if t.UserAgent != "" { + req.Header.Add("User-Agent", t.UserAgent) + } + + if log.GetLevel() >= log.TraceLevel { + dump, _ := httputil.DumpRequest(req, true) + log.Tracef("auth-jwt request: %s", string(dump)) + } + + log.Debugf("auth-jwt(auth): %s %s", req.Method, req.URL.String()) + + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("could not get jwt token: %w", err) + } + + log.Debugf("auth-jwt : http %d", resp.StatusCode) + + if log.GetLevel() >= log.TraceLevel { + dump, _ := httputil.DumpResponse(resp, true) + log.Tracef("auth-jwt response: %s", string(dump)) + } + + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + log.Debugf("received response status %q when fetching %v", resp.Status, req.URL) + + err = CheckResponse(resp) + if err != nil { + return err + } + } + + var response models.WatcherAuthResponse + + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + return fmt.Errorf("unable to decode response: %w", err) + } + + if err := t.Expiration.UnmarshalText([]byte(response.Expire)); err != nil { + return fmt.Errorf("unable to parse jwt expiration: %w", err) + } + + t.Token = response.Token + + log.Debugf("token %s will expire on %s", t.Token, t.Expiration.String()) + + return nil +} + +func (t *JWTTransport) needsTokenRefresh() bool { + return t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC()) +} + +// prepareRequest returns a copy of the request with the necessary authentication headers. +func (t *JWTTransport) prepareRequest(req *http.Request) (*http.Request, error) { + // In a few occasions several goroutines will execute refreshJwtToken concurrently which is useless + // and will cause overload on CAPI. We use a mutex to avoid this. + t.refreshTokenMutex.Lock() + defer t.refreshTokenMutex.Unlock() + + // We bypass the refresh if we are requesting the login endpoint, as it does not require a token, + // and it leads to do 2 requests instead of one (refresh + actual login request). + if req.URL.Path != "/"+t.VersionPrefix+"/watchers/login" && t.needsTokenRefresh() { + if err := t.refreshJwtToken(); err != nil { + return nil, err + } + } + + if t.UserAgent != "" { + req.Header.Add("User-Agent", t.UserAgent) + } + + req.Header.Add("Authorization", "Bearer "+t.Token) + + return req, nil +} + +// RoundTrip implements the RoundTripper interface. +func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) { + + var resp *http.Response + attemptsCount := make(map[int]int) + + for { + if log.GetLevel() >= log.TraceLevel { + // requestToDump := cloneRequest(req) + dump, _ := httputil.DumpRequest(req, true) + log.Tracef("req-jwt: %s", string(dump)) + } + // Make the HTTP request. + clonedReq := cloneRequest(req) + + clonedReq, err := t.prepareRequest(clonedReq) + if err != nil { + return nil, err + } + + resp, err = t.transport().RoundTrip(clonedReq) + if log.GetLevel() >= log.TraceLevel { + dump, _ := httputil.DumpResponse(resp, true) + log.Tracef("resp-jwt: %s (err:%v)", string(dump), err) + } + + if err != nil { + // we had an error (network error for example), reset the token? + t.ResetToken() + return resp, fmt.Errorf("performing jwt auth: %w", err) + } + + if resp != nil { + log.Debugf("resp-jwt: %d", resp.StatusCode) + } + + config, shouldRetry := t.RetryConfig.StatusCodeConfig[resp.StatusCode] + if !shouldRetry { + break + } + + if attemptsCount[resp.StatusCode] >= config.MaxAttempts { + log.Infof("max attempts reached for status code %d", resp.StatusCode) + break + } + + if config.InvalidateToken { + log.Debugf("invalidating token for status code %d", resp.StatusCode) + t.ResetToken() + } + + log.Debugf("retrying request to %s", req.URL.String()) + attemptsCount[resp.StatusCode]++ + log.Infof("attempt %d out of %d", attemptsCount[resp.StatusCode], config.MaxAttempts) + + if config.Backoff { + backoff := 2*attemptsCount[resp.StatusCode] + 5 + log.Infof("retrying in %d seconds (attempt %d of %d)", backoff, attemptsCount[resp.StatusCode], config.MaxAttempts) + time.Sleep(time.Duration(backoff) * time.Second) + } + } + return resp, nil + +} + +func (t *JWTTransport) Client() *http.Client { + return &http.Client{Transport: t} +} + +func (t *JWTTransport) ResetToken() { + log.Debug("resetting jwt token") + t.refreshTokenMutex.Lock() + t.Token = "" + t.refreshTokenMutex.Unlock() +} + +// transport() returns a round tripper that retries once when the status is unauthorized, +// and 5 times when the infrastructure is overloaded. +func (t *JWTTransport) transport() http.RoundTripper { + if t.Transport != nil { + return t.Transport + } + return http.DefaultTransport +} diff --git a/pkg/apiclient/auth_key.go b/pkg/apiclient/auth_key.go new file mode 100644 index 00000000000..e2213aca227 --- /dev/null +++ b/pkg/apiclient/auth_key.go @@ -0,0 +1,73 @@ +package apiclient + +import ( + "errors" + "net/http" + "net/http/httputil" + "net/url" + + log "github.com/sirupsen/logrus" +) + +type APIKeyTransport struct { + APIKey string + // Transport is the underlying HTTP transport to use when making requests. + // It will default to http.DefaultTransport if nil. + Transport http.RoundTripper + URL *url.URL + VersionPrefix string + UserAgent string +} + +// RoundTrip implements the RoundTripper interface. +func (t *APIKeyTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if t.APIKey == "" { + return nil, errors.New("APIKey is empty") + } + + // We must make a copy of the Request so + // that we don't modify the Request we were given. This is required by the + // specification of http.RoundTripper. + req = cloneRequest(req) + req.Header.Add("X-Api-Key", t.APIKey) + + if t.UserAgent != "" { + req.Header.Add("User-Agent", t.UserAgent) + } + + log.Debugf("req-api: %s %s", req.Method, req.URL.String()) + + if log.GetLevel() >= log.TraceLevel { + dump, _ := httputil.DumpRequest(req, true) + log.Tracef("auth-api request: %s", string(dump)) + } + + // Make the HTTP request. + resp, err := t.transport().RoundTrip(req) + if err != nil { + log.Errorf("auth-api: auth with api key failed return nil response, error: %s", err) + + return resp, err + } + + if log.GetLevel() >= log.TraceLevel { + dump, _ := httputil.DumpResponse(resp, true) + log.Tracef("auth-api response: %s", string(dump)) + } + + log.Debugf("resp-api: http %d", resp.StatusCode) + + return resp, err +} + +func (t *APIKeyTransport) Client() *http.Client { + return &http.Client{Transport: t} +} + +func (t *APIKeyTransport) transport() http.RoundTripper { + if t.Transport != nil { + return t.Transport + } + + return http.DefaultTransport +} diff --git a/pkg/apiclient/auth_test.go b/pkg/apiclient/auth_key_test.go similarity index 66% rename from pkg/apiclient/auth_test.go rename to pkg/apiclient/auth_key_test.go index f28a0ea051a..f686de6227a 100644 --- a/pkg/apiclient/auth_test.go +++ b/pkg/apiclient/auth_key_test.go @@ -9,6 +9,9 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/crowdsecurity/go-cs-lib/cstest" + "github.com/crowdsecurity/go-cs-lib/ptr" ) func TestApiAuth(t *testing.T) { @@ -17,8 +20,9 @@ func TestApiAuth(t *testing.T) { mux, urlx, teardown := setup() mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "GET") + if r.Header.Get("X-Api-Key") == "ixu" { - assert.Equal(t, r.URL.RawQuery, "ip=1.2.3.4") + assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery) w.WriteHeader(http.StatusOK) w.Write([]byte(`null`)) } else { @@ -26,11 +30,11 @@ func TestApiAuth(t *testing.T) { w.Write([]byte(`{"message":"access forbidden"}`)) } }) + log.Printf("URL is %s", urlx) + apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) defer teardown() @@ -40,18 +44,12 @@ func TestApiAuth(t *testing.T) { } newcli, err := NewDefaultClient(apiURL, "v1", "toto", auth.Client()) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) - alert := DecisionsListOpts{IPEquals: new(string)} - *alert.IPEquals = "1.2.3.4" + alert := DecisionsListOpts{IPEquals: ptr.Of("1.2.3.4")} _, resp, err := newcli.Decisions.List(context.Background(), alert) require.NoError(t, err) - - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) //ko bad token auth = &APIKeyTransport{ @@ -59,28 +57,25 @@ func TestApiAuth(t *testing.T) { } newcli, err = NewDefaultClient(apiURL, "v1", "toto", auth.Client()) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) _, resp, err = newcli.Decisions.List(context.Background(), alert) log.Infof("--> %s", err) - if resp.Response.StatusCode != http.StatusForbidden { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - assert.Contains(t, err.Error(), "API error: access forbidden") + + assert.Equal(t, http.StatusForbidden, resp.Response.StatusCode) + + cstest.RequireErrorMessage(t, err, "API error: access forbidden") + //ko empty token auth = &APIKeyTransport{} + newcli, err = NewDefaultClient(apiURL, "v1", "toto", auth.Client()) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) _, _, err = newcli.Decisions.List(context.Background(), alert) require.Error(t, err) log.Infof("--> %s", err) assert.Contains(t, err.Error(), "APIKey is empty") - } diff --git a/pkg/apiclient/auth_retry.go b/pkg/apiclient/auth_retry.go new file mode 100644 index 00000000000..a17725439bc --- /dev/null +++ b/pkg/apiclient/auth_retry.go @@ -0,0 +1,81 @@ +package apiclient + +import ( + "math/rand" + "net/http" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/pkg/fflag" +) + +type retryRoundTripper struct { + next http.RoundTripper + maxAttempts int + retryStatusCodes []int + withBackOff bool + onBeforeRequest func(attempt int) +} + +func (r retryRoundTripper) ShouldRetry(statusCode int) bool { + for _, code := range r.retryStatusCodes { + if code == statusCode { + return true + } + } + + return false +} + +func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + var ( + resp *http.Response + err error + ) + + backoff := 0 + maxAttempts := r.maxAttempts + + if fflag.DisableHttpRetryBackoff.IsEnabled() { + maxAttempts = 1 + } + + for i := range maxAttempts { + if i > 0 { + if r.withBackOff { + //nolint:gosec + backoff += 10 + rand.Intn(20) + } + + log.Infof("retrying in %d seconds (attempt %d of %d)", backoff, i+1, r.maxAttempts) + + select { + case <-req.Context().Done(): + return nil, req.Context().Err() + case <-time.After(time.Duration(backoff) * time.Second): + } + } + + if r.onBeforeRequest != nil { + r.onBeforeRequest(i) + } + + clonedReq := cloneRequest(req) + + resp, err = r.next.RoundTrip(clonedReq) + if err != nil { + if left := maxAttempts - i - 1; left > 0 { + log.Errorf("error while performing request: %s; %d retries left", err, left) + } + + continue + } + + if !r.ShouldRetry(resp.StatusCode) { + return resp, nil + } + } + + return resp, err +} diff --git a/pkg/apiclient/auth_service.go b/pkg/apiclient/auth_service.go index 64284902e8c..e7a423cfd95 100644 --- a/pkg/apiclient/auth_service.go +++ b/pkg/apiclient/auth_service.go @@ -8,8 +8,6 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/models" ) -// type ApiAlerts service - type AuthService service // Don't add it to the models, as they are used with LAPI, but the enroll endpoint is specific to CAPI @@ -22,6 +20,7 @@ type enrollRequest struct { func (s *AuthService) UnregisterWatcher(ctx context.Context) (*Response, error) { u := fmt.Sprintf("%s/watchers", s.client.URLPrefix) + req, err := s.client.NewRequest(http.MethodDelete, u, nil) if err != nil { return nil, err @@ -31,6 +30,7 @@ func (s *AuthService) UnregisterWatcher(ctx context.Context) (*Response, error) if err != nil { return resp, err } + return resp, nil } @@ -46,6 +46,7 @@ func (s *AuthService) RegisterWatcher(ctx context.Context, registration models.W if err != nil { return resp, err } + return resp, nil } @@ -53,6 +54,7 @@ func (s *AuthService) AuthenticateWatcher(ctx context.Context, auth models.Watch var authResp models.WatcherAuthResponse u := fmt.Sprintf("%s/watchers/login", s.client.URLPrefix) + req, err := s.client.NewRequest(http.MethodPost, u, &auth) if err != nil { return authResp, nil, err @@ -62,11 +64,13 @@ func (s *AuthService) AuthenticateWatcher(ctx context.Context, auth models.Watch if err != nil { return authResp, resp, err } + return authResp, resp, nil } func (s *AuthService) EnrollWatcher(ctx context.Context, enrollKey string, name string, tags []string, overwrite bool) (*Response, error) { u := fmt.Sprintf("%s/watchers/enroll", s.client.URLPrefix) + req, err := s.client.NewRequest(http.MethodPost, u, &enrollRequest{EnrollKey: enrollKey, Name: name, Tags: tags, Overwrite: overwrite}) if err != nil { return nil, err @@ -76,5 +80,6 @@ func (s *AuthService) EnrollWatcher(ctx context.Context, enrollKey string, name if err != nil { return resp, err } + return resp, nil } diff --git a/pkg/apiclient/auth_service_test.go b/pkg/apiclient/auth_service_test.go index 32ba1890f62..d22c9394014 100644 --- a/pkg/apiclient/auth_service_test.go +++ b/pkg/apiclient/auth_service_test.go @@ -12,8 +12,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" - - "github.com/crowdsecurity/go-cs-lib/pkg/version" + "github.com/stretchr/testify/require" "github.com/crowdsecurity/crowdsec/pkg/models" ) @@ -24,39 +23,41 @@ type BasicMockPayload struct { } func getLoginsForMockErrorCases() map[string]int { - loginsForMockErrorCases := map[string]int{ + return map[string]int{ "login_400": http.StatusBadRequest, "login_409": http.StatusConflict, "login_500": http.StatusInternalServerError, } - - return loginsForMockErrorCases } func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) { loginsForMockErrorCases := getLoginsForMockErrorCases() + mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "POST") + buf := new(bytes.Buffer) _, _ = buf.ReadFrom(r.Body) newStr := buf.String() var payload BasicMockPayload + err := json.Unmarshal([]byte(newStr), &payload) if err != nil || payload.MachineID == "" || payload.Password == "" { log.Printf("Bad payload") w.WriteHeader(http.StatusBadRequest) } - responseBody := "" - responseCode, hasFoundErrorMock := loginsForMockErrorCases[payload.MachineID] + var responseBody string + responseCode, hasFoundErrorMock := loginsForMockErrorCases[payload.MachineID] if !hasFoundErrorMock { responseCode = http.StatusOK responseBody = `{"code":200,"expire":"2029-11-30T14:14:24+01:00","token":"toto"}` } else { responseBody = fmt.Sprintf("Error %d", responseCode) } + log.Printf("MockServerReceived > %s // Login : [%s] => Mux response [%d]", newStr, payload.MachineID, responseCode) w.WriteHeader(responseCode) @@ -71,134 +72,123 @@ func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) { * 400, 409, 500 => Error */ func TestWatcherRegister(t *testing.T) { - log.SetLevel(log.DebugLevel) mux, urlx, teardown := setup() defer teardown() - //body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password} + + // body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password} initBasicMuxMock(t, mux, "/watchers") log.Printf("URL is %s", urlx) + apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) // Valid Registration : should retrieve the client and no err clientconfig := Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", } - client, err := RegisterClient(&clientconfig, &http.Client{}) - if client == nil || err != nil { - t.Fatalf("while registering client : %s", err) - } + + ctx := context.Background() + + client, err := RegisterClient(ctx, &clientconfig, &http.Client{}) + require.NoError(t, err) + log.Printf("->%T", client) // Testing error handling on Registration (400, 409, 500): should retrieve an error errorCodesToTest := [3]int{http.StatusBadRequest, http.StatusConflict, http.StatusInternalServerError} for _, errorCodeToTest := range errorCodesToTest { clientconfig.MachineID = fmt.Sprintf("login_%d", errorCodeToTest) - client, err = RegisterClient(&clientconfig, &http.Client{}) - if client != nil || err == nil { - t.Fatalf("The RegisterClient function should have returned an error for the response code %d", errorCodeToTest) - } else { - log.Printf("The RegisterClient function handled the error code %d as expected \n\r", errorCodeToTest) - } + + client, err = RegisterClient(ctx, &clientconfig, &http.Client{}) + require.Nil(t, client, "nil expected for the response code %d", errorCodeToTest) + require.Error(t, err, "error expected for the response code %d", errorCodeToTest) } } func TestWatcherAuth(t *testing.T) { - log.SetLevel(log.DebugLevel) mux, urlx, teardown := setup() defer teardown() - //body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password} + // body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password} initBasicMuxMock(t, mux, "/watchers/login") log.Printf("URL is %s", urlx) + apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) - //ok auth + // ok auth clientConfig := &Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", Scenarios: []string{"crowdsecurity/test"}, } - client, err := NewClient(clientConfig) - if err != nil { - t.Fatalf("new api client: %s", err) - } + client, err := NewClient(clientConfig) + require.NoError(t, err) _, _, err = client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ MachineID: &clientConfig.MachineID, Password: &clientConfig.Password, Scenarios: clientConfig.Scenarios, }) - if err != nil { - t.Fatalf("unexpect auth err 0: %s", err) - } + require.NoError(t, err) // Testing error handling on AuthenticateWatcher (400, 409): should retrieve an error // Not testing 500 because it loops and try to re-autehnticate. But you can test it manually by adding it in array errorCodesToTest := [2]int{http.StatusBadRequest, http.StatusConflict} for _, errorCodeToTest := range errorCodesToTest { clientConfig.MachineID = fmt.Sprintf("login_%d", errorCodeToTest) - client, err := NewClient(clientConfig) - if err != nil { - t.Fatalf("new api client: %s", err) - } + client, err := NewClient(clientConfig) + require.NoError(t, err) - var resp *Response - _, resp, err = client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ + _, resp, err := client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ MachineID: &clientConfig.MachineID, Password: &clientConfig.Password, }) if err == nil { resp.Response.Body.Close() + bodyBytes, err := io.ReadAll(resp.Response.Body) - if err != nil { - t.Fatalf("error while reading body: %s", err.Error()) - } + require.NoError(t, err) - log.Printf(string(bodyBytes)) + log.Print(string(bodyBytes)) t.Fatalf("The AuthenticateWatcher function should have returned an error for the response code %d", errorCodeToTest) - } else { - log.Printf("The AuthenticateWatcher function handled the error code %d as expected \n\r", errorCodeToTest) } + + log.Printf("The AuthenticateWatcher function handled the error code %d as expected \n\r", errorCodeToTest) } } func TestWatcherUnregister(t *testing.T) { - log.SetLevel(log.DebugLevel) mux, urlx, teardown := setup() defer teardown() - //body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password} + // body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password} mux.HandleFunc("/watchers", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "DELETE") - assert.Equal(t, r.ContentLength, int64(0)) + assert.Equal(t, int64(0), r.ContentLength) w.WriteHeader(http.StatusOK) }) + mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "POST") + buf := new(bytes.Buffer) _, _ = buf.ReadFrom(r.Body) + newStr := buf.String() if newStr == `{"machine_id":"test_login","password":"test_password","scenarios":["crowdsecurity/test"]} ` { @@ -211,27 +201,24 @@ func TestWatcherUnregister(t *testing.T) { }) log.Printf("URL is %s", urlx) + apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) + mycfg := &Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", Scenarios: []string{"crowdsecurity/test"}, } + client, err := NewClient(mycfg) + require.NoError(t, err) - if err != nil { - t.Fatalf("new api client: %s", err) - } _, err = client.Auth.UnregisterWatcher(context.Background()) - if err != nil { - t.Fatalf("while registering client : %s", err) - } + require.NoError(t, err) + log.Printf("->%T", client) } @@ -243,10 +230,12 @@ func TestWatcherEnroll(t *testing.T) { mux.HandleFunc("/watchers/enroll", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "POST") + buf := new(bytes.Buffer) _, _ = buf.ReadFrom(r.Body) newStr := buf.String() log.Debugf("body -> %s", newStr) + if newStr == `{"attachment_key":"goodkey","name":"","tags":[],"overwrite":false} ` { log.Print("good key") @@ -258,35 +247,31 @@ func TestWatcherEnroll(t *testing.T) { fmt.Fprintf(w, `{"message":"the attachment key provided is not valid"}`) } }) + mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "POST") w.WriteHeader(http.StatusOK) fmt.Fprintf(w, `{"code":200,"expire":"2029-11-30T14:14:24+01:00","token":"toto"}`) }) + log.Printf("URL is %s", urlx) + apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) mycfg := &Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", Scenarios: []string{"crowdsecurity/test"}, } - client, err := NewClient(mycfg) - if err != nil { - t.Fatalf("new api client: %s", err) - } + client, err := NewClient(mycfg) + require.NoError(t, err) _, err = client.Auth.EnrollWatcher(context.Background(), "goodkey", "", []string{}, false) - if err != nil { - t.Fatalf("unexpect enroll err: %s", err) - } + require.NoError(t, err) _, err = client.Auth.EnrollWatcher(context.Background(), "badkey", "", []string{}, false) assert.Contains(t, err.Error(), "the attachment key provided is not valid", "got %s", err.Error()) diff --git a/pkg/apiclient/client.go b/pkg/apiclient/client.go index d95f7749041..47d97a28344 100644 --- a/pkg/apiclient/client.go +++ b/pkg/apiclient/client.go @@ -4,12 +4,15 @@ import ( "context" "crypto/tls" "crypto/x509" - "encoding/json" "fmt" - "io" + "net" "net/http" "net/url" + "strings" + "github.com/golang-jwt/jwt/v4" + + "github.com/crowdsecurity/crowdsec/pkg/apiclient/useragent" "github.com/crowdsecurity/crowdsec/pkg/models" ) @@ -37,35 +40,78 @@ type ApiClient struct { Metrics *MetricsService Signal *SignalService HeartBeat *HeartBeatService + UsageMetrics *UsageMetricsService } func (a *ApiClient) GetClient() *http.Client { return a.client } +func (a *ApiClient) IsEnrolled() bool { + jwtTransport := a.client.Transport.(*JWTTransport) + tokenStr := jwtTransport.Token + + token, _ := jwt.Parse(tokenStr, nil) + if token == nil { + return false + } + + claims := token.Claims.(jwt.MapClaims) + _, ok := claims["organization_id"] + + return ok +} + type service struct { client *ApiClient } func NewClient(config *Config) (*ApiClient, error) { + userAgent := config.UserAgent + if userAgent == "" { + userAgent = useragent.Default() + } + t := &JWTTransport{ MachineID: &config.MachineID, Password: &config.Password, Scenarios: config.Scenarios, - URL: config.URL, - UserAgent: config.UserAgent, + UserAgent: userAgent, VersionPrefix: config.VersionPrefix, UpdateScenario: config.UpdateScenario, + RetryConfig: NewRetryConfig( + WithStatusCodeConfig(http.StatusUnauthorized, 2, false, true), + WithStatusCodeConfig(http.StatusForbidden, 2, false, true), + WithStatusCodeConfig(http.StatusTooManyRequests, 5, true, false), + WithStatusCodeConfig(http.StatusServiceUnavailable, 5, true, false), + WithStatusCodeConfig(http.StatusGatewayTimeout, 5, true, false), + ), } + + transport, baseURL := createTransport(config.URL) + if transport != nil { + t.Transport = transport + } else { + // can be httpmock.MockTransport + if ht, ok := http.DefaultTransport.(*http.Transport); ok { + t.Transport = ht.Clone() + } + } + + t.URL = baseURL + tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify} tlsconfig.RootCAs = CaCertPool + if Cert != nil { tlsconfig.Certificates = []tls.Certificate{*Cert} } - if ht, ok := http.DefaultTransport.(*http.Transport); ok { - ht.TLSClientConfig = &tlsconfig + + if t.Transport != nil { + t.Transport.(*http.Transport).TLSClientConfig = &tlsconfig } - c := &ApiClient{client: t.Client(), BaseURL: config.URL, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix, PapiURL: config.PapiURL} + + c := &ApiClient{client: t.Client(), BaseURL: baseURL, UserAgent: userAgent, URLPrefix: config.VersionPrefix, PapiURL: config.PapiURL} c.common.client = c c.Decisions = (*DecisionsService)(&c.common) c.Alerts = (*AlertsService)(&c.common) @@ -74,24 +120,40 @@ func NewClient(config *Config) (*ApiClient, error) { c.Signal = (*SignalService)(&c.common) c.DecisionDelete = (*DecisionDeleteService)(&c.common) c.HeartBeat = (*HeartBeatService)(&c.common) + c.UsageMetrics = (*UsageMetricsService)(&c.common) return c, nil } func NewDefaultClient(URL *url.URL, prefix string, userAgent string, client *http.Client) (*ApiClient, error) { + transport, baseURL := createTransport(URL) + if client == nil { client = &http.Client{} - if ht, ok := http.DefaultTransport.(*http.Transport); ok { - tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify} - tlsconfig.RootCAs = CaCertPool - if Cert != nil { - tlsconfig.Certificates = []tls.Certificate{*Cert} + + if transport != nil { + client.Transport = transport + } else { + if ht, ok := http.DefaultTransport.(*http.Transport); ok { + ht = ht.Clone() + tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify} + tlsconfig.RootCAs = CaCertPool + + if Cert != nil { + tlsconfig.Certificates = []tls.Certificate{*Cert} + } + + ht.TLSClientConfig = &tlsconfig + client.Transport = ht } - ht.TLSClientConfig = &tlsconfig - client.Transport = ht } } - c := &ApiClient{client: client, BaseURL: URL, UserAgent: userAgent, URLPrefix: prefix} + + if userAgent == "" { + userAgent = useragent.Default() + } + + c := &ApiClient{client: client, BaseURL: baseURL, UserAgent: userAgent, URLPrefix: prefix} c.common.client = c c.Decisions = (*DecisionsService)(&c.common) c.Alerts = (*AlertsService)(&c.common) @@ -100,89 +162,96 @@ func NewDefaultClient(URL *url.URL, prefix string, userAgent string, client *htt c.Signal = (*SignalService)(&c.common) c.DecisionDelete = (*DecisionDeleteService)(&c.common) c.HeartBeat = (*HeartBeatService)(&c.common) + c.UsageMetrics = (*UsageMetricsService)(&c.common) return c, nil } -func RegisterClient(config *Config, client *http.Client) (*ApiClient, error) { +func RegisterClient(ctx context.Context, config *Config, client *http.Client) (*ApiClient, error) { + transport, baseURL := createTransport(config.URL) + if client == nil { client = &http.Client{} + if transport != nil { + client.Transport = transport + } else { + tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify} + if Cert != nil { + tlsconfig.RootCAs = CaCertPool + tlsconfig.Certificates = []tls.Certificate{*Cert} + } + + client.Transport = http.DefaultTransport.(*http.Transport).Clone() + client.Transport.(*http.Transport).TLSClientConfig = &tlsconfig + } + } else if client.Transport == nil && transport != nil { + client.Transport = transport } - tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify} - if Cert != nil { - tlsconfig.RootCAs = CaCertPool - tlsconfig.Certificates = []tls.Certificate{*Cert} + + userAgent := config.UserAgent + if userAgent == "" { + userAgent = useragent.Default() } - http.DefaultTransport.(*http.Transport).TLSClientConfig = &tlsconfig - c := &ApiClient{client: client, BaseURL: config.URL, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix} + + c := &ApiClient{client: client, BaseURL: baseURL, UserAgent: userAgent, URLPrefix: config.VersionPrefix} c.common.client = c c.Decisions = (*DecisionsService)(&c.common) c.Alerts = (*AlertsService)(&c.common) c.Auth = (*AuthService)(&c.common) - resp, err := c.Auth.RegisterWatcher(context.Background(), models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password}) - /*if we have http status, return it*/ + resp, err := c.Auth.RegisterWatcher(ctx, models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password, RegistrationToken: config.RegistrationToken}) if err != nil { + /*if we have http status, return it*/ if resp != nil && resp.Response != nil { return nil, fmt.Errorf("api register (%s) http %s: %w", c.BaseURL, resp.Response.Status, err) } + return nil, fmt.Errorf("api register (%s): %w", c.BaseURL, err) } - return c, nil - -} -type Response struct { - Response *http.Response - //add our pagination stuff - //NextPage int - //... + return c, nil } -type ErrorResponse struct { - models.ErrorResponse -} +func createTransport(url *url.URL) (*http.Transport, *url.URL) { + urlString := url.String() -func (e *ErrorResponse) Error() string { - err := fmt.Sprintf("API error: %s", *e.Message) - if len(e.Errors) > 0 { - err += fmt.Sprintf(" (%s)", e.Errors) + // TCP transport + if !strings.HasPrefix(urlString, "/") { + return nil, url } - return err + + // Unix transport + url.Path = "/" + url.Host = "unix" + url.Scheme = "http" + + return &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", strings.TrimSuffix(urlString, "/")) + }, + }, url } -func newResponse(r *http.Response) *Response { - response := &Response{Response: r} - return response +type Response struct { + Response *http.Response + // add our pagination stuff + // NextPage int + // ... } -func CheckResponse(r *http.Response) error { - if c := r.StatusCode; 200 <= c && c <= 299 || c == 304 { - return nil - } - errorResponse := &ErrorResponse{} - data, err := io.ReadAll(r.Body) - if err == nil && data != nil { - err := json.Unmarshal(data, errorResponse) - if err != nil { - return fmt.Errorf("http code %d, invalid body: %w", r.StatusCode, err) - } - } else { - errorResponse.Message = new(string) - *errorResponse.Message = fmt.Sprintf("http code %d, no error message", r.StatusCode) - } - return errorResponse +func newResponse(r *http.Response) *Response { + return &Response{Response: r} } type ListOpts struct { - //Page int - //PerPage int + // Page int + // PerPage int } type DeleteOpts struct { - //?? + // ?? } type AddOpts struct { - //?? + // ?? } diff --git a/pkg/apiclient/client_http.go b/pkg/apiclient/client_http.go index 2c55128e1df..0240618f535 100644 --- a/pkg/apiclient/client_http.go +++ b/pkg/apiclient/client_http.go @@ -19,6 +19,7 @@ func (c *ApiClient) NewRequest(method, url string, body interface{}) (*http.Requ if !strings.HasSuffix(c.BaseURL.Path, "/") { return nil, fmt.Errorf("BaseURL must have a trailing slash, but %q does not", c.BaseURL) } + u, err := c.BaseURL.Parse(url) if err != nil { return nil, err @@ -29,8 +30,8 @@ func (c *ApiClient) NewRequest(method, url string, body interface{}) (*http.Requ buf = &bytes.Buffer{} enc := json.NewEncoder(buf) enc.SetEscapeHTML(false) - err := enc.Encode(body) - if err != nil { + + if err = enc.Encode(body); err != nil { return nil, err } } @@ -51,6 +52,7 @@ func (c *ApiClient) Do(ctx context.Context, req *http.Request, v interface{}) (* if ctx == nil { return nil, errors.New("context must be non-nil") } + req = req.WithContext(ctx) // Check rate limit @@ -62,6 +64,7 @@ func (c *ApiClient) Do(ctx context.Context, req *http.Request, v interface{}) (* if log.GetLevel() >= log.DebugLevel { log.Debugf("[URL] %s %s", req.Method, req.URL) } + resp, err := c.client.Do(req) if resp != nil && resp.Body != nil { defer resp.Body.Close() @@ -82,14 +85,16 @@ func (c *ApiClient) Do(ctx context.Context, req *http.Request, v interface{}) (* e.URL = url.String() return newResponse(resp), e } + return newResponse(resp), err } + return newResponse(resp), err } if log.GetLevel() >= log.DebugLevel { for k, v := range resp.Header { - log.Debugf("[headers] %s : %s", k, v) + log.Debugf("[headers] %s: %s", k, v) } dump, err := httputil.DumpResponse(resp, true) @@ -112,9 +117,12 @@ func (c *ApiClient) Do(ctx context.Context, req *http.Request, v interface{}) (* if errors.Is(decErr, io.EOF) { decErr = nil // ignore EOF errors caused by empty response body } + return response, decErr } + io.Copy(w, resp.Body) } + return response, err } diff --git a/pkg/apiclient/client_http_test.go b/pkg/apiclient/client_http_test.go index 9e082cf51cf..45cd8410a8e 100644 --- a/pkg/apiclient/client_http_test.go +++ b/pkg/apiclient/client_http_test.go @@ -2,35 +2,32 @@ package apiclient import ( "context" - "fmt" "net/http" "net/url" "testing" "time" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" - "github.com/crowdsecurity/go-cs-lib/pkg/version" + "github.com/crowdsecurity/go-cs-lib/cstest" ) func TestNewRequestInvalid(t *testing.T) { mux, urlx, teardown := setup() defer teardown() - //missing slash in uri + + // missing slash in uri apiURL, err := url.Parse(urlx) - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) + client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", }) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) + /*mock login*/ mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) @@ -43,27 +40,25 @@ func TestNewRequestInvalid(t *testing.T) { }) _, _, err = client.Alerts.List(context.Background(), AlertsListOpts{}) - assert.Contains(t, err.Error(), `building request: BaseURL must have a trailing slash, but `) + cstest.RequireErrorContains(t, err, "building request: BaseURL must have a trailing slash, but ") } func TestNewRequestTimeout(t *testing.T) { mux, urlx, teardown := setup() defer teardown() - //missing slash in uri + + // missing slash in uri apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) + client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", }) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) + /*mock login*/ mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) { time.Sleep(2 * time.Second) @@ -73,5 +68,5 @@ func TestNewRequestTimeout(t *testing.T) { defer cancel() _, _, err = client.Alerts.List(ctx, AlertsListOpts{}) - assert.Contains(t, err.Error(), `performing request: context deadline exceeded`) + cstest.RequireErrorMessage(t, err, "performing request: context deadline exceeded") } diff --git a/pkg/apiclient/client_test.go b/pkg/apiclient/client_test.go index 08f56730b86..d1f58f33ad2 100644 --- a/pkg/apiclient/client_test.go +++ b/pkg/apiclient/client_test.go @@ -3,16 +3,20 @@ package apiclient import ( "context" "fmt" + "net" "net/http" "net/http/httptest" "net/url" + "path" "runtime" + "strings" "testing" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" - "github.com/crowdsecurity/go-cs-lib/pkg/version" + "github.com/crowdsecurity/go-cs-lib/cstest" ) /*this is a ripoff of google/go-github approach : @@ -20,42 +24,113 @@ import ( - each test will then bind handler for the method(s) they want to try */ -func setup() (mux *http.ServeMux, serverURL string, teardown func()) { +func setup() (*http.ServeMux, string, func()) { return setupWithPrefix("v1") } -func setupWithPrefix(urlPrefix string) (mux *http.ServeMux, serverURL string, teardown func()) { +func setupWithPrefix(urlPrefix string) (*http.ServeMux, string, func()) { // mux is the HTTP request multiplexer used with the test server. - mux = http.NewServeMux() + mux := http.NewServeMux() baseURLPath := "/" + urlPrefix apiHandler := http.NewServeMux() apiHandler.Handle(baseURLPath+"/", http.StripPrefix(baseURLPath, mux)) - // server is a test HTTP server used to provide mock API responses. server := httptest.NewServer(apiHandler) return mux, server.URL, server.Close } +// toUNCPath converts a Windows file path to a UNC path. +// This is necessary because the Go http package does not support Windows file paths. +func toUNCPath(path string) (string, error) { + colonIdx := strings.Index(path, ":") + if colonIdx == -1 { + return "", fmt.Errorf("invalid path format, missing drive letter: %s", path) + } + + // URL parsing does not like backslashes + remaining := strings.ReplaceAll(path[colonIdx+1:], "\\", "/") + uncPath := "//localhost/" + path[:colonIdx] + "$" + remaining + + return uncPath, nil +} + +func setupUnixSocketWithPrefix(socket string, urlPrefix string) (mux *http.ServeMux, serverURL string, teardown func()) { + var err error + if runtime.GOOS == "windows" { + socket, err = toUNCPath(socket) + if err != nil { + log.Fatalf("converting to UNC path: %s", err) + } + } + + mux = http.NewServeMux() + baseURLPath := "/" + urlPrefix + + apiHandler := http.NewServeMux() + apiHandler.Handle(baseURLPath+"/", http.StripPrefix(baseURLPath, mux)) + + server := httptest.NewUnstartedServer(apiHandler) + l, _ := net.Listen("unix", socket) + _ = server.Listener.Close() + server.Listener = l + server.Start() + + return mux, socket, server.Close +} + func testMethod(t *testing.T, r *http.Request, want string) { t.Helper() - if got := r.Method; got != want { - t.Errorf("Request method: %v, want %v", got, want) - } + assert.Equal(t, want, r.Method) } func TestNewClientOk(t *testing.T) { mux, urlx, teardown := setup() defer teardown() + apiURL, err := url.Parse(urlx + "/") + require.NoError(t, err) + + client, err := NewClient(&Config{ + MachineID: "test_login", + Password: "test_password", + URL: apiURL, + VersionPrefix: "v1", + }) + require.NoError(t, err) + + /*mock login*/ + mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`)) + }) + + mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + w.WriteHeader(http.StatusOK) + }) + + _, resp, err := client.Alerts.List(context.Background(), AlertsListOpts{}) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) +} + +func TestNewClientOk_UnixSocket(t *testing.T) { + tmpDir := t.TempDir() + socket := path.Join(tmpDir, "socket") + + mux, urlx, teardown := setupUnixSocketWithPrefix(socket, "v1") + defer teardown() + + apiURL, err := url.Parse(urlx) if err != nil { t.Fatalf("parsing api url: %s", apiURL) } + client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", }) @@ -77,6 +152,7 @@ func TestNewClientOk(t *testing.T) { if err != nil { t.Fatalf("test Unable to list alerts : %+v", err) } + if resp.Response.StatusCode != http.StatusOK { t.Fatalf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusCreated) } @@ -85,20 +161,18 @@ func TestNewClientOk(t *testing.T) { func TestNewClientKo(t *testing.T) { mux, urlx, teardown := setup() defer teardown() + apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) + client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", }) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) + /*mock login*/ mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) @@ -111,25 +185,54 @@ func TestNewClientKo(t *testing.T) { }) _, _, err = client.Alerts.List(context.Background(), AlertsListOpts{}) - assert.Contains(t, err.Error(), `API error: bad login/password`) + cstest.RequireErrorContains(t, err, `API error: bad login/password`) + log.Printf("err-> %s", err) } func TestNewDefaultClient(t *testing.T) { mux, urlx, teardown := setup() defer teardown() + apiURL, err := url.Parse(urlx + "/") + require.NoError(t, err) + + client, err := NewDefaultClient(apiURL, "/v1", "", nil) + require.NoError(t, err) + + mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"code": 401, "message" : "brr"}`)) + }) + + _, _, err = client.Alerts.List(context.Background(), AlertsListOpts{}) + cstest.RequireErrorMessage(t, err, "performing request: API error: brr") + + log.Printf("err-> %s", err) +} + +func TestNewDefaultClient_UnixSocket(t *testing.T) { + tmpDir := t.TempDir() + socket := path.Join(tmpDir, "socket") + + mux, urlx, teardown := setupUnixSocketWithPrefix(socket, "v1") + defer teardown() + + apiURL, err := url.Parse(urlx) if err != nil { t.Fatalf("parsing api url: %s", apiURL) } + client, err := NewDefaultClient(apiURL, "/v1", "", nil) if err != nil { t.Fatalf("new api client: %s", err) } + mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte(`{"code": 401, "message" : "brr"}`)) }) + _, _, err = client.Alerts.List(context.Background(), AlertsListOpts{}) assert.Contains(t, err.Error(), `performing request: API error: brr`) log.Printf("err-> %s", err) @@ -137,25 +240,27 @@ func TestNewDefaultClient(t *testing.T) { func TestNewClientRegisterKO(t *testing.T) { apiURL, err := url.Parse("http://127.0.0.1:4242/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } - _, err = RegisterClient(&Config{ + require.NoError(t, err) + + ctx := context.Background() + + _, err = RegisterClient(ctx, &Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", }, &http.Client{}) - if runtime.GOOS != "windows" { - assert.Contains(t, fmt.Sprintf("%s", err), "dial tcp 127.0.0.1:4242: connect: connection refused") + + if runtime.GOOS == "windows" { + cstest.RequireErrorContains(t, err, " No connection could be made because the target machine actively refused it.") } else { - assert.Contains(t, fmt.Sprintf("%s", err), " No connection could be made because the target machine actively refused it.") + cstest.RequireErrorContains(t, err, "dial tcp 127.0.0.1:4242: connect: connection refused") } } func TestNewClientRegisterOK(t *testing.T) { log.SetLevel(log.TraceLevel) + mux, urlx, teardown := setup() defer teardown() @@ -167,24 +272,60 @@ func TestNewClientRegisterOK(t *testing.T) { }) apiURL, err := url.Parse(urlx + "/") + require.NoError(t, err) + + ctx := context.Background() + + client, err := RegisterClient(ctx, &Config{ + MachineID: "test_login", + Password: "test_password", + URL: apiURL, + VersionPrefix: "v1", + }, &http.Client{}) + require.NoError(t, err) + + log.Printf("->%T", client) +} + +func TestNewClientRegisterOK_UnixSocket(t *testing.T) { + log.SetLevel(log.TraceLevel) + + tmpDir := t.TempDir() + socket := path.Join(tmpDir, "socket") + + mux, urlx, teardown := setupUnixSocketWithPrefix(socket, "v1") + defer teardown() + + /*mock login*/ + mux.HandleFunc("/watchers", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "POST") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`)) + }) + + apiURL, err := url.Parse(urlx) if err != nil { t.Fatalf("parsing api url: %s", apiURL) } - client, err := RegisterClient(&Config{ + + ctx := context.Background() + + client, err := RegisterClient(ctx, &Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", }, &http.Client{}) if err != nil { t.Fatalf("while registering client : %s", err) } + log.Printf("->%T", client) } func TestNewClientBadAnswer(t *testing.T) { log.SetLevel(log.TraceLevel) + mux, urlx, teardown := setup() defer teardown() @@ -194,16 +335,17 @@ func TestNewClientBadAnswer(t *testing.T) { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte(`bad`)) }) + apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } - _, err = RegisterClient(&Config{ + require.NoError(t, err) + + ctx := context.Background() + + _, err = RegisterClient(ctx, &Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", }, &http.Client{}) - assert.Contains(t, fmt.Sprintf("%s", err), `invalid body: invalid character 'b' looking for beginning of value`) + cstest.RequireErrorContains(t, err, "API error: http code 401, response: bad") } diff --git a/pkg/apiclient/clone.go b/pkg/apiclient/clone.go new file mode 100644 index 00000000000..e8f47429639 --- /dev/null +++ b/pkg/apiclient/clone.go @@ -0,0 +1,32 @@ +package apiclient + +import ( + "bytes" + "io" + "net/http" +) + +// cloneRequest returns a clone of the provided *http.Request. The clone is a +// shallow copy of the struct and its Header map. +func cloneRequest(r *http.Request) *http.Request { + // shallow copy of the struct + r2 := new(http.Request) + *r2 = *r + // deep copy of the Header + r2.Header = make(http.Header, len(r.Header)) + + for k, s := range r.Header { + r2.Header[k] = append([]string(nil), s...) + } + + if r.Body != nil { + var b bytes.Buffer + + b.ReadFrom(r.Body) + + r.Body = io.NopCloser(&b) + r2.Body = io.NopCloser(bytes.NewReader(b.Bytes())) + } + + return r2 +} diff --git a/pkg/apiclient/config.go b/pkg/apiclient/config.go index 4dfeb3e863f..29a8acf185e 100644 --- a/pkg/apiclient/config.go +++ b/pkg/apiclient/config.go @@ -1,18 +1,20 @@ package apiclient import ( + "context" "net/url" "github.com/go-openapi/strfmt" ) type Config struct { - MachineID string - Password strfmt.Password - Scenarios []string - URL *url.URL - PapiURL *url.URL - VersionPrefix string - UserAgent string - UpdateScenario func() ([]string, error) + MachineID string + Password strfmt.Password + Scenarios []string + URL *url.URL + PapiURL *url.URL + VersionPrefix string + UserAgent string + RegistrationToken string + UpdateScenario func(context.Context) ([]string, error) } diff --git a/pkg/apiclient/decisions_service.go b/pkg/apiclient/decisions_service.go index e96394f5611..98f26cad9ae 100644 --- a/pkg/apiclient/decisions_service.go +++ b/pkg/apiclient/decisions_service.go @@ -3,14 +3,14 @@ package apiclient import ( "bufio" "context" + "errors" "fmt" "net/http" qs "github.com/google/go-querystring/query" - "github.com/pkg/errors" log "github.com/sirupsen/logrus" - "github.com/crowdsecurity/go-cs-lib/pkg/ptr" + "github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/modelscapi" @@ -42,6 +42,7 @@ func (o *DecisionsStreamOpts) addQueryParamsToURL(url string) (string, error) { if err != nil { return "", err } + return fmt.Sprintf("%s?%s", url, params.Encode()), nil } @@ -60,11 +61,11 @@ type DecisionsDeleteOpts struct { // to demo query arguments func (s *DecisionsService) List(ctx context.Context, opts DecisionsListOpts) (*models.GetDecisionsResponse, *Response, error) { - var decisions models.GetDecisionsResponse params, err := qs.Values(opts) if err != nil { return nil, nil, err } + u := fmt.Sprintf("%s/decisions?%s", s.client.URLPrefix, params.Encode()) req, err := s.client.NewRequest(http.MethodGet, u, nil) @@ -72,6 +73,8 @@ func (s *DecisionsService) List(ctx context.Context, opts DecisionsListOpts) (*m return nil, nil, err } + var decisions models.GetDecisionsResponse + resp, err := s.client.Do(ctx, req, &decisions) if err != nil { return nil, resp, err @@ -81,13 +84,13 @@ func (s *DecisionsService) List(ctx context.Context, opts DecisionsListOpts) (*m } func (s *DecisionsService) FetchV2Decisions(ctx context.Context, url string) (*models.DecisionsStreamResponse, *Response, error) { - var decisions models.DecisionsStreamResponse - req, err := s.client.NewRequest(http.MethodGet, url, nil) if err != nil { return nil, nil, err } + var decisions models.DecisionsStreamResponse + resp, err := s.client.Do(ctx, req, &decisions) if err != nil { return nil, resp, err @@ -97,7 +100,7 @@ func (s *DecisionsService) FetchV2Decisions(ctx context.Context, url string) (*m } func (s *DecisionsService) GetDecisionsFromGroups(decisionsGroups []*modelscapi.GetDecisionsStreamResponseNewItem) []*models.Decision { - var decisions []*models.Decision + decisions := make([]*models.Decision, 0) for _, decisionsGroup := range decisionsGroups { partialDecisions := make([]*models.Decision, len(decisionsGroup.Decisions)) @@ -111,15 +114,14 @@ func (s *DecisionsService) GetDecisionsFromGroups(decisionsGroups []*modelscapi. Origin: ptr.Of(types.CAPIOrigin), } } + decisions = append(decisions, partialDecisions...) } + return decisions } func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*models.DecisionsStreamResponse, *Response, error) { - var decisions modelscapi.GetDecisionsStreamResponse - var v2Decisions models.DecisionsStreamResponse - scenarioDeleted := "deleted" durationDeleted := "1h" @@ -128,16 +130,21 @@ func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*m return nil, nil, err } + decisions := modelscapi.GetDecisionsStreamResponse{} + resp, err := s.client.Do(ctx, req, &decisions) if err != nil { return nil, resp, err } + v2Decisions := models.DecisionsStreamResponse{} v2Decisions.New = s.GetDecisionsFromGroups(decisions.New) + for _, decisionsGroup := range decisions.Deleted { partialDecisions := make([]*models.Decision, len(decisionsGroup.Decisions)) + for idx, decision := range decisionsGroup.Decisions { - decision := decision // fix exportloopref linter message + decision := decision //nolint:copyloopvar // fix exportloopref linter message partialDecisions[idx] = &models.Decision{ Scenario: &scenarioDeleted, Scope: decisionsGroup.Scope, @@ -147,6 +154,7 @@ func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*m Origin: ptr.Of(types.CAPIOrigin), } } + v2Decisions.Deleted = append(v2Decisions.Deleted, partialDecisions...) } @@ -161,6 +169,7 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl log.Debugf("Fetching blocklist %s", *blocklist.URL) client := http.Client{} + req, err := http.NewRequest(http.MethodGet, *blocklist.URL, nil) if err != nil { return nil, false, err @@ -169,9 +178,12 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl if lastPullTimestamp != nil { req.Header.Set("If-Modified-Since", *lastPullTimestamp) } + req = req.WithContext(ctx) log.Debugf("[URL] %s %s", req.Method, req.URL) - // we dont use client_http Do method because we need the reader and is not provided. We would be forced to use Pipe and goroutine, etc + + // we don't use client_http Do method because we need the reader and is not provided. + // We would be forced to use Pipe and goroutine, etc resp, err := client.Do(req) if resp != nil && resp.Body != nil { defer resp.Body.Close() @@ -188,6 +200,7 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl // If the error type is *url.Error, sanitize its URL before returning. log.Errorf("Error fetching blocklist %s: %s", *blocklist.URL, err) + return nil, false, err } @@ -197,13 +210,18 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl } else { log.Debugf("Blocklist %s has not been modified (decisions about to expire)", *blocklist.URL) } + return nil, false, nil } + if resp.StatusCode != http.StatusOK { log.Debugf("Received nok status code %d for blocklist %s", resp.StatusCode, *blocklist.URL) + return nil, false, nil } + decisions := make([]*models.Decision, 0) + scanner := bufio.NewScanner(resp.Body) for scanner.Scan() { decision := scanner.Text() @@ -227,11 +245,12 @@ func (s *DecisionsService) GetStream(ctx context.Context, opts DecisionsStreamOp if err != nil { return nil, nil, err } - if s.client.URLPrefix == "v3" { - return s.FetchV3Decisions(ctx, u) - } else { + + if s.client.URLPrefix != "v3" { return s.FetchV2Decisions(ctx, u) } + + return s.FetchV3Decisions(ctx, u) } func (s *DecisionsService) GetStreamV3(ctx context.Context, opts DecisionsStreamOpts) (*modelscapi.GetDecisionsStreamResponse, *Response, error) { @@ -239,13 +258,14 @@ func (s *DecisionsService) GetStreamV3(ctx context.Context, opts DecisionsStream if err != nil { return nil, nil, err } - var decisions modelscapi.GetDecisionsStreamResponse req, err := s.client.NewRequest(http.MethodGet, u, nil) if err != nil { return nil, nil, err } + decisions := modelscapi.GetDecisionsStreamResponse{} + resp, err := s.client.Do(ctx, req, &decisions) if err != nil { return nil, resp, err @@ -255,8 +275,8 @@ func (s *DecisionsService) GetStreamV3(ctx context.Context, opts DecisionsStream } func (s *DecisionsService) StopStream(ctx context.Context) (*Response, error) { - u := fmt.Sprintf("%s/decisions", s.client.URLPrefix) + req, err := s.client.NewRequest(http.MethodDelete, u, nil) if err != nil { return nil, err @@ -266,15 +286,16 @@ func (s *DecisionsService) StopStream(ctx context.Context) (*Response, error) { if err != nil { return resp, err } + return resp, nil } func (s *DecisionsService) Delete(ctx context.Context, opts DecisionsDeleteOpts) (*models.DeleteDecisionResponse, *Response, error) { - var deleteDecisionResponse models.DeleteDecisionResponse params, err := qs.Values(opts) if err != nil { return nil, nil, err } + u := fmt.Sprintf("%s/decisions?%s", s.client.URLPrefix, params.Encode()) req, err := s.client.NewRequest(http.MethodDelete, u, nil) @@ -282,25 +303,30 @@ func (s *DecisionsService) Delete(ctx context.Context, opts DecisionsDeleteOpts) return nil, nil, err } + deleteDecisionResponse := models.DeleteDecisionResponse{} + resp, err := s.client.Do(ctx, req, &deleteDecisionResponse) if err != nil { return nil, resp, err } + return &deleteDecisionResponse, resp, nil } -func (s *DecisionsService) DeleteOne(ctx context.Context, decision_id string) (*models.DeleteDecisionResponse, *Response, error) { - var deleteDecisionResponse models.DeleteDecisionResponse - u := fmt.Sprintf("%s/decisions/%s", s.client.URLPrefix, decision_id) +func (s *DecisionsService) DeleteOne(ctx context.Context, decisionID string) (*models.DeleteDecisionResponse, *Response, error) { + u := fmt.Sprintf("%s/decisions/%s", s.client.URLPrefix, decisionID) req, err := s.client.NewRequest(http.MethodDelete, u, nil) if err != nil { return nil, nil, err } + deleteDecisionResponse := models.DeleteDecisionResponse{} + resp, err := s.client.Do(ctx, req, &deleteDecisionResponse) if err != nil { return nil, resp, err } + return &deleteDecisionResponse, resp, nil } diff --git a/pkg/apiclient/decisions_service_test.go b/pkg/apiclient/decisions_service_test.go index ab7e46e644f..54c44f43eda 100644 --- a/pkg/apiclient/decisions_service_test.go +++ b/pkg/apiclient/decisions_service_test.go @@ -2,18 +2,16 @@ package apiclient import ( "context" - "fmt" "net/http" "net/url" - "reflect" "testing" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/crowdsecurity/go-cs-lib/pkg/ptr" - "github.com/crowdsecurity/go-cs-lib/pkg/version" + "github.com/crowdsecurity/go-cs-lib/cstest" + "github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/modelscapi" @@ -27,77 +25,55 @@ func TestDecisionsList(t *testing.T) { mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "GET") + if r.URL.RawQuery == "ip=1.2.3.4" { - assert.Equal(t, r.URL.RawQuery, "ip=1.2.3.4") - assert.Equal(t, r.Header.Get("X-Api-Key"), "ixu") + assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery) + assert.Equal(t, "ixu", r.Header.Get("X-Api-Key")) w.WriteHeader(http.StatusOK) w.Write([]byte(`[{"duration":"3h59m55.756182786s","id":4,"origin":"cscli","scenario":"manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'","scope":"Ip","type":"ban","value":"1.2.3.4"}]`)) } else { w.WriteHeader(http.StatusOK) w.Write([]byte(`null`)) - //no results + // no results } }) + apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) - //ok answer + // ok answer auth := &APIKeyTransport{ APIKey: "ixu", } newcli, err := NewDefaultClient(apiURL, "v1", "toto", auth.Client()) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) - tduration := "3h59m55.756182786s" - torigin := "cscli" - tscenario := "manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'" - tscope := "Ip" - ttype := "ban" - tvalue := "1.2.3.4" expected := &models.GetDecisionsResponse{ &models.Decision{ - Duration: &tduration, + Duration: ptr.Of("3h59m55.756182786s"), ID: 4, - Origin: &torigin, - Scenario: &tscenario, - Scope: &tscope, - Type: &ttype, - Value: &tvalue, + Origin: ptr.Of("cscli"), + Scenario: ptr.Of("manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'"), + Scope: ptr.Of("Ip"), + Type: ptr.Of("ban"), + Value: ptr.Of("1.2.3.4"), }, } - //OK decisions - decisionsFilter := DecisionsListOpts{IPEquals: new(string)} - *decisionsFilter.IPEquals = "1.2.3.4" + // OK decisions + decisionsFilter := DecisionsListOpts{IPEquals: ptr.Of("1.2.3.4")} decisions, resp, err := newcli.Decisions.List(context.Background(), decisionsFilter) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) + assert.Equal(t, *expected, *decisions) - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - - if err != nil { - t.Fatalf("new api client: %s", err) - } - if !reflect.DeepEqual(*decisions, *expected) { - t.Fatalf("returned %+v, want %+v", resp, expected) - } - - //Empty return - decisionsFilter = DecisionsListOpts{IPEquals: new(string)} - *decisionsFilter.IPEquals = "1.2.3.5" + // Empty return + decisionsFilter = DecisionsListOpts{IPEquals: ptr.Of("1.2.3.5")} decisions, resp, err = newcli.Decisions.List(context.Background(), decisionsFilter) require.NoError(t, err) - - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - assert.Equal(t, len(*decisions), 0) - + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) + assert.Empty(t, *decisions) } func TestDecisionsStream(t *testing.T) { @@ -107,9 +83,9 @@ func TestDecisionsStream(t *testing.T) { defer teardown() mux.HandleFunc("/decisions/stream", func(w http.ResponseWriter, r *http.Request) { - - assert.Equal(t, r.Header.Get("X-Api-Key"), "ixu") + assert.Equal(t, "ixu", r.Header.Get("X-Api-Key")) testMethod(t, r, http.MethodGet) + if r.Method == http.MethodGet { if r.URL.RawQuery == "startup=true" { w.WriteHeader(http.StatusOK) @@ -120,80 +96,57 @@ func TestDecisionsStream(t *testing.T) { } } }) + mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, r.Header.Get("X-Api-Key"), "ixu") + assert.Equal(t, "ixu", r.Header.Get("X-Api-Key")) testMethod(t, r, http.MethodDelete) + if r.Method == http.MethodDelete { w.WriteHeader(http.StatusOK) } }) apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) - //ok answer + // ok answer auth := &APIKeyTransport{ APIKey: "ixu", } newcli, err := NewDefaultClient(apiURL, "v1", "toto", auth.Client()) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) - tduration := "3h59m55.756182786s" - torigin := "cscli" - tscenario := "manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'" - tscope := "Ip" - ttype := "ban" - tvalue := "1.2.3.4" expected := &models.DecisionsStreamResponse{ New: models.GetDecisionsResponse{ &models.Decision{ - Duration: &tduration, + Duration: ptr.Of("3h59m55.756182786s"), ID: 4, - Origin: &torigin, - Scenario: &tscenario, - Scope: &tscope, - Type: &ttype, - Value: &tvalue, + Origin: ptr.Of("cscli"), + Scenario: ptr.Of("manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'"), + Scope: ptr.Of("Ip"), + Type: ptr.Of("ban"), + Value: ptr.Of("1.2.3.4"), }, }, } decisions, resp, err := newcli.Decisions.GetStream(context.Background(), DecisionsStreamOpts{Startup: true}) require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) + assert.Equal(t, *expected, *decisions) - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - - if err != nil { - t.Fatalf("new api client: %s", err) - } - if !reflect.DeepEqual(*decisions, *expected) { - t.Fatalf("returned %+v, want %+v", resp, expected) - } - - //and second call, we get empty lists + // and second call, we get empty lists decisions, resp, err = newcli.Decisions.GetStream(context.Background(), DecisionsStreamOpts{Startup: false}) require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) + assert.Empty(t, decisions.New) + assert.Empty(t, decisions.Deleted) - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - assert.Equal(t, 0, len(decisions.New)) - assert.Equal(t, 0, len(decisions.Deleted)) - - //delete stream + // delete stream resp, err = newcli.Decisions.StopStream(context.Background()) require.NoError(t, err) - - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) } func TestDecisionsStreamV3Compatibility(t *testing.T) { @@ -203,9 +156,9 @@ func TestDecisionsStreamV3Compatibility(t *testing.T) { defer teardown() mux.HandleFunc("/decisions/stream", func(w http.ResponseWriter, r *http.Request) { - - assert.Equal(t, r.Header.Get("X-Api-Key"), "ixu") + assert.Equal(t, "ixu", r.Header.Get("X-Api-Key")) testMethod(t, r, http.MethodGet) + if r.Method == http.MethodGet { if r.URL.RawQuery == "startup=true" { w.WriteHeader(http.StatusOK) @@ -218,48 +171,38 @@ func TestDecisionsStreamV3Compatibility(t *testing.T) { }) apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) - //ok answer + // ok answer auth := &APIKeyTransport{ APIKey: "ixu", } newcli, err := NewDefaultClient(apiURL, "v3", "toto", auth.Client()) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) - tduration := "3h59m55.756182786s" torigin := "CAPI" - tscenario := "manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'" tscope := "ip" ttype := "ban" - tvalue := "1.2.3.4" - tvalue1 := "1.2.3.5" - tscenarioDeleted := "deleted" - tdurationDeleted := "1h" expected := &models.DecisionsStreamResponse{ New: models.GetDecisionsResponse{ &models.Decision{ - Duration: &tduration, + Duration: ptr.Of("3h59m55.756182786s"), Origin: &torigin, - Scenario: &tscenario, + Scenario: ptr.Of("manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'"), Scope: &tscope, Type: &ttype, - Value: &tvalue, + Value: ptr.Of("1.2.3.4"), }, }, Deleted: models.GetDecisionsResponse{ &models.Decision{ - Duration: &tdurationDeleted, + Duration: ptr.Of("1h"), Origin: &torigin, - Scenario: &tscenarioDeleted, + Scenario: ptr.Of("deleted"), Scope: &tscope, Type: &ttype, - Value: &tvalue1, + Value: ptr.Of("1.2.3.5"), }, }, } @@ -267,17 +210,8 @@ func TestDecisionsStreamV3Compatibility(t *testing.T) { // GetStream is supposed to consume v3 payload and return v2 response decisions, resp, err := newcli.Decisions.GetStream(context.Background(), DecisionsStreamOpts{Startup: true}) require.NoError(t, err) - - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - - if err != nil { - t.Fatalf("new api client: %s", err) - } - if !reflect.DeepEqual(*decisions, *expected) { - t.Fatalf("returned %+v, want %+v", resp, expected) - } + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) + assert.Equal(t, *expected, *decisions) } func TestDecisionsStreamV3(t *testing.T) { @@ -287,9 +221,9 @@ func TestDecisionsStreamV3(t *testing.T) { defer teardown() mux.HandleFunc("/decisions/stream", func(w http.ResponseWriter, r *http.Request) { - - assert.Equal(t, r.Header.Get("X-Api-Key"), "ixu") + assert.Equal(t, "ixu", r.Header.Get("X-Api-Key")) testMethod(t, r, http.MethodGet) + if r.Method == http.MethodGet { w.WriteHeader(http.StatusOK) w.Write([]byte(`{"deleted":[{"scope":"ip","decisions":["1.2.3.5"]}], @@ -299,40 +233,27 @@ func TestDecisionsStreamV3(t *testing.T) { }) apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) - //ok answer + // ok answer auth := &APIKeyTransport{ APIKey: "ixu", } newcli, err := NewDefaultClient(apiURL, "v3", "toto", auth.Client()) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) - tduration := "3h59m55.756182786s" - tscenario := "manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'" tscope := "ip" - tvalue := "1.2.3.4" - tvalue1 := "1.2.3.5" - tdurationBlocklist := "24h" - tnameBlocklist := "blocklist1" - tremediationBlocklist := "ban" - tscopeBlocklist := "ip" - turlBlocklist := "/v3/blocklist" expected := &modelscapi.GetDecisionsStreamResponse{ New: modelscapi.GetDecisionsStreamResponseNew{ &modelscapi.GetDecisionsStreamResponseNewItem{ Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{ { - Duration: &tduration, - Value: &tvalue, + Duration: ptr.Of("3h59m55.756182786s"), + Value: ptr.Of("1.2.3.4"), }, }, - Scenario: &tscenario, + Scenario: ptr.Of("manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'"), Scope: &tscope, }, }, @@ -340,18 +261,18 @@ func TestDecisionsStreamV3(t *testing.T) { &modelscapi.GetDecisionsStreamResponseDeletedItem{ Scope: &tscope, Decisions: []string{ - tvalue1, + "1.2.3.5", }, }, }, Links: &modelscapi.GetDecisionsStreamResponseLinks{ Blocklists: []*modelscapi.BlocklistLink{ { - Duration: &tdurationBlocklist, - Name: &tnameBlocklist, - Remediation: &tremediationBlocklist, - Scope: &tscopeBlocklist, - URL: &turlBlocklist, + Duration: ptr.Of("24h"), + Name: ptr.Of("blocklist1"), + Remediation: ptr.Of("ban"), + Scope: ptr.Of("ip"), + URL: ptr.Of("/v3/blocklist"), }, }, }, @@ -360,17 +281,8 @@ func TestDecisionsStreamV3(t *testing.T) { // GetStream is supposed to consume v3 payload and return v2 response decisions, resp, err := newcli.Decisions.GetStreamV3(context.Background(), DecisionsStreamOpts{Startup: true}) require.NoError(t, err) - - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - - if err != nil { - t.Fatalf("new api client: %s", err) - } - if !reflect.DeepEqual(*decisions, *expected) { - t.Fatalf("returned %+v, want %+v", resp, expected) - } + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) + assert.Equal(t, *expected, *decisions) } func TestDecisionsFromBlocklist(t *testing.T) { @@ -381,10 +293,13 @@ func TestDecisionsFromBlocklist(t *testing.T) { mux.HandleFunc("/blocklist", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, http.MethodGet) + if r.Header.Get("If-Modified-Since") == "Sun, 01 Jan 2023 01:01:01 GMT" { w.WriteHeader(http.StatusNotModified) + return } + if r.Method == http.MethodGet { w.WriteHeader(http.StatusOK) w.Write([]byte("1.2.3.4\r\n1.2.3.5")) @@ -392,22 +307,16 @@ func TestDecisionsFromBlocklist(t *testing.T) { }) apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) - //ok answer + // ok answer auth := &APIKeyTransport{ APIKey: "ixu", } newcli, err := NewDefaultClient(apiURL, "v3", "toto", auth.Client()) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) - tvalue1 := "1.2.3.4" - tvalue2 := "1.2.3.5" tdurationBlocklist := "24h" tnameBlocklist := "blocklist1" tremediationBlocklist := "ban" @@ -417,7 +326,7 @@ func TestDecisionsFromBlocklist(t *testing.T) { expected := []*models.Decision{ { Duration: &tdurationBlocklist, - Value: &tvalue1, + Value: ptr.Of("1.2.3.4"), Scenario: &tnameBlocklist, Scope: &tscopeBlocklist, Type: &tremediationBlocklist, @@ -425,7 +334,7 @@ func TestDecisionsFromBlocklist(t *testing.T) { }, { Duration: &tdurationBlocklist, - Value: &tvalue2, + Value: ptr.Of("1.2.3.5"), Scenario: &tnameBlocklist, Scope: &tscopeBlocklist, Type: &tremediationBlocklist, @@ -448,12 +357,7 @@ func TestDecisionsFromBlocklist(t *testing.T) { log.Infof("expected : %s, %s, %s, %s, %s", *expected[0].Value, *expected[0].Duration, *expected[0].Scenario, *expected[0].Scope, *expected[0].Type) log.Infof("decisions: %s, %s, %s, %s, %s", *decisions[1].Value, *decisions[1].Duration, *decisions[1].Scenario, *decisions[1].Scope, *decisions[1].Type) - if err != nil { - t.Fatalf("new api client: %s", err) - } - if !reflect.DeepEqual(decisions, expected) { - t.Fatalf("returned %+v, want %+v", decisions, expected) - } + assert.Equal(t, expected, decisions) // test cache control _, isModified, err = newcli.Decisions.GetDecisionsFromBlocklist(context.Background(), &modelscapi.BlocklistLink{ @@ -463,8 +367,10 @@ func TestDecisionsFromBlocklist(t *testing.T) { Name: &tnameBlocklist, Duration: &tdurationBlocklist, }, ptr.Of("Sun, 01 Jan 2023 01:01:01 GMT")) + require.NoError(t, err) assert.False(t, isModified) + _, isModified, err = newcli.Decisions.GetDecisionsFromBlocklist(context.Background(), &modelscapi.BlocklistLink{ URL: &turlBlocklist, Scope: &tscopeBlocklist, @@ -472,6 +378,7 @@ func TestDecisionsFromBlocklist(t *testing.T) { Name: &tnameBlocklist, Duration: &tdurationBlocklist, }, ptr.Of("Mon, 02 Jan 2023 01:01:01 GMT")) + require.NoError(t, err) assert.True(t, isModified) } @@ -482,36 +389,33 @@ func TestDeleteDecisions(t *testing.T) { w.WriteHeader(http.StatusOK) w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`)) }) + mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "DELETE") - assert.Equal(t, r.URL.RawQuery, "ip=1.2.3.4") + assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery) w.WriteHeader(http.StatusOK) w.Write([]byte(`{"nbDeleted":"1"}`)) - //w.Write([]byte(`{"message":"0 deleted alerts"}`)) + // w.Write([]byte(`{"message":"0 deleted alerts"}`)) }) + log.Printf("URL is %s", urlx) + apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) + client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", }) - - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) filters := DecisionsDeleteOpts{IPEquals: new(string)} *filters.IPEquals = "1.2.3.4" + deleted, _, err := client.Decisions.Delete(context.Background(), filters) - if err != nil { - t.Fatalf("unexpected err : %s", err) - } + require.NoError(t, err) assert.Equal(t, "1", deleted.NbDeleted) defer teardown() @@ -519,28 +423,30 @@ func TestDeleteDecisions(t *testing.T) { func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) { baseURLString := "http://localhost:8080/v1/decisions/stream" + type fields struct { Startup bool Scopes string ScenariosContaining string ScenariosNotContaining string } + tests := []struct { - name string - fields fields - want string - wantErr bool + name string + fields fields + expected string + expectedErr string }{ { - name: "no filter", - want: baseURLString + "?", + name: "no filter", + expected: baseURLString + "?", }, { name: "startup=true", fields: fields{ Startup: true, }, - want: baseURLString + "?startup=true", + expected: baseURLString + "?startup=true", }, { name: "set all params", @@ -550,11 +456,11 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) { ScenariosContaining: "ssh", ScenariosNotContaining: "bf", }, - want: baseURLString + "?scenarios_containing=ssh&scenarios_not_containing=bf&scopes=ip%2Crange&startup=true", + expected: baseURLString + "?scenarios_containing=ssh&scenarios_not_containing=bf&scopes=ip%2Crange&startup=true", }, } + for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { o := &DecisionsStreamOpts{ Startup: tt.fields.Startup, @@ -562,25 +468,21 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) { ScenariosContaining: tt.fields.ScenariosContaining, ScenariosNotContaining: tt.fields.ScenariosNotContaining, } + got, err := o.addQueryParamsToURL(baseURLString) - if (err != nil) != tt.wantErr { - t.Errorf("DecisionsStreamOpts.addQueryParamsToURL() error = %v, wantErr %v", err, tt.wantErr) + cstest.RequireErrorContains(t, err, tt.expectedErr) + + if tt.expectedErr != "" { return } gotURL, err := url.Parse(got) - if err != nil { - t.Errorf("DecisionsStreamOpts.addQueryParamsToURL() got error while parsing URL: %s", err) - } + require.NoError(t, err) - expectedURL, err := url.Parse(tt.want) - if err != nil { - t.Errorf("DecisionsStreamOpts.addQueryParamsToURL() got error while parsing URL: %s", err) - } + expectedURL, err := url.Parse(tt.expected) + require.NoError(t, err) - if *gotURL != *expectedURL { - t.Errorf("DecisionsStreamOpts.addQueryParamsToURL() = %v, want %v", *gotURL, *expectedURL) - } + assert.Equal(t, *expectedURL, *gotURL) }) } } @@ -604,7 +506,6 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) { // client, err := NewClient(&Config{ // MachineID: "test_login", // Password: "test_password", -// UserAgent: fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()), // URL: apiURL, // VersionPrefix: "v1", // }) diff --git a/pkg/apiclient/decisions_sync_service.go b/pkg/apiclient/decisions_sync_service.go index 57999691f21..25e33a8e29d 100644 --- a/pkg/apiclient/decisions_sync_service.go +++ b/pkg/apiclient/decisions_sync_service.go @@ -14,21 +14,25 @@ type DecisionDeleteService service // DecisionDeleteService purposely reuses AddSignalsRequestItemDecisions model func (d *DecisionDeleteService) Add(ctx context.Context, deletedDecisions *models.DecisionsDeleteRequest) (interface{}, *Response, error) { - var response interface{} u := fmt.Sprintf("%s/decisions/delete", d.client.URLPrefix) + req, err := d.client.NewRequest(http.MethodPost, u, &deletedDecisions) if err != nil { return nil, nil, fmt.Errorf("while building request: %w", err) } + var response interface{} + resp, err := d.client.Do(ctx, req, &response) if err != nil { return nil, resp, fmt.Errorf("while performing request: %w", err) } + if resp.Response.StatusCode != http.StatusOK { - log.Warnf("Decisions delete response : http %s", resp.Response.Status) + log.Warnf("Decisions delete response: http %s", resp.Response.Status) } else { - log.Debugf("Decisions delete response : http %s", resp.Response.Status) + log.Debugf("Decisions delete response: http %s", resp.Response.Status) } + return &response, resp, nil } diff --git a/pkg/apiclient/heartbeat.go b/pkg/apiclient/heartbeat.go index 497ccb7eb32..c6b3d0832ba 100644 --- a/pkg/apiclient/heartbeat.go +++ b/pkg/apiclient/heartbeat.go @@ -9,13 +9,12 @@ import ( log "github.com/sirupsen/logrus" tomb "gopkg.in/tomb.v2" - "github.com/crowdsecurity/go-cs-lib/pkg/trace" + "github.com/crowdsecurity/go-cs-lib/trace" ) type HeartBeatService service func (h *HeartBeatService) Ping(ctx context.Context) (bool, *Response, error) { - u := fmt.Sprintf("%s/heartbeat", h.client.URLPrefix) req, err := h.client.NewRequest(http.MethodGet, u, nil) @@ -39,16 +38,19 @@ func (h *HeartBeatService) StartHeartBeat(ctx context.Context, t *tomb.Tomb) { select { case <-hbTimer.C: log.Debug("heartbeat: sending heartbeat") + ok, resp, err := h.Ping(ctx) if err != nil { - log.Errorf("heartbeat error : %s", err) + log.Errorf("heartbeat error: %s", err) continue } + resp.Response.Body.Close() if resp.Response.StatusCode != http.StatusOK { - log.Errorf("heartbeat unexpected return code : %d", resp.Response.StatusCode) + log.Errorf("heartbeat unexpected return code: %d", resp.Response.StatusCode) continue } + if !ok { log.Errorf("heartbeat returned false") continue diff --git a/pkg/apiclient/metrics.go b/pkg/apiclient/metrics.go index ea447280aa5..7f8d095a2df 100644 --- a/pkg/apiclient/metrics.go +++ b/pkg/apiclient/metrics.go @@ -11,17 +11,19 @@ import ( type MetricsService service func (s *MetricsService) Add(ctx context.Context, metrics *models.Metrics) (interface{}, *Response, error) { - var response interface{} - u := fmt.Sprintf("%s/metrics/", s.client.URLPrefix) + req, err := s.client.NewRequest(http.MethodPost, u, &metrics) if err != nil { return nil, nil, err } + var response interface{} + resp, err := s.client.Do(ctx, req, &response) if err != nil { return nil, resp, err } + return &response, resp, nil } diff --git a/pkg/apiclient/resperr.go b/pkg/apiclient/resperr.go new file mode 100644 index 00000000000..1b0786f9882 --- /dev/null +++ b/pkg/apiclient/resperr.go @@ -0,0 +1,61 @@ +package apiclient + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/crowdsecurity/go-cs-lib/ptr" + + "github.com/crowdsecurity/crowdsec/pkg/models" +) + +type ErrorResponse struct { + models.ErrorResponse +} + +func (e *ErrorResponse) Error() string { + message := ptr.OrEmpty(e.Message) + errors := "" + + if e.Errors != "" { + errors = fmt.Sprintf(" (%s)", e.Errors) + } + + if message == "" && errors == "" { + errors = "(no errors)" + } + + return fmt.Sprintf("API error: %s%s", message, errors) +} + +// CheckResponse verifies the API response and builds an appropriate Go error if necessary. +func CheckResponse(r *http.Response) error { + if c := r.StatusCode; 200 <= c && c <= 299 || c == 304 { + return nil + } + + ret := &ErrorResponse{} + + data, err := io.ReadAll(r.Body) + if err != nil || len(data) == 0 { + ret.Message = ptr.Of(fmt.Sprintf("http code %d, no response body", r.StatusCode)) + return ret + } + + switch r.StatusCode { + case http.StatusUnprocessableEntity: + ret.Message = ptr.Of(fmt.Sprintf("http code %d, invalid request: %s", r.StatusCode, string(data))) + default: + // try to unmarshal and if there are no 'message' or 'errors' fields, display the body as is, + // the API is following a different convention + err := json.Unmarshal(data, ret) + if err != nil || (ret.Message == nil && ret.Errors == "") { + ret.Message = ptr.Of(fmt.Sprintf("http code %d, response: %s", r.StatusCode, string(data))) + return ret + } + } + + return ret +} diff --git a/pkg/apiclient/retry_config.go b/pkg/apiclient/retry_config.go new file mode 100644 index 00000000000..8a0d1096f84 --- /dev/null +++ b/pkg/apiclient/retry_config.go @@ -0,0 +1,33 @@ +package apiclient + +type StatusCodeConfig struct { + MaxAttempts int + Backoff bool + InvalidateToken bool +} + +type RetryConfig struct { + StatusCodeConfig map[int]StatusCodeConfig +} + +type RetryConfigOption func(*RetryConfig) + +func NewRetryConfig(options ...RetryConfigOption) *RetryConfig { + rc := &RetryConfig{ + StatusCodeConfig: make(map[int]StatusCodeConfig), + } + for _, opt := range options { + opt(rc) + } + return rc +} + +func WithStatusCodeConfig(statusCode int, maxAttempts int, backOff bool, invalidateToken bool) RetryConfigOption { + return func(rc *RetryConfig) { + rc.StatusCodeConfig[statusCode] = StatusCodeConfig{ + MaxAttempts: maxAttempts, + Backoff: backOff, + InvalidateToken: invalidateToken, + } + } +} diff --git a/pkg/apiclient/signal.go b/pkg/apiclient/signal.go index 2dceb815754..613ce70bbfb 100644 --- a/pkg/apiclient/signal.go +++ b/pkg/apiclient/signal.go @@ -13,22 +13,25 @@ import ( type SignalService service func (s *SignalService) Add(ctx context.Context, signals *models.AddSignalsRequest) (interface{}, *Response, error) { - var response interface{} - u := fmt.Sprintf("%s/signals", s.client.URLPrefix) + req, err := s.client.NewRequest(http.MethodPost, u, &signals) if err != nil { return nil, nil, fmt.Errorf("while building request: %w", err) } + var response interface{} + resp, err := s.client.Do(ctx, req, &response) if err != nil { return nil, resp, fmt.Errorf("while performing request: %w", err) } + if resp.Response.StatusCode != http.StatusOK { log.Warnf("Signal push response : http %s", resp.Response.Status) } else { log.Debugf("Signal push response : http %s", resp.Response.Status) } + return &response, resp, nil } diff --git a/pkg/apiclient/usagemetrics.go b/pkg/apiclient/usagemetrics.go new file mode 100644 index 00000000000..1d822bb5c1e --- /dev/null +++ b/pkg/apiclient/usagemetrics.go @@ -0,0 +1,29 @@ +package apiclient + +import ( + "context" + "fmt" + "net/http" + + "github.com/crowdsecurity/crowdsec/pkg/models" +) + +type UsageMetricsService service + +func (s *UsageMetricsService) Add(ctx context.Context, metrics *models.AllMetrics) (interface{}, *Response, error) { + u := fmt.Sprintf("%s/usage-metrics", s.client.URLPrefix) + + req, err := s.client.NewRequest(http.MethodPost, u, &metrics) + if err != nil { + return nil, nil, err + } + + var response interface{} + + resp, err := s.client.Do(ctx, req, &response) + if err != nil { + return nil, resp, err + } + + return &response, resp, nil +} diff --git a/pkg/apiclient/useragent/useragent.go b/pkg/apiclient/useragent/useragent.go new file mode 100644 index 00000000000..5a62ce1ac06 --- /dev/null +++ b/pkg/apiclient/useragent/useragent.go @@ -0,0 +1,9 @@ +package useragent + +import ( + "github.com/crowdsecurity/go-cs-lib/version" +) + +func Default() string { + return "crowdsec/" + version.String() + "-" + version.System +} diff --git a/pkg/apiserver/alerts_test.go b/pkg/apiserver/alerts_test.go index 5fd23d116ad..4cc215c344f 100644 --- a/pkg/apiserver/alerts_test.go +++ b/pkg/apiserver/alerts_test.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "encoding/json" "fmt" "net/http" @@ -9,34 +10,27 @@ import ( "sync" "testing" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csplugin" "github.com/crowdsecurity/crowdsec/pkg/models" - "github.com/gin-gonic/gin" - - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" ) type LAPI struct { router *gin.Engine loginResp models.WatcherAuthResponse bouncerKey string - t *testing.T DBConfig *csconfig.DatabaseCfg } -func SetupLAPITest(t *testing.T) LAPI { +func SetupLAPITest(t *testing.T, ctx context.Context) LAPI { t.Helper() - router, loginResp, config, err := InitMachineTest(t) - if err != nil { - t.Fatal(err) - } + router, loginResp, config := InitMachineTest(t, ctx) - APIKey, err := CreateTestBouncer(config.API.Server.DbConfig) - if err != nil { - t.Fatal(err) - } + APIKey := CreateTestBouncer(t, ctx, config.API.Server.DbConfig) return LAPI{ router: router, @@ -46,63 +40,51 @@ func SetupLAPITest(t *testing.T) LAPI { } } -func (l *LAPI) InsertAlertFromFile(path string) *httptest.ResponseRecorder { - alertReader := GetAlertReaderFromFile(path) - return l.RecordResponse(http.MethodPost, "/v1/alerts", alertReader, "password") +func (l *LAPI) InsertAlertFromFile(t *testing.T, ctx context.Context, path string) *httptest.ResponseRecorder { + alertReader := GetAlertReaderFromFile(t, path) + return l.RecordResponse(t, ctx, http.MethodPost, "/v1/alerts", alertReader, "password") } -func (l *LAPI) RecordResponse(verb string, url string, body *strings.Reader, authType string) *httptest.ResponseRecorder { +func (l *LAPI) RecordResponse(t *testing.T, ctx context.Context, verb string, url string, body *strings.Reader, authType string) *httptest.ResponseRecorder { w := httptest.NewRecorder() - req, err := http.NewRequest(verb, url, body) - if err != nil { - l.t.Fatal(err) - } - if authType == "apikey" { + req, err := http.NewRequestWithContext(ctx, verb, url, body) + require.NoError(t, err) + + switch authType { + case "apikey": req.Header.Add("X-Api-Key", l.bouncerKey) - } else if authType == "password" { + case "password": AddAuthHeaders(req, l.loginResp) - } else { - l.t.Fatal("auth type not supported") + default: + t.Fatal("auth type not supported") } + l.router.ServeHTTP(w, req) + return w } -func InitMachineTest(t *testing.T) (*gin.Engine, models.WatcherAuthResponse, csconfig.Config, error) { - router, config, err := NewAPITest(t) - if err != nil { - return nil, models.WatcherAuthResponse{}, config, fmt.Errorf("unable to run local API: %s", err) - } +func InitMachineTest(t *testing.T, ctx context.Context) (*gin.Engine, models.WatcherAuthResponse, csconfig.Config) { + router, config := NewAPITest(t, ctx) + loginResp := LoginToTestAPI(t, ctx, router, config) - loginResp, err := LoginToTestAPI(router, config) - if err != nil { - return nil, models.WatcherAuthResponse{}, config, err - } - return router, loginResp, config, nil + return router, loginResp, config } -func LoginToTestAPI(router *gin.Engine, config csconfig.Config) (models.WatcherAuthResponse, error) { - body, err := CreateTestMachine(router) - if err != nil { - return models.WatcherAuthResponse{}, err - } - err = ValidateMachine("test", config.API.Server.DbConfig) - if err != nil { - log.Fatalln(err) - } +func LoginToTestAPI(t *testing.T, ctx context.Context, router *gin.Engine, config csconfig.Config) models.WatcherAuthResponse { + body := CreateTestMachine(t, ctx, router, "") + ValidateMachine(t, ctx, "test", config.API.Server.DbConfig) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) loginResp := models.WatcherAuthResponse{} - err = json.NewDecoder(w.Body).Decode(&loginResp) - if err != nil { - return models.WatcherAuthResponse{}, err - } + err := json.NewDecoder(w.Body).Decode(&loginResp) + require.NoError(t, err) - return loginResp, nil + return loginResp } func AddAuthHeaders(request *http.Request, authResponse models.WatcherAuthResponse) { @@ -111,298 +93,303 @@ func AddAuthHeaders(request *http.Request, authResponse models.WatcherAuthRespon } func TestSimulatedAlert(t *testing.T) { - lapi := SetupLAPITest(t) - lapi.InsertAlertFromFile("./tests/alert_minibulk+simul.json") - alertContent := GetAlertReaderFromFile("./tests/alert_minibulk+simul.json") - //exclude decision in simulation mode + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_minibulk+simul.json") + alertContent := GetAlertReaderFromFile(t, "./tests/alert_minibulk+simul.json") + // exclude decision in simulation mode - w := lapi.RecordResponse("GET", "/v1/alerts?simulated=false", alertContent, "password") + w := lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?simulated=false", alertContent, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.178 performed crowdsecurity/ssh-bf (6 events over `) assert.NotContains(t, w.Body.String(), `"message":"Ip 91.121.79.179 performed crowdsecurity/ssh-bf (6 events over `) - //include decision in simulation mode + // include decision in simulation mode - w = lapi.RecordResponse("GET", "/v1/alerts?simulated=true", alertContent, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?simulated=true", alertContent, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.178 performed crowdsecurity/ssh-bf (6 events over `) assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.179 performed crowdsecurity/ssh-bf (6 events over `) } func TestCreateAlert(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) // Create Alert with invalid format - w := lapi.RecordResponse(http.MethodPost, "/v1/alerts", strings.NewReader("test"), "password") + w := lapi.RecordResponse(t, ctx, http.MethodPost, "/v1/alerts", strings.NewReader("test"), "password") assert.Equal(t, 400, w.Code) - assert.Equal(t, "{\"message\":\"invalid character 'e' in literal true (expecting 'r')\"}", w.Body.String()) + assert.Equal(t, `{"message":"invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String()) // Create Alert with invalid input - alertContent := GetAlertReaderFromFile("./tests/invalidAlert_sample.json") + alertContent := GetAlertReaderFromFile(t, "./tests/invalidAlert_sample.json") - w = lapi.RecordResponse(http.MethodPost, "/v1/alerts", alertContent, "password") + w = lapi.RecordResponse(t, ctx, http.MethodPost, "/v1/alerts", alertContent, "password") assert.Equal(t, 500, w.Code) - assert.Equal(t, "{\"message\":\"validation failure list:\\n0.scenario in body is required\\n0.scenario_hash in body is required\\n0.scenario_version in body is required\\n0.simulated in body is required\\n0.source in body is required\"}", w.Body.String()) + assert.Equal(t, + `{"message":"validation failure list:\n0.scenario in body is required\n0.scenario_hash in body is required\n0.scenario_version in body is required\n0.simulated in body is required\n0.source in body is required"}`, + w.Body.String()) // Create Valid Alert - w = lapi.InsertAlertFromFile("./tests/alert_sample.json") + w = lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") assert.Equal(t, 201, w.Code) - assert.Equal(t, "[\"1\"]", w.Body.String()) + assert.Equal(t, `["1"]`, w.Body.String()) } func TestCreateAlertChannels(t *testing.T) { - - apiServer, config, err := NewAPIServer(t) - if err != nil { - log.Fatalln(err) - } + ctx := context.Background() + apiServer, config := NewAPIServer(t, ctx) apiServer.controller.PluginChannel = make(chan csplugin.ProfileAlert) apiServer.InitController() - loginResp, err := LoginToTestAPI(apiServer.router, config) - if err != nil { - log.Fatalln(err) - } + loginResp := LoginToTestAPI(t, ctx, apiServer.router, config) lapi := LAPI{router: apiServer.router, loginResp: loginResp} - var pd csplugin.ProfileAlert - var wg sync.WaitGroup + var ( + pd csplugin.ProfileAlert + wg sync.WaitGroup + ) wg.Add(1) + go func() { pd = <-apiServer.controller.PluginChannel + wg.Done() }() - go lapi.InsertAlertFromFile("./tests/alert_ssh-bf.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_ssh-bf.json") wg.Wait() - assert.Equal(t, len(pd.Alert.Decisions), 1) + assert.Len(t, pd.Alert.Decisions, 1) apiServer.Close() } func TestAlertListFilters(t *testing.T) { - lapi := SetupLAPITest(t) - lapi.InsertAlertFromFile("./tests/alert_ssh-bf.json") - alertContent := GetAlertReaderFromFile("./tests/alert_ssh-bf.json") + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_ssh-bf.json") + alertContent := GetAlertReaderFromFile(t, "./tests/alert_ssh-bf.json") - //bad filter + // bad filter - w := lapi.RecordResponse("GET", "/v1/alerts?test=test", alertContent, "password") + w := lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?test=test", alertContent, "password") assert.Equal(t, 500, w.Code) - assert.Equal(t, "{\"message\":\"Filter parameter 'test' is unknown (=test): invalid filter\"}", w.Body.String()) + assert.Equal(t, `{"message":"Filter parameter 'test' is unknown (=test): invalid filter"}`, w.Body.String()) - //get without filters + // get without filters - w = lapi.RecordResponse("GET", "/v1/alerts", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts", emptyBody, "password") assert.Equal(t, 200, w.Code) - //check alert and decision + // check alert and decision assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test decision_type filter (ok) + // test decision_type filter (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?decision_type=ban", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?decision_type=ban", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test decision_type filter (bad value) + // test decision_type filter (bad value) - w = lapi.RecordResponse("GET", "/v1/alerts?decision_type=ratata", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?decision_type=ratata", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) - //test scope (ok) + // test scope (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?scope=Ip", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?scope=Ip", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test scope (bad value) + // test scope (bad value) - w = lapi.RecordResponse("GET", "/v1/alerts?scope=rarara", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?scope=rarara", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) - //test scenario (ok) + // test scenario (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?scenario=crowdsecurity/ssh-bf", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?scenario=crowdsecurity/ssh-bf", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test scenario (bad value) + // test scenario (bad value) - w = lapi.RecordResponse("GET", "/v1/alerts?scenario=crowdsecurity/nope", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?scenario=crowdsecurity/nope", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) - //test ip (ok) + // test ip (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?ip=91.121.79.195", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?ip=91.121.79.195", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test ip (bad value) + // test ip (bad value) - w = lapi.RecordResponse("GET", "/v1/alerts?ip=99.122.77.195", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?ip=99.122.77.195", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) - //test ip (invalid value) + // test ip (invalid value) - w = lapi.RecordResponse("GET", "/v1/alerts?ip=gruueq", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?ip=gruueq", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Equal(t, `{"message":"unable to convert 'gruueq' to int: invalid address: invalid ip address / range"}`, w.Body.String()) - //test range (ok) + // test range (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?range=91.121.79.0/24&contains=false", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?range=91.121.79.0/24&contains=false", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test range + // test range - w = lapi.RecordResponse("GET", "/v1/alerts?range=99.122.77.0/24&contains=false", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?range=99.122.77.0/24&contains=false", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) - //test range (invalid value) + // test range (invalid value) - w = lapi.RecordResponse("GET", "/v1/alerts?range=ratata", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?range=ratata", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Equal(t, `{"message":"unable to convert 'ratata' to int: invalid address: invalid ip address / range"}`, w.Body.String()) - //test since (ok) + // test since (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?since=1h", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?since=1h", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test since (ok but yields no results) + // test since (ok but yields no results) - w = lapi.RecordResponse("GET", "/v1/alerts?since=1ns", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?since=1ns", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) - //test since (invalid value) + // test since (invalid value) - w = lapi.RecordResponse("GET", "/v1/alerts?since=1zuzu", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?since=1zuzu", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Contains(t, w.Body.String(), `{"message":"while parsing duration: time: unknown unit`) - //test until (ok) + // test until (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?until=1ns", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?until=1ns", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test until (ok but no return) + // test until (ok but no return) - w = lapi.RecordResponse("GET", "/v1/alerts?until=1m", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?until=1m", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) - //test until (invalid value) + // test until (invalid value) - w = lapi.RecordResponse("GET", "/v1/alerts?until=1zuzu", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?until=1zuzu", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Contains(t, w.Body.String(), `{"message":"while parsing duration: time: unknown unit`) - //test simulated (ok) + // test simulated (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?simulated=true", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?simulated=true", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test simulated (ok) + // test simulated (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?simulated=false", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?simulated=false", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test has active decision + // test has active decision - w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=true", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?has_active_decision=true", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test has active decision + // test has active decision - w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=false", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?has_active_decision=false", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) - //test has active decision (invalid value) + // test has active decision (invalid value) - w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=ratatqata", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?has_active_decision=ratatqata", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Equal(t, `{"message":"'ratatqata' is not a boolean: strconv.ParseBool: parsing \"ratatqata\": invalid syntax: unable to parse type"}`, w.Body.String()) - } func TestAlertBulkInsert(t *testing.T) { - lapi := SetupLAPITest(t) - //insert a bulk of 20 alerts to trigger bulk insert - lapi.InsertAlertFromFile("./tests/alert_bulk.json") - alertContent := GetAlertReaderFromFile("./tests/alert_bulk.json") + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) + // insert a bulk of 20 alerts to trigger bulk insert + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_bulk.json") + alertContent := GetAlertReaderFromFile(t, "./tests/alert_bulk.json") - w := lapi.RecordResponse("GET", "/v1/alerts", alertContent, "password") + w := lapi.RecordResponse(t, ctx, "GET", "/v1/alerts", alertContent, "password") assert.Equal(t, 200, w.Code) } func TestListAlert(t *testing.T) { - lapi := SetupLAPITest(t) - lapi.InsertAlertFromFile("./tests/alert_sample.json") + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") // List Alert with invalid filter - w := lapi.RecordResponse("GET", "/v1/alerts?test=test", emptyBody, "password") + w := lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?test=test", emptyBody, "password") assert.Equal(t, 500, w.Code) - assert.Equal(t, "{\"message\":\"Filter parameter 'test' is unknown (=test): invalid filter\"}", w.Body.String()) + assert.Equal(t, `{"message":"Filter parameter 'test' is unknown (=test): invalid filter"}`, w.Body.String()) // List Alert - w = lapi.RecordResponse("GET", "/v1/alerts", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "crowdsecurity/test") } func TestCreateAlertErrors(t *testing.T) { - lapi := SetupLAPITest(t) - alertContent := GetAlertReaderFromFile("./tests/alert_sample.json") + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) + alertContent := GetAlertReaderFromFile(t, "./tests/alert_sample.json") - //test invalid bearer + // test invalid bearer w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/alerts", alertContent) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/alerts", alertContent) req.Header.Add("User-Agent", UserAgent) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "ratata")) lapi.router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) - //test invalid bearer + // test invalid bearer w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/alerts", alertContent) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/alerts", alertContent) req.Header.Add("User-Agent", UserAgent) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", lapi.loginResp.Token+"s")) lapi.router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) - } func TestDeleteAlert(t *testing.T) { - lapi := SetupLAPITest(t) - lapi.InsertAlertFromFile("./tests/alert_sample.json") + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") // Fail Delete Alert w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodDelete, "/v1/alerts", strings.NewReader("")) + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts", strings.NewReader("")) AddAuthHeaders(req, lapi.loginResp) req.RemoteAddr = "127.0.0.2:4242" lapi.router.ServeHTTP(w, req) @@ -411,7 +398,7 @@ func TestDeleteAlert(t *testing.T) { // Delete Alert w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodDelete, "/v1/alerts", strings.NewReader("")) + req, _ = http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts", strings.NewReader("")) AddAuthHeaders(req, lapi.loginResp) req.RemoteAddr = "127.0.0.1:4242" lapi.router.ServeHTTP(w, req) @@ -420,12 +407,13 @@ func TestDeleteAlert(t *testing.T) { } func TestDeleteAlertByID(t *testing.T) { - lapi := SetupLAPITest(t) - lapi.InsertAlertFromFile("./tests/alert_sample.json") + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") // Fail Delete Alert w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodDelete, "/v1/alerts/1", strings.NewReader("")) + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts/1", strings.NewReader("")) AddAuthHeaders(req, lapi.loginResp) req.RemoteAddr = "127.0.0.2:4242" lapi.router.ServeHTTP(w, req) @@ -434,7 +422,7 @@ func TestDeleteAlertByID(t *testing.T) { // Delete Alert w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodDelete, "/v1/alerts/1", strings.NewReader("")) + req, _ = http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts/1", strings.NewReader("")) AddAuthHeaders(req, lapi.loginResp) req.RemoteAddr = "127.0.0.1:4242" lapi.router.ServeHTTP(w, req) @@ -443,36 +431,30 @@ func TestDeleteAlertByID(t *testing.T) { } func TestDeleteAlertTrustedIPS(t *testing.T) { + ctx := context.Background() cfg := LoadTestConfig(t) // IPv6 mocking doesn't seem to work. // cfg.API.Server.TrustedIPs = []string{"1.2.3.4", "1.2.4.0/24", "::"} cfg.API.Server.TrustedIPs = []string{"1.2.3.4", "1.2.4.0/24"} cfg.API.Server.ListenURI = "::8080" - server, err := NewServer(cfg.API.Server) - if err != nil { - log.Fatal(err) - } + server, err := NewServer(ctx, cfg.API.Server) + require.NoError(t, err) + err = server.InitController() - if err != nil { - log.Fatal(err) - } + require.NoError(t, err) + router, err := server.Router() - if err != nil { - log.Fatal(err) - } - loginResp, err := LoginToTestAPI(router, cfg) - if err != nil { - log.Fatal(err) - } + require.NoError(t, err) + + loginResp := LoginToTestAPI(t, ctx, router, cfg) lapi := LAPI{ router: router, loginResp: loginResp, - t: t, } assertAlertDeleteFailedFromIP := func(ip string) { w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodDelete, "/v1/alerts", strings.NewReader("")) + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts", strings.NewReader("")) AddAuthHeaders(req, loginResp) req.RemoteAddr = ip + ":1234" @@ -484,7 +466,7 @@ func TestDeleteAlertTrustedIPS(t *testing.T) { assertAlertDeletedFromIP := func(ip string) { w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodDelete, "/v1/alerts", strings.NewReader("")) + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts", strings.NewReader("")) AddAuthHeaders(req, loginResp) req.RemoteAddr = ip + ":1234" @@ -493,18 +475,17 @@ func TestDeleteAlertTrustedIPS(t *testing.T) { assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String()) } - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") assertAlertDeleteFailedFromIP("4.3.2.1") assertAlertDeletedFromIP("1.2.3.4") - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") assertAlertDeletedFromIP("1.2.4.0") - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") assertAlertDeletedFromIP("1.2.4.1") - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") assertAlertDeletedFromIP("1.2.4.255") - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") assertAlertDeletedFromIP("127.0.0.1") - } diff --git a/pkg/apiserver/api_key_test.go b/pkg/apiserver/api_key_test.go index a77ab3f835f..e6ed68a6e0d 100644 --- a/pkg/apiserver/api_key_test.go +++ b/pkg/apiserver/api_key_test.go @@ -1,52 +1,47 @@ package apiserver import ( + "context" "net/http" "net/http/httptest" "strings" "testing" - log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" ) func TestAPIKey(t *testing.T) { - router, config, err := NewAPITest(t) - if err != nil { - log.Fatalf("unable to run local API: %s", err) - } - - APIKey, err := CreateTestBouncer(config.API.Server.DbConfig) - if err != nil { - log.Fatal(err) - } + ctx := context.Background() + router, config := NewAPITest(t, ctx) + + APIKey := CreateTestBouncer(t, ctx, config.API.Server.DbConfig) + // Login with empty token w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodGet, "/v1/decisions", strings.NewReader("")) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) assert.Equal(t, 403, w.Code) - assert.Equal(t, "{\"message\":\"access forbidden\"}", w.Body.String()) + assert.Equal(t, `{"message":"access forbidden"}`, w.Body.String()) // Login with invalid token w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodGet, "/v1/decisions", strings.NewReader("")) + req, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) req.Header.Add("User-Agent", UserAgent) req.Header.Add("X-Api-Key", "a1b2c3d4e5f6") router.ServeHTTP(w, req) assert.Equal(t, 403, w.Code) - assert.Equal(t, "{\"message\":\"access forbidden\"}", w.Body.String()) + assert.Equal(t, `{"message":"access forbidden"}`, w.Body.String()) // Login with valid token w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodGet, "/v1/decisions", strings.NewReader("")) + req, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) req.Header.Add("User-Agent", UserAgent) req.Header.Add("X-Api-Key", APIKey) router.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) - } diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index 718657f3aa0..a2fb0e85749 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -2,25 +2,24 @@ package apiserver import ( "context" + "errors" "fmt" "math/rand" "net" "net/http" "net/url" + "slices" "strconv" "strings" "sync" "time" "github.com/go-openapi/strfmt" - "github.com/pkg/errors" log "github.com/sirupsen/logrus" - "golang.org/x/exp/slices" "gopkg.in/tomb.v2" - "github.com/crowdsecurity/go-cs-lib/pkg/ptr" - "github.com/crowdsecurity/go-cs-lib/pkg/trace" - "github.com/crowdsecurity/go-cs-lib/pkg/version" + "github.com/crowdsecurity/go-cs-lib/ptr" + "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/crowdsec/pkg/apiclient" "github.com/crowdsecurity/crowdsec/pkg/csconfig" @@ -35,28 +34,30 @@ import ( const ( // delta values must be smaller than the interval - pullIntervalDefault = time.Hour * 2 - pullIntervalDelta = 5 * time.Minute - pushIntervalDefault = time.Second * 10 - pushIntervalDelta = time.Second * 7 - metricsIntervalDefault = time.Minute * 30 - metricsIntervalDelta = time.Minute * 15 + pullIntervalDefault = time.Hour * 2 + pullIntervalDelta = 5 * time.Minute + pushIntervalDefault = time.Second * 10 + pushIntervalDelta = time.Second * 7 + metricsIntervalDefault = time.Minute * 30 + metricsIntervalDelta = time.Minute * 15 + usageMetricsInterval = time.Minute * 30 + usageMetricsIntervalDelta = time.Minute * 15 ) -var SCOPE_CAPI_ALIAS_ALIAS string = "crowdsecurity/community-blocklist" //we don't use "CAPI" directly, to make it less confusing for the user - type apic struct { // when changing the intervals in tests, always set *First too // or they can be negative - pullInterval time.Duration - pullIntervalFirst time.Duration - pushInterval time.Duration - pushIntervalFirst time.Duration - metricsInterval time.Duration - metricsIntervalFirst time.Duration - dbClient *database.Client - apiClient *apiclient.ApiClient - AlertsAddChan chan []*models.Alert + pullInterval time.Duration + pullIntervalFirst time.Duration + pushInterval time.Duration + pushIntervalFirst time.Duration + metricsInterval time.Duration + metricsIntervalFirst time.Duration + usageMetricsInterval time.Duration + usageMetricsIntervalFirst time.Duration + dbClient *database.Client + apiClient *apiclient.ApiClient + AlertsAddChan chan []*models.Alert mu sync.Mutex pushTomb tomb.Tomb @@ -77,31 +78,37 @@ func randomDuration(d time.Duration, delta time.Duration) time.Duration { if ret <= 0 { return 1 } + return ret } -func (a *apic) FetchScenariosListFromDB() ([]string, error) { +func (a *apic) FetchScenariosListFromDB(ctx context.Context) ([]string, error) { scenarios := make([]string, 0) - machines, err := a.dbClient.ListMachines() + + machines, err := a.dbClient.ListMachines(ctx) if err != nil { return nil, fmt.Errorf("while listing machines: %w", err) } - //merge all scenarios together + // merge all scenarios together for _, v := range machines { machineScenarios := strings.Split(v.Scenarios, ",") log.Debugf("%d scenarios for machine %d", len(machineScenarios), v.ID) + for _, sv := range machineScenarios { if !slices.Contains(scenarios, sv) && sv != "" { scenarios = append(scenarios, sv) } } } + log.Debugf("Returning list of scenarios : %+v", scenarios) + return scenarios, nil } func decisionsToApiDecisions(decisions []*models.Decision) models.AddSignalsRequestItemDecisions { apiDecisions := models.AddSignalsRequestItemDecisions{} + for _, decision := range decisions { x := &models.AddSignalsRequestItemDecisionsItem{ Duration: ptr.Of(*decision.Duration), @@ -109,18 +116,21 @@ func decisionsToApiDecisions(decisions []*models.Decision) models.AddSignalsRequ Origin: ptr.Of(*decision.Origin), Scenario: ptr.Of(*decision.Scenario), Scope: ptr.Of(*decision.Scope), - //Simulated: *decision.Simulated, + // Simulated: *decision.Simulated, Type: ptr.Of(*decision.Type), Until: decision.Until, Value: ptr.Of(*decision.Value), UUID: decision.UUID, } *x.ID = decision.ID + if decision.Simulated != nil { x.Simulated = *decision.Simulated } + apiDecisions = append(apiDecisions, x) } + return apiDecisions } @@ -151,6 +161,7 @@ func alertToSignal(alert *models.Alert, scenarioTrust string, shareContext bool) } if shareContext { signal.Context = make([]*models.AddSignalsRequestItemContextItems0, 0) + for _, meta := range alert.Meta { contextItem := models.AddSignalsRequestItemContextItems0{ Key: meta.Key, @@ -159,51 +170,56 @@ func alertToSignal(alert *models.Alert, scenarioTrust string, shareContext bool) signal.Context = append(signal.Context, &contextItem) } } + return signal } -func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, consoleConfig *csconfig.ConsoleConfig, apicWhitelist *csconfig.CapiWhitelist) (*apic, error) { +func NewAPIC(ctx context.Context, config *csconfig.OnlineApiClientCfg, dbClient *database.Client, consoleConfig *csconfig.ConsoleConfig, apicWhitelist *csconfig.CapiWhitelist) (*apic, error) { var err error - ret := &apic{ - AlertsAddChan: make(chan []*models.Alert), - dbClient: dbClient, - mu: sync.Mutex{}, - startup: true, - credentials: config.Credentials, - pullTomb: tomb.Tomb{}, - pushTomb: tomb.Tomb{}, - metricsTomb: tomb.Tomb{}, - scenarioList: make([]string, 0), - consoleConfig: consoleConfig, - pullInterval: pullIntervalDefault, - pullIntervalFirst: randomDuration(pullIntervalDefault, pullIntervalDelta), - pushInterval: pushIntervalDefault, - pushIntervalFirst: randomDuration(pushIntervalDefault, pushIntervalDelta), - metricsInterval: metricsIntervalDefault, - metricsIntervalFirst: randomDuration(metricsIntervalDefault, metricsIntervalDelta), - isPulling: make(chan bool, 1), - whitelists: apicWhitelist, + ret := &apic{ + AlertsAddChan: make(chan []*models.Alert), + dbClient: dbClient, + mu: sync.Mutex{}, + startup: true, + credentials: config.Credentials, + pullTomb: tomb.Tomb{}, + pushTomb: tomb.Tomb{}, + metricsTomb: tomb.Tomb{}, + scenarioList: make([]string, 0), + consoleConfig: consoleConfig, + pullInterval: pullIntervalDefault, + pullIntervalFirst: randomDuration(pullIntervalDefault, pullIntervalDelta), + pushInterval: pushIntervalDefault, + pushIntervalFirst: randomDuration(pushIntervalDefault, pushIntervalDelta), + metricsInterval: metricsIntervalDefault, + metricsIntervalFirst: randomDuration(metricsIntervalDefault, metricsIntervalDelta), + usageMetricsInterval: usageMetricsInterval, + usageMetricsIntervalFirst: randomDuration(usageMetricsInterval, usageMetricsIntervalDelta), + isPulling: make(chan bool, 1), + whitelists: apicWhitelist, } password := strfmt.Password(config.Credentials.Password) + apiURL, err := url.Parse(config.Credentials.URL) if err != nil { return nil, fmt.Errorf("while parsing '%s': %w", config.Credentials.URL, err) } + papiURL, err := url.Parse(config.Credentials.PapiURL) if err != nil { return nil, fmt.Errorf("while parsing '%s': %w", config.Credentials.PapiURL, err) } - ret.scenarioList, err = ret.FetchScenariosListFromDB() + ret.scenarioList, err = ret.FetchScenariosListFromDB(ctx) if err != nil { return nil, fmt.Errorf("while fetching scenarios from db: %w", err) } + ret.apiClient, err = apiclient.NewClient(&apiclient.Config{ MachineID: config.Credentials.Login, Password: password, - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, PapiURL: papiURL, VersionPrefix: "v3", @@ -215,13 +231,13 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con } // The watcher will be authenticated by the RoundTripper the first time it will call CAPI - // Explicit authentication will provoke an useless supplementary call to CAPI - scenarios, err := ret.FetchScenariosListFromDB() + // Explicit authentication will provoke a useless supplementary call to CAPI + scenarios, err := ret.FetchScenariosListFromDB(ctx) if err != nil { return ret, fmt.Errorf("get scenario in db: %w", err) } - authResp, _, err := ret.apiClient.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ + authResp, _, err := ret.apiClient.Auth.AuthenticateWatcher(ctx, models.WatcherAuthRequest{ MachineID: &config.Credentials.Login, Password: &password, Scenarios: scenarios, @@ -230,7 +246,7 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con return ret, fmt.Errorf("authenticate watcher (%s): %w", config.Credentials.Login, err) } - if err := ret.apiClient.GetClient().Transport.(*apiclient.JWTTransport).Expiration.UnmarshalText([]byte(authResp.Expire)); err != nil { + if err = ret.apiClient.GetClient().Transport.(*apiclient.JWTTransport).Expiration.UnmarshalText([]byte(authResp.Expire)); err != nil { return ret, fmt.Errorf("unable to parse jwt expiration: %w", err) } @@ -240,10 +256,11 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con } // keep track of all alerts in cache and push it to CAPI every PushInterval. -func (a *apic) Push() error { +func (a *apic) Push(ctx context.Context) error { defer trace.CatchPanic("lapi/pushToAPIC") var cache models.AddSignalsRequest + ticker := time.NewTicker(a.pushIntervalFirst) log.Infof("Start push to CrowdSec Central API (interval: %s once, then %s)", a.pushIntervalFirst.Round(time.Second), a.pushInterval) @@ -254,28 +271,35 @@ func (a *apic) Push() error { a.pullTomb.Kill(nil) a.metricsTomb.Kill(nil) log.Infof("push tomb is dying, sending cache (%d elements) before exiting", len(cache)) + if len(cache) == 0 { return nil } - go a.Send(&cache) + + go a.Send(ctx, &cache) + return nil case <-ticker.C: ticker.Reset(a.pushInterval) + if len(cache) > 0 { a.mu.Lock() cacheCopy := cache cache = make(models.AddSignalsRequest, 0) a.mu.Unlock() log.Infof("Signal push: %d signals to push", len(cacheCopy)) - go a.Send(&cacheCopy) + + go a.Send(ctx, &cacheCopy) } case alerts := <-a.AlertsAddChan: var signals []*models.AddSignalsRequestItem + for _, alert := range alerts { if ok := shouldShareAlert(alert, a.consoleConfig); ok { signals = append(signals, alertToSignal(alert, getScenarioTrustOfAlert(alert), *a.consoleConfig.ShareContext)) } } + a.mu.Lock() cache = append(cache, signals...) a.mu.Unlock() @@ -290,11 +314,13 @@ func getScenarioTrustOfAlert(alert *models.Alert) string { } else if alert.ScenarioVersion == nil || *alert.ScenarioVersion == "" || *alert.ScenarioVersion == "?" { scenarioTrust = "tainted" } + if len(alert.Decisions) > 0 { if *alert.Decisions[0].Origin == types.CscliOrigin { scenarioTrust = "manual" } } + return scenarioTrust } @@ -303,6 +329,7 @@ func shouldShareAlert(alert *models.Alert, consoleConfig *csconfig.ConsoleConfig log.Debugf("simulation enabled for alert (id:%d), will not be sent to CAPI", alert.ID) return false } + switch scenarioTrust := getScenarioTrustOfAlert(alert); scenarioTrust { case "manual": if !*consoleConfig.ShareManualDecisions { @@ -320,10 +347,11 @@ func shouldShareAlert(alert *models.Alert, consoleConfig *csconfig.ConsoleConfig return false } } + return true } -func (a *apic) Send(cacheOrig *models.AddSignalsRequest) { +func (a *apic) Send(ctx context.Context, cacheOrig *models.AddSignalsRequest) { /*we do have a problem with this : The apic.Push background routine reads from alertToPush chan. This chan is filled by Controller.CreateAlert @@ -335,124 +363,142 @@ func (a *apic) Send(cacheOrig *models.AddSignalsRequest) { I don't know enough about gin to tell how much of an issue it can be. */ - var cache []*models.AddSignalsRequestItem = *cacheOrig - var send models.AddSignalsRequest + var ( + cache []*models.AddSignalsRequestItem = *cacheOrig + send models.AddSignalsRequest + ) bulkSize := 50 pageStart := 0 pageEnd := bulkSize for { - if pageEnd >= len(cache) { send = cache[pageStart:] - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + _, _, err := a.apiClient.Signal.Add(ctx, &send) if err != nil { log.Errorf("sending signal to central API: %s", err) return } + break } + send = cache[pageStart:pageEnd] - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + _, _, err := a.apiClient.Signal.Add(ctx, &send) if err != nil { - //we log it here as well, because the return value of func might be discarded + // we log it here as well, because the return value of func might be discarded log.Errorf("sending signal to central API: %s", err) } + pageStart += bulkSize pageEnd += bulkSize } } -func (a *apic) CAPIPullIsOld() (bool, error) { +func (a *apic) CAPIPullIsOld(ctx context.Context) (bool, error) { /*only pull community blocklist if it's older than 1h30 */ alerts := a.dbClient.Ent.Alert.Query() alerts = alerts.Where(alert.HasDecisionsWith(decision.OriginEQ(database.CapiMachineID))) alerts = alerts.Where(alert.CreatedAtGTE(time.Now().UTC().Add(-time.Duration(1*time.Hour + 30*time.Minute)))) //nolint:unconvert - count, err := alerts.Count(a.dbClient.CTX) + + count, err := alerts.Count(ctx) if err != nil { return false, fmt.Errorf("while looking for CAPI alert: %w", err) } + if count > 0 { log.Printf("last CAPI pull is newer than 1h30, skip.") return false, nil } + return true, nil } -func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, delete_counters map[string]map[string]int) (int, error) { - var filter map[string][]string - var nbDeleted int +func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, deleteCounters map[string]map[string]int) (int, error) { + ctx := context.TODO() + nbDeleted := 0 + for _, decision := range deletedDecisions { - if strings.ToLower(*decision.Scope) == "ip" { - filter = make(map[string][]string, 1) - filter["value"] = []string{*decision.Value} - } else { - filter = make(map[string][]string, 3) - filter["value"] = []string{*decision.Value} + filter := map[string][]string{ + "value": {*decision.Value}, + "origin": {*decision.Origin}, + } + if strings.ToLower(*decision.Scope) != "ip" { filter["type"] = []string{*decision.Type} filter["scopes"] = []string{*decision.Scope} } - filter["origin"] = []string{*decision.Origin} - dbCliRet, _, err := a.dbClient.SoftDeleteDecisionsWithFilter(filter) + dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(ctx, filter) if err != nil { - return 0, fmt.Errorf("deleting decisions error: %w", err) + return 0, fmt.Errorf("expiring decisions error: %w", err) } + dbCliDel, err := strconv.Atoi(dbCliRet) if err != nil { return 0, fmt.Errorf("converting db ret %d: %w", dbCliDel, err) } - updateCounterForDecision(delete_counters, decision.Origin, decision.Scenario, dbCliDel) + + updateCounterForDecision(deleteCounters, decision.Origin, decision.Scenario, dbCliDel) nbDeleted += dbCliDel } + return nbDeleted, nil } -func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisionsStreamResponseDeletedItem, delete_counters map[string]map[string]int) (int, error) { - var filter map[string][]string +func (a *apic) HandleDeletedDecisionsV3(ctx context.Context, deletedDecisions []*modelscapi.GetDecisionsStreamResponseDeletedItem, deleteCounters map[string]map[string]int) (int, error) { var nbDeleted int + for _, decisions := range deletedDecisions { scope := decisions.Scope + for _, decision := range decisions.Decisions { - if strings.ToLower(*scope) == "ip" { - filter = make(map[string][]string, 1) - filter["value"] = []string{decision} - } else { - filter = make(map[string][]string, 2) - filter["value"] = []string{decision} + filter := map[string][]string{ + "value": {decision}, + "origin": {types.CAPIOrigin}, + } + if strings.ToLower(*scope) != "ip" { filter["scopes"] = []string{*scope} } - filter["origin"] = []string{types.CAPIOrigin} - dbCliRet, _, err := a.dbClient.SoftDeleteDecisionsWithFilter(filter) + dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(ctx, filter) if err != nil { - return 0, fmt.Errorf("deleting decisions error: %w", err) + return 0, fmt.Errorf("expiring decisions error: %w", err) } + dbCliDel, err := strconv.Atoi(dbCliRet) if err != nil { return 0, fmt.Errorf("converting db ret %d: %w", dbCliDel, err) } - updateCounterForDecision(delete_counters, ptr.Of(types.CAPIOrigin), nil, dbCliDel) + + updateCounterForDecision(deleteCounters, ptr.Of(types.CAPIOrigin), nil, dbCliDel) nbDeleted += dbCliDel } } + return nbDeleted, nil } func createAlertsForDecisions(decisions []*models.Decision) []*models.Alert { newAlerts := make([]*models.Alert, 0) + for _, decision := range decisions { found := false + for _, sub := range newAlerts { if sub.Source.Scope == nil { log.Warningf("nil scope in %+v", sub) continue } + if *decision.Origin == types.CAPIOrigin { if *sub.Source.Scope == types.CAPIOrigin { found = true @@ -463,6 +509,7 @@ func createAlertsForDecisions(decisions []*models.Decision) []*models.Alert { if sub.Scenario == nil { log.Warningf("nil scenario in %+v", sub) } + if *sub.Scenario == *decision.Scenario { found = true break @@ -472,46 +519,60 @@ func createAlertsForDecisions(decisions []*models.Decision) []*models.Alert { log.Warningf("unknown origin %s : %+v", *decision.Origin, decision) } } + if !found { log.Debugf("Create entry for origin:%s scenario:%s", *decision.Origin, *decision.Scenario) newAlerts = append(newAlerts, createAlertForDecision(decision)) } } + return newAlerts } func createAlertForDecision(decision *models.Decision) *models.Alert { - newAlert := &models.Alert{} - newAlert.Source = &models.Source{} - newAlert.Source.Scope = ptr.Of("") - if *decision.Origin == types.CAPIOrigin { //to make things more user friendly, we replace CAPI with community-blocklist - newAlert.Scenario = ptr.Of(types.CAPIOrigin) - newAlert.Source.Scope = ptr.Of(types.CAPIOrigin) - } else if *decision.Origin == types.ListOrigin { - newAlert.Scenario = ptr.Of(*decision.Scenario) - newAlert.Source.Scope = ptr.Of(types.ListOrigin) - } else { + var ( + scenario string + scope string + ) + + switch *decision.Origin { + case types.CAPIOrigin: + scenario = types.CAPIOrigin + scope = types.CAPIOrigin + case types.ListOrigin: + scenario = *decision.Scenario + scope = types.ListOrigin + default: + scenario = "" + scope = "" + log.Warningf("unknown origin %s", *decision.Origin) } - newAlert.Message = ptr.Of("") - newAlert.Source.Value = ptr.Of("") - newAlert.StartAt = ptr.Of(time.Now().UTC().Format(time.RFC3339)) - newAlert.StopAt = ptr.Of(time.Now().UTC().Format(time.RFC3339)) - newAlert.Capacity = ptr.Of(int32(0)) - newAlert.Simulated = ptr.Of(false) - newAlert.EventsCount = ptr.Of(int32(0)) - newAlert.Leakspeed = ptr.Of("") - newAlert.ScenarioHash = ptr.Of("") - newAlert.ScenarioVersion = ptr.Of("") - newAlert.MachineID = database.CapiMachineID - return newAlert + + return &models.Alert{ + Source: &models.Source{ + Scope: ptr.Of(scope), + Value: ptr.Of(""), + }, + Scenario: ptr.Of(scenario), + Message: ptr.Of(""), + StartAt: ptr.Of(time.Now().UTC().Format(time.RFC3339)), + StopAt: ptr.Of(time.Now().UTC().Format(time.RFC3339)), + Capacity: ptr.Of(int32(0)), + Simulated: ptr.Of(false), + EventsCount: ptr.Of(int32(0)), + Leakspeed: ptr.Of(""), + ScenarioHash: ptr.Of(""), + ScenarioVersion: ptr.Of(""), + MachineID: database.CapiMachineID, + } } // This function takes in list of parent alerts and decisions and then pairs them up. -func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decision, add_counters map[string]map[string]int) []*models.Alert { +func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decision, addCounters map[string]map[string]int) []*models.Alert { for _, decision := range decisions { - //count and create separate alerts for each list - updateCounterForDecision(add_counters, decision.Origin, decision.Scenario, 1) + // count and create separate alerts for each list + updateCounterForDecision(addCounters, decision.Origin, decision.Scenario, 1) /*CAPI might send lower case scopes, unify it.*/ switch strings.ToLower(*decision.Scope) { @@ -520,40 +581,45 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio case "range": *decision.Scope = types.Range } + found := false - //add the individual decisions to the right list + // add the individual decisions to the right list for idx, alert := range alerts { if *decision.Origin == types.CAPIOrigin { if *alert.Source.Scope == types.CAPIOrigin { alerts[idx].Decisions = append(alerts[idx].Decisions, decision) found = true + break } } else if *decision.Origin == types.ListOrigin { if *alert.Source.Scope == types.ListOrigin && *alert.Scenario == *decision.Scenario { alerts[idx].Decisions = append(alerts[idx].Decisions, decision) found = true + break } } else { log.Warningf("unknown origin %s", *decision.Origin) } } + if !found { log.Warningf("Orphaned decision for %s - %s", *decision.Origin, *decision.Scenario) } } + return alerts } // we receive a list of decisions and links for blocklist and we need to create a list of alerts : // one alert for "community blocklist" // one alert per list we're subscribed to -func (a *apic) PullTop(forcePull bool) error { +func (a *apic) PullTop(ctx context.Context, forcePull bool) error { var err error - //A mutex with TryLock would be a bit simpler - //But go does not guarantee that TryLock will be able to acquire the lock even if it is available + // A mutex with TryLock would be a bit simpler + // But go does not guarantee that TryLock will be able to acquire the lock even if it is available select { case a.isPulling <- true: defer func() { @@ -564,36 +630,57 @@ func (a *apic) PullTop(forcePull bool) error { } if !forcePull { - if lastPullIsOld, err := a.CAPIPullIsOld(); err != nil { + if lastPullIsOld, err := a.CAPIPullIsOld(ctx); err != nil { return err } else if !lastPullIsOld { return nil } } + log.Debug("Acquiring lock for pullCAPI") + + err = a.dbClient.AcquirePullCAPILock(ctx) + if a.dbClient.IsLocked(err) { + log.Info("PullCAPI is already running, skipping") + return nil + } + + /*defer lock release*/ + defer func() { + log.Debug("Releasing lock for pullCAPI") + + if err := a.dbClient.ReleasePullCAPILock(ctx); err != nil { + log.Errorf("while releasing lock: %v", err) + } + }() + log.Infof("Starting community-blocklist update") - data, _, err := a.apiClient.Decisions.GetStreamV3(context.Background(), apiclient.DecisionsStreamOpts{Startup: a.startup}) + data, _, err := a.apiClient.Decisions.GetStreamV3(ctx, apiclient.DecisionsStreamOpts{Startup: a.startup}) if err != nil { return fmt.Errorf("get stream: %w", err) } + a.startup = false /*to count additions/deletions across lists*/ log.Debugf("Received %d new decisions", len(data.New)) log.Debugf("Received %d deleted decisions", len(data.Deleted)) + if data.Links != nil { log.Debugf("Received %d blocklists links", len(data.Links.Blocklists)) } - add_counters, delete_counters := makeAddAndDeleteCounters() + addCounters, deleteCounters := makeAddAndDeleteCounters() + // process deleted decisions - if nbDeleted, err := a.HandleDeletedDecisionsV3(data.Deleted, delete_counters); err != nil { + nbDeleted, err := a.HandleDeletedDecisionsV3(ctx, data.Deleted, deleteCounters) + if err != nil { return err - } else { - log.Printf("capi/community-blocklist : %d explicit deletions", nbDeleted) } + log.Printf("capi/community-blocklist : %d explicit deletions", nbDeleted) + if len(data.New) == 0 { log.Infof("capi/community-blocklist : received 0 new entries (expected if you just installed crowdsec)") return nil @@ -601,117 +688,216 @@ func (a *apic) PullTop(forcePull bool) error { // create one alert for community blocklist using the first decision decisions := a.apiClient.Decisions.GetDecisionsFromGroups(data.New) - //apply APIC specific whitelists + // apply APIC specific whitelists decisions = a.ApplyApicWhitelists(decisions) alert := createAlertForDecision(decisions[0]) alertsFromCapi := []*models.Alert{alert} - alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, add_counters) + alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, addCounters) - err = a.SaveAlerts(alertsFromCapi, add_counters, delete_counters) + err = a.SaveAlerts(ctx, alertsFromCapi, addCounters, deleteCounters) if err != nil { return fmt.Errorf("while saving alerts: %w", err) } // update blocklists - if err := a.UpdateBlocklists(data.Links, add_counters); err != nil { + if err := a.UpdateBlocklists(ctx, data.Links, addCounters, forcePull); err != nil { return fmt.Errorf("while updating blocklists: %w", err) } + + return nil +} + +// we receive a link to a blocklist, we pull the content of the blocklist and we create one alert +func (a *apic) PullBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink, forcePull bool) error { + addCounters, _ := makeAddAndDeleteCounters() + if err := a.UpdateBlocklists(ctx, &modelscapi.GetDecisionsStreamResponseLinks{ + Blocklists: []*modelscapi.BlocklistLink{blocklist}, + }, addCounters, forcePull); err != nil { + return fmt.Errorf("while pulling blocklist: %w", err) + } + return nil } +// if decisions is whitelisted: return representation of the whitelist ip or cidr +// if not whitelisted: empty string +func (a *apic) whitelistedBy(decision *models.Decision) string { + if decision.Value == nil { + return "" + } + + ipval := net.ParseIP(*decision.Value) + for _, cidr := range a.whitelists.Cidrs { + if cidr.Contains(ipval) { + return cidr.String() + } + } + + for _, ip := range a.whitelists.Ips { + if ip != nil && ip.Equal(ipval) { + return ip.String() + } + } + + return "" +} + func (a *apic) ApplyApicWhitelists(decisions []*models.Decision) []*models.Decision { - if a.whitelists == nil { + if a.whitelists == nil || len(a.whitelists.Cidrs) == 0 && len(a.whitelists.Ips) == 0 { return decisions } - //deal with CAPI whitelists for fire. We want to avoid having a second list, so we shrink in place + // deal with CAPI whitelists for fire. We want to avoid having a second list, so we shrink in place outIdx := 0 + for _, decision := range decisions { - if decision.Value == nil { + whitelister := a.whitelistedBy(decision) + if whitelister != "" { + log.Infof("%s from %s is whitelisted by %s", *decision.Value, *decision.Scenario, whitelister) continue } - skip := false - ipval := net.ParseIP(*decision.Value) - for _, cidr := range a.whitelists.Cidrs { - if skip { - break - } - if cidr.Contains(ipval) { - log.Infof("%s from %s is whitelisted by %s", *decision.Value, *decision.Scenario, cidr.String()) - skip = true - } - } - for _, ip := range a.whitelists.Ips { - if skip { - break - } - if ip != nil && ip.Equal(ipval) { - log.Infof("%s from %s is whitelisted by %s", *decision.Value, *decision.Scenario, ip.String()) - skip = true - } - } - if !skip { - decisions[outIdx] = decision - outIdx++ - } + decisions[outIdx] = decision + outIdx++ } - //shrink the list, those are deleted items - decisions = decisions[:outIdx] - return decisions + // shrink the list, those are deleted items + return decisions[:outIdx] } -func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, add_counters map[string]map[string]int, delete_counters map[string]map[string]int) error { - for idx, alert := range alertsFromCapi { - alertsFromCapi[idx] = setAlertScenario(add_counters, delete_counters, alert) - log.Debugf("%s has %d decisions", *alertsFromCapi[idx].Source.Scope, len(alertsFromCapi[idx].Decisions)) +func (a *apic) SaveAlerts(ctx context.Context, alertsFromCapi []*models.Alert, addCounters map[string]map[string]int, deleteCounters map[string]map[string]int) error { + for _, alert := range alertsFromCapi { + setAlertScenario(alert, addCounters, deleteCounters) + log.Debugf("%s has %d decisions", *alert.Source.Scope, len(alert.Decisions)) + if a.dbClient.Type == "sqlite" && (a.dbClient.WalMode == nil || !*a.dbClient.WalMode) { log.Warningf("sqlite is not using WAL mode, LAPI might become unresponsive when inserting the community blocklist") } - alertID, inserted, deleted, err := a.dbClient.UpdateCommunityBlocklist(alertsFromCapi[idx]) + + alertID, inserted, deleted, err := a.dbClient.UpdateCommunityBlocklist(ctx, alert) if err != nil { - return fmt.Errorf("while saving alert from %s: %w", *alertsFromCapi[idx].Source.Scope, err) + return fmt.Errorf("while saving alert from %s: %w", *alert.Source.Scope, err) } - log.Printf("%s : added %d entries, deleted %d entries (alert:%d)", *alertsFromCapi[idx].Source.Scope, inserted, deleted, alertID) + + log.Printf("%s : added %d entries, deleted %d entries (alert:%d)", *alert.Source.Scope, inserted, deleted, alertID) } return nil } -func (a *apic) ShouldForcePullBlocklist(blocklist *modelscapi.BlocklistLink) (bool, error) { +func (a *apic) ShouldForcePullBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink) (bool, error) { // we should force pull if the blocklist decisions are about to expire or there's no decision in the db alertQuery := a.dbClient.Ent.Alert.Query() alertQuery.Where(alert.SourceScopeEQ(fmt.Sprintf("%s:%s", types.ListOrigin, *blocklist.Name))) alertQuery.Order(ent.Desc(alert.FieldCreatedAt)) - alertInstance, err := alertQuery.First(context.Background()) + + alertInstance, err := alertQuery.First(ctx) if err != nil { if ent.IsNotFound(err) { log.Debugf("no alert found for %s, force refresh", *blocklist.Name) return true, nil } + return false, fmt.Errorf("while getting alert: %w", err) } + decisionQuery := a.dbClient.Ent.Decision.Query() decisionQuery.Where(decision.HasOwnerWith(alert.IDEQ(alertInstance.ID))) - firstDecision, err := decisionQuery.First(context.Background()) + + firstDecision, err := decisionQuery.First(ctx) if err != nil { if ent.IsNotFound(err) { log.Debugf("no decision found for %s, force refresh", *blocklist.Name) return true, nil } + return false, fmt.Errorf("while getting decision: %w", err) } + if firstDecision == nil || firstDecision.Until == nil || firstDecision.Until.Sub(time.Now().UTC()) < (a.pullInterval+15*time.Minute) { log.Debugf("at least one decision found for %s, expire soon, force refresh", *blocklist.Name) return true, nil } + return false, nil } -func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLinks, add_counters map[string]map[string]int) error { +func (a *apic) updateBlocklist(ctx context.Context, client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, addCounters map[string]map[string]int, forcePull bool) error { + if blocklist.Scope == nil { + log.Warningf("blocklist has no scope") + return nil + } + + if blocklist.Duration == nil { + log.Warningf("blocklist has no duration") + return nil + } + + if !forcePull { + _forcePull, err := a.ShouldForcePullBlocklist(ctx, blocklist) + if err != nil { + return fmt.Errorf("while checking if we should force pull blocklist %s: %w", *blocklist.Name, err) + } + + forcePull = _forcePull + } + + blocklistConfigItemName := fmt.Sprintf("blocklist:%s:last_pull", *blocklist.Name) + + var ( + lastPullTimestamp *string + err error + ) + + if !forcePull { + lastPullTimestamp, err = a.dbClient.GetConfigItem(ctx, blocklistConfigItemName) + if err != nil { + return fmt.Errorf("while getting last pull timestamp for blocklist %s: %w", *blocklist.Name, err) + } + } + + decisions, hasChanged, err := client.Decisions.GetDecisionsFromBlocklist(ctx, blocklist, lastPullTimestamp) + if err != nil { + return fmt.Errorf("while getting decisions from blocklist %s: %w", *blocklist.Name, err) + } + + if !hasChanged { + if lastPullTimestamp == nil { + log.Infof("blocklist %s hasn't been modified or there was an error reading it, skipping", *blocklist.Name) + } else { + log.Infof("blocklist %s hasn't been modified since %s, skipping", *blocklist.Name, *lastPullTimestamp) + } + + return nil + } + + err = a.dbClient.SetConfigItem(ctx, blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat)) + if err != nil { + return fmt.Errorf("while setting last pull timestamp for blocklist %s: %w", *blocklist.Name, err) + } + + if len(decisions) == 0 { + log.Infof("blocklist %s has no decisions", *blocklist.Name) + return nil + } + // apply APIC specific whitelists + decisions = a.ApplyApicWhitelists(decisions) + alert := createAlertForDecision(decisions[0]) + alertsFromCapi := []*models.Alert{alert} + alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, addCounters) + + err = a.SaveAlerts(ctx, alertsFromCapi, addCounters, nil) + if err != nil { + return fmt.Errorf("while saving alert from blocklist %s: %w", *blocklist.Name, err) + } + + return nil +} + +func (a *apic) UpdateBlocklists(ctx context.Context, links *modelscapi.GetDecisionsStreamResponseLinks, addCounters map[string]map[string]int, forcePull bool) error { if links == nil { return nil } + if links.Blocklists == nil { return nil } @@ -721,91 +907,56 @@ func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLink if err != nil { return fmt.Errorf("while creating default client: %w", err) } - for _, blocklist := range links.Blocklists { - if blocklist.Scope == nil { - log.Warningf("blocklist has no scope") - continue - } - if blocklist.Duration == nil { - log.Warningf("blocklist has no duration") - continue - } - forcePull, err := a.ShouldForcePullBlocklist(blocklist) - if err != nil { - return fmt.Errorf("while checking if we should force pull blocklist %s: %w", *blocklist.Name, err) - } - blocklistConfigItemName := fmt.Sprintf("blocklist:%s:last_pull", *blocklist.Name) - var lastPullTimestamp *string - if !forcePull { - lastPullTimestamp, err = a.dbClient.GetConfigItem(blocklistConfigItemName) - if err != nil { - return fmt.Errorf("while getting last pull timestamp for blocklist %s: %w", *blocklist.Name, err) - } - } - decisions, has_changed, err := defaultClient.Decisions.GetDecisionsFromBlocklist(context.Background(), blocklist, lastPullTimestamp) - if err != nil { - return fmt.Errorf("while getting decisions from blocklist %s: %w", *blocklist.Name, err) - } - if !has_changed { - if lastPullTimestamp == nil { - log.Infof("blocklist %s hasn't been modified or there was an error reading it, skipping", *blocklist.Name) - } else { - log.Infof("blocklist %s hasn't been modified since %s, skipping", *blocklist.Name, *lastPullTimestamp) - } - continue - } - err = a.dbClient.SetConfigItem(blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat)) - if err != nil { - return fmt.Errorf("while setting last pull timestamp for blocklist %s: %w", *blocklist.Name, err) - } - if len(decisions) == 0 { - log.Infof("blocklist %s has no decisions", *blocklist.Name) - continue - } - //apply APIC specific whitelists - decisions = a.ApplyApicWhitelists(decisions) - alert := createAlertForDecision(decisions[0]) - alertsFromCapi := []*models.Alert{alert} - alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, add_counters) - err = a.SaveAlerts(alertsFromCapi, add_counters, nil) - if err != nil { - return fmt.Errorf("while saving alert from blocklist %s: %w", *blocklist.Name, err) + for _, blocklist := range links.Blocklists { + if err := a.updateBlocklist(ctx, defaultClient, blocklist, addCounters, forcePull); err != nil { + return err } } + return nil } -func setAlertScenario(add_counters map[string]map[string]int, delete_counters map[string]map[string]int, alert *models.Alert) *models.Alert { - if *alert.Source.Scope == types.CAPIOrigin { - *alert.Source.Scope = SCOPE_CAPI_ALIAS_ALIAS - alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", add_counters[types.CAPIOrigin]["all"], delete_counters[types.CAPIOrigin]["all"])) - } else if *alert.Source.Scope == types.ListOrigin { +func setAlertScenario(alert *models.Alert, addCounters map[string]map[string]int, deleteCounters map[string]map[string]int) { + switch *alert.Source.Scope { + case types.CAPIOrigin: + *alert.Source.Scope = types.CommunityBlocklistPullSourceScope + alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", + addCounters[types.CAPIOrigin]["all"], + deleteCounters[types.CAPIOrigin]["all"])) + case types.ListOrigin: *alert.Source.Scope = fmt.Sprintf("%s:%s", types.ListOrigin, *alert.Scenario) - alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", add_counters[types.ListOrigin][*alert.Scenario], delete_counters[types.ListOrigin][*alert.Scenario])) + alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", + addCounters[types.ListOrigin][*alert.Scenario], + deleteCounters[types.ListOrigin][*alert.Scenario])) } - return alert } -func (a *apic) Pull() error { +func (a *apic) Pull(ctx context.Context) error { defer trace.CatchPanic("lapi/pullFromAPIC") toldOnce := false + for { - scenario, err := a.FetchScenariosListFromDB() + scenario, err := a.FetchScenariosListFromDB(ctx) if err != nil { log.Errorf("unable to fetch scenarios from db: %s", err) } + if len(scenario) > 0 { break } + if !toldOnce { log.Warning("scenario list is empty, will not pull yet") + toldOnce = true } + time.Sleep(1 * time.Second) } - if err := a.PullTop(false); err != nil { + + if err := a.PullTop(ctx, false); err != nil { log.Errorf("capi pull top: %s", err) } @@ -816,13 +967,15 @@ func (a *apic) Pull() error { select { case <-ticker.C: ticker.Reset(a.pullInterval) - if err := a.PullTop(false); err != nil { + + if err := a.PullTop(ctx, false); err != nil { log.Errorf("capi pull top: %s", err) continue } case <-a.pullTomb.Dying(): // if one apic routine is dying, do we kill the others? a.metricsTomb.Kill(nil) a.pushTomb.Kill(nil) + return nil } } @@ -835,23 +988,24 @@ func (a *apic) Shutdown() { } func makeAddAndDeleteCounters() (map[string]map[string]int, map[string]map[string]int) { - add_counters := make(map[string]map[string]int) - add_counters[types.CAPIOrigin] = make(map[string]int) - add_counters[types.ListOrigin] = make(map[string]int) + addCounters := make(map[string]map[string]int) + addCounters[types.CAPIOrigin] = make(map[string]int) + addCounters[types.ListOrigin] = make(map[string]int) - delete_counters := make(map[string]map[string]int) - delete_counters[types.CAPIOrigin] = make(map[string]int) - delete_counters[types.ListOrigin] = make(map[string]int) + deleteCounters := make(map[string]map[string]int) + deleteCounters[types.CAPIOrigin] = make(map[string]int) + deleteCounters[types.ListOrigin] = make(map[string]int) - return add_counters, delete_counters + return addCounters, deleteCounters } func updateCounterForDecision(counter map[string]map[string]int, origin *string, scenario *string, totalDecisions int) { - if *origin == types.CAPIOrigin { + switch *origin { + case types.CAPIOrigin: counter[*origin]["all"] += totalDecisions - } else if *origin == types.ListOrigin { + case types.ListOrigin: counter[*origin][*scenario] += totalDecisions - } else { + default: log.Warningf("Unknown origin %s", *origin) } } diff --git a/pkg/apiserver/apic_metrics.go b/pkg/apiserver/apic_metrics.go index 7305dfcd6bc..3d9e7b28a79 100644 --- a/pkg/apiserver/apic_metrics.go +++ b/pkg/apiserver/apic_metrics.go @@ -2,20 +2,191 @@ package apiserver import ( "context" + "encoding/json" + "net/http" + "slices" + "strings" "time" log "github.com/sirupsen/logrus" - "golang.org/x/exp/slices" - "github.com/crowdsecurity/go-cs-lib/pkg/ptr" - "github.com/crowdsecurity/go-cs-lib/pkg/trace" - "github.com/crowdsecurity/go-cs-lib/pkg/version" + "github.com/crowdsecurity/go-cs-lib/ptr" + "github.com/crowdsecurity/go-cs-lib/trace" + "github.com/crowdsecurity/go-cs-lib/version" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/fflag" "github.com/crowdsecurity/crowdsec/pkg/models" ) -func (a *apic) GetMetrics() (*models.Metrics, error) { - machines, err := a.dbClient.ListMachines() +type dbPayload struct { + Metrics []*models.DetailedMetrics `json:"metrics"` +} + +func (a *apic) GetUsageMetrics(ctx context.Context) (*models.AllMetrics, []int, error) { + allMetrics := &models.AllMetrics{} + metricsIds := make([]int, 0) + + lps, err := a.dbClient.ListMachines(ctx) + if err != nil { + return nil, nil, err + } + + bouncers, err := a.dbClient.ListBouncers(ctx) + if err != nil { + return nil, nil, err + } + + for _, bouncer := range bouncers { + dbMetrics, err := a.dbClient.GetBouncerUsageMetricsByName(ctx, bouncer.Name) + if err != nil { + log.Errorf("unable to get bouncer usage metrics: %s", err) + continue + } + + rcMetrics := models.RemediationComponentsMetrics{} + + rcMetrics.Os = &models.OSversion{ + Name: ptr.Of(bouncer.Osname), + Version: ptr.Of(bouncer.Osversion), + } + rcMetrics.Type = bouncer.Type + rcMetrics.FeatureFlags = strings.Split(bouncer.Featureflags, ",") + rcMetrics.Version = ptr.Of(bouncer.Version) + rcMetrics.Name = bouncer.Name + + rcMetrics.LastPull = 0 + if bouncer.LastPull != nil { + rcMetrics.LastPull = bouncer.LastPull.UTC().Unix() + } + + rcMetrics.Metrics = make([]*models.DetailedMetrics, 0) + + // Might seem weird, but we duplicate the bouncers if we have multiple unsent metrics + for _, dbMetric := range dbMetrics { + dbPayload := &dbPayload{} + // Append no matter what, if we cannot unmarshal, there's no way we'll be able to fix it automatically + metricsIds = append(metricsIds, dbMetric.ID) + + err := json.Unmarshal([]byte(dbMetric.Payload), dbPayload) + if err != nil { + log.Errorf("unable to parse bouncer metric (%s)", err) + continue + } + + rcMetrics.Metrics = append(rcMetrics.Metrics, dbPayload.Metrics...) + } + + allMetrics.RemediationComponents = append(allMetrics.RemediationComponents, &rcMetrics) + } + + for _, lp := range lps { + dbMetrics, err := a.dbClient.GetLPUsageMetricsByMachineID(ctx, lp.MachineId) + if err != nil { + log.Errorf("unable to get LP usage metrics: %s", err) + continue + } + + lpMetrics := models.LogProcessorsMetrics{} + + lpMetrics.Os = &models.OSversion{ + Name: ptr.Of(lp.Osname), + Version: ptr.Of(lp.Osversion), + } + lpMetrics.FeatureFlags = strings.Split(lp.Featureflags, ",") + lpMetrics.Version = ptr.Of(lp.Version) + lpMetrics.Name = lp.MachineId + + lpMetrics.LastPush = 0 + if lp.LastPush != nil { + lpMetrics.LastPush = lp.LastPush.UTC().Unix() + } + + lpMetrics.LastUpdate = lp.UpdatedAt.UTC().Unix() + lpMetrics.Datasources = lp.Datasources + + hubItems := models.HubItems{} + + if lp.Hubstate != nil { + // must carry over the hub state even if nothing is installed + for itemType, items := range lp.Hubstate { + hubItems[itemType] = []models.HubItem{} + for _, item := range items { + hubItems[itemType] = append(hubItems[itemType], models.HubItem{ + Name: item.Name, + Status: item.Status, + Version: item.Version, + }) + } + } + } + + lpMetrics.HubItems = hubItems + + lpMetrics.Metrics = make([]*models.DetailedMetrics, 0) + + for _, dbMetric := range dbMetrics { + dbPayload := &dbPayload{} + // Append no matter what, if we cannot unmarshal, there's no way we'll be able to fix it automatically + metricsIds = append(metricsIds, dbMetric.ID) + + err := json.Unmarshal([]byte(dbMetric.Payload), dbPayload) + if err != nil { + log.Errorf("unable to parse log processor metric (%s)", err) + continue + } + + lpMetrics.Metrics = append(lpMetrics.Metrics, dbPayload.Metrics...) + } + + allMetrics.LogProcessors = append(allMetrics.LogProcessors, &lpMetrics) + } + + // FIXME: all of this should only be done once on startup/reload + consoleOptions := strings.Join(csconfig.GetConfig().API.Server.ConsoleConfig.EnabledOptions(), ",") + allMetrics.Lapi = &models.LapiMetrics{ + ConsoleOptions: models.ConsoleOptions{ + consoleOptions, + }, + } + + osName, osVersion := version.DetectOS() + + allMetrics.Lapi.Os = &models.OSversion{ + Name: ptr.Of(osName), + Version: ptr.Of(osVersion), + } + allMetrics.Lapi.Version = ptr.Of(version.String()) + allMetrics.Lapi.FeatureFlags = fflag.Crowdsec.GetEnabledFeatures() + + allMetrics.Lapi.Metrics = make([]*models.DetailedMetrics, 0) + + allMetrics.Lapi.Metrics = append(allMetrics.Lapi.Metrics, &models.DetailedMetrics{ + Meta: &models.MetricsMeta{ + UtcNowTimestamp: ptr.Of(time.Now().UTC().Unix()), + WindowSizeSeconds: ptr.Of(int64(a.metricsInterval.Seconds())), + }, + Items: make([]*models.MetricsDetailItem, 0), + }) + + // Force an actual slice to avoid non existing fields in the json + if allMetrics.RemediationComponents == nil { + allMetrics.RemediationComponents = make([]*models.RemediationComponentsMetrics, 0) + } + + if allMetrics.LogProcessors == nil { + allMetrics.LogProcessors = make([]*models.LogProcessorsMetrics, 0) + } + + return allMetrics, metricsIds, nil +} + +func (a *apic) MarkUsageMetricsAsSent(ctx context.Context, ids []int) error { + return a.dbClient.MarkUsageMetricsAsSent(ctx, ids) +} + +func (a *apic) GetMetrics(ctx context.Context) (*models.Metrics, error) { + machines, err := a.dbClient.ListMachines(ctx) if err != nil { return nil, err } @@ -26,12 +197,12 @@ func (a *apic) GetMetrics() (*models.Metrics, error) { machinesInfo[i] = &models.MetricsAgentInfo{ Version: machine.Version, Name: machine.MachineId, - LastUpdate: machine.UpdatedAt.String(), - LastPush: ptr.OrEmpty(machine.LastPush).String(), + LastUpdate: machine.UpdatedAt.Format(time.RFC3339), + LastPush: ptr.OrEmpty(machine.LastPush).Format(time.RFC3339), } } - bouncers, err := a.dbClient.ListBouncers() + bouncers, err := a.dbClient.ListBouncers(ctx) if err != nil { return nil, err } @@ -39,11 +210,16 @@ func (a *apic) GetMetrics() (*models.Metrics, error) { bouncersInfo := make([]*models.MetricsBouncerInfo, len(bouncers)) for i, bouncer := range bouncers { + lastPull := "" + if bouncer.LastPull != nil { + lastPull = bouncer.LastPull.Format(time.RFC3339) + } + bouncersInfo[i] = &models.MetricsBouncerInfo{ Version: bouncer.Version, CustomName: bouncer.Name, Name: bouncer.Type, - LastPull: bouncer.LastPull.String(), + LastPull: lastPull, } } @@ -54,8 +230,8 @@ func (a *apic) GetMetrics() (*models.Metrics, error) { }, nil } -func (a *apic) fetchMachineIDs() ([]string, error) { - machines, err := a.dbClient.ListMachines() +func (a *apic) fetchMachineIDs(ctx context.Context) ([]string, error) { + machines, err := a.dbClient.ListMachines(ctx) if err != nil { return nil, err } @@ -66,6 +242,7 @@ func (a *apic) fetchMachineIDs() ([]string, error) { } // sorted slices are required for the slices.Equal comparison slices.Sort(ret) + return ret, nil } @@ -74,16 +251,16 @@ func (a *apic) fetchMachineIDs() ([]string, error) { // Metrics are sent at start, then at the randomized metricsIntervalFirst, // then at regular metricsInterval. If a change is detected in the list // of machines, the next metrics are sent immediately. -func (a *apic) SendMetrics(stop chan (bool)) { +func (a *apic) SendMetrics(ctx context.Context, stop chan (bool)) { defer trace.CatchPanic("lapi/metricsToAPIC") // verify the list of machines every interval const checkInt = 20 * time.Second // intervals must always be > 0 - metInts := []time.Duration{1*time.Millisecond, a.metricsIntervalFirst, a.metricsInterval} + metInts := []time.Duration{1 * time.Millisecond, a.metricsIntervalFirst, a.metricsInterval} - log.Infof("Start send metrics to CrowdSec Central API (interval: %s once, then %s)", + log.Infof("Start sending metrics to CrowdSec Central API (interval: %s once, then %s)", metInts[1].Round(time.Second), metInts[2]) count := -1 @@ -91,17 +268,20 @@ func (a *apic) SendMetrics(stop chan (bool)) { if count < len(metInts)-1 { count++ } + return metInts[count] } machineIDs := []string{} reloadMachineIDs := func() { - ids, err := a.fetchMachineIDs() + ids, err := a.fetchMachineIDs(ctx) if err != nil { log.Debugf("unable to get machines (%s), will retry", err) + return } + machineIDs = ids } @@ -117,32 +297,91 @@ func (a *apic) SendMetrics(stop chan (bool)) { case <-stop: checkTicker.Stop() metTicker.Stop() + return case <-checkTicker.C: oldIDs := machineIDs + reloadMachineIDs() + if !slices.Equal(oldIDs, machineIDs) { log.Infof("capi metrics: machines changed, immediate send") - metTicker.Reset(1*time.Millisecond) + metTicker.Reset(1 * time.Millisecond) } case <-metTicker.C: metTicker.Stop() - metrics, err := a.GetMetrics() + + metrics, err := a.GetMetrics(ctx) if err != nil { - log.Errorf("unable to get metrics (%s), will retry", err) + log.Errorf("unable to get metrics (%s)", err) } - log.Info("capi metrics: sending") - _, _, err = a.apiClient.Metrics.Add(context.Background(), metrics) - if err != nil { - log.Errorf("capi metrics: failed: %s", err) + // metrics are nil if they could not be retrieved + if metrics != nil { + log.Info("capi metrics: sending") + + _, _, err = a.apiClient.Metrics.Add(ctx, metrics) + if err != nil { + log.Errorf("capi metrics: failed: %s", err) + } } + metTicker.Reset(nextMetInt()) case <-a.metricsTomb.Dying(): // if one apic routine is dying, do we kill the others? checkTicker.Stop() metTicker.Stop() a.pullTomb.Kill(nil) a.pushTomb.Kill(nil) + return } } } + +func (a *apic) SendUsageMetrics(ctx context.Context) { + defer trace.CatchPanic("lapi/usageMetricsToAPIC") + + firstRun := true + + log.Debugf("Start sending usage metrics to CrowdSec Central API (interval: %s once, then %s)", a.usageMetricsIntervalFirst, a.usageMetricsInterval) + ticker := time.NewTicker(a.usageMetricsIntervalFirst) + + for { + select { + case <-a.metricsTomb.Dying(): + // The normal metrics routine also kills push/pull tombs, does that make sense ? + ticker.Stop() + return + case <-ticker.C: + if firstRun { + firstRun = false + + ticker.Reset(a.usageMetricsInterval) + } + + metrics, metricsId, err := a.GetUsageMetrics(ctx) + if err != nil { + log.Errorf("unable to get usage metrics: %s", err) + continue + } + + _, resp, err := a.apiClient.UsageMetrics.Add(ctx, metrics) + if err != nil { + log.Errorf("unable to send usage metrics: %s", err) + + if resp.Response.StatusCode >= http.StatusBadRequest && resp.Response.StatusCode != http.StatusUnprocessableEntity { + // In case of 422, mark the metrics as sent anyway, the API did not like what we sent, + // and it's unlikely we'll be able to fix it + continue + } + } + + err = a.MarkUsageMetricsAsSent(ctx, metricsId) + if err != nil { + log.Errorf("unable to mark usage metrics as sent: %s", err) + continue + } + + log.Infof("Sent %d usage metrics", len(metricsId)) + } + } +} diff --git a/pkg/apiserver/apic_metrics_test.go b/pkg/apiserver/apic_metrics_test.go index fdb94f4e2b1..d81af03f710 100644 --- a/pkg/apiserver/apic_metrics_test.go +++ b/pkg/apiserver/apic_metrics_test.go @@ -2,7 +2,6 @@ package apiserver import ( "context" - "fmt" "net/url" "testing" "time" @@ -11,12 +10,12 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/crowdsecurity/go-cs-lib/pkg/version" - "github.com/crowdsecurity/crowdsec/pkg/apiclient" ) func TestAPICSendMetrics(t *testing.T) { + ctx := context.Background() + tests := []struct { name string duration time.Duration @@ -26,18 +25,18 @@ func TestAPICSendMetrics(t *testing.T) { }{ { name: "basic", - duration: time.Millisecond * 60, - metricsInterval: time.Millisecond * 10, + duration: time.Millisecond * 120, + metricsInterval: time.Millisecond * 20, expectedCalls: 5, setUp: func(api *apic) {}, }, { name: "with some metrics", - duration: time.Millisecond * 60, - metricsInterval: time.Millisecond * 10, + duration: time.Millisecond * 120, + metricsInterval: time.Millisecond * 20, expectedCalls: 5, setUp: func(api *apic) { - api.dbClient.Ent.Machine.Delete().ExecX(context.Background()) + api.dbClient.Ent.Machine.Delete().ExecX(ctx) api.dbClient.Ent.Machine.Create(). SetMachineId("1234"). SetPassword(testPassword.String()). @@ -45,26 +44,26 @@ func TestAPICSendMetrics(t *testing.T) { SetScenarios("crowdsecurity/test"). SetLastPush(time.Time{}). SetUpdatedAt(time.Time{}). - ExecX(context.Background()) + ExecX(ctx) - api.dbClient.Ent.Bouncer.Delete().ExecX(context.Background()) + api.dbClient.Ent.Bouncer.Delete().ExecX(ctx) api.dbClient.Ent.Bouncer.Create(). SetIPAddress("1.2.3.6"). SetName("someBouncer"). SetAPIKey("foobar"). SetRevoked(false). SetLastPull(time.Time{}). - ExecX(context.Background()) + ExecX(ctx) }, }, } httpmock.RegisterResponder("POST", "http://api.crowdsec.net/api/metrics/", httpmock.NewBytesResponder(200, []byte{})) httpmock.Activate() + defer httpmock.Deactivate() for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { url, err := url.ParseRequestURI("http://api.crowdsec.net/") require.NoError(t, err) @@ -72,12 +71,12 @@ func TestAPICSendMetrics(t *testing.T) { apiClient, err := apiclient.NewDefaultClient( url, "/api", - fmt.Sprintf("crowdsec/%s", version.String()), + "", nil, ) require.NoError(t, err) - api := getAPIC(t) + api := getAPIC(t, ctx) api.pushInterval = time.Millisecond api.pushIntervalFirst = time.Millisecond api.apiClient = apiClient @@ -86,8 +85,11 @@ func TestAPICSendMetrics(t *testing.T) { tc.setUp(api) stop := make(chan bool) + httpmock.ZeroCallCounters() - go api.SendMetrics(stop) + + go api.SendMetrics(ctx, stop) + time.Sleep(tc.duration) stop <- true diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 8aeb092cd42..b52dc9e44cc 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -20,9 +20,9 @@ import ( "github.com/stretchr/testify/require" "gopkg.in/tomb.v2" - "github.com/crowdsecurity/go-cs-lib/pkg/cstest" - "github.com/crowdsecurity/go-cs-lib/pkg/ptr" - "github.com/crowdsecurity/go-cs-lib/pkg/version" + "github.com/crowdsecurity/go-cs-lib/cstest" + "github.com/crowdsecurity/go-cs-lib/ptr" + "github.com/crowdsecurity/go-cs-lib/version" "github.com/crowdsecurity/crowdsec/pkg/apiclient" "github.com/crowdsecurity/crowdsec/pkg/csconfig" @@ -34,25 +34,28 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -func getDBClient(t *testing.T) *database.Client { +func getDBClient(t *testing.T, ctx context.Context) *database.Client { t.Helper() + dbPath, err := os.CreateTemp("", "*sqlite") require.NoError(t, err) - dbClient, err := database.NewClient(&csconfig.DatabaseCfg{ + dbClient, err := database.NewClient(ctx, &csconfig.DatabaseCfg{ Type: "sqlite", DbName: "crowdsec", DbPath: dbPath.Name(), }) require.NoError(t, err) + return dbClient } -func getAPIC(t *testing.T) *apic { +func getAPIC(t *testing.T, ctx context.Context) *apic { t.Helper() - dbClient := getDBClient(t) + dbClient := getDBClient(t, ctx) + return &apic{ AlertsAddChan: make(chan []*models.Alert), - //DecisionDeleteChan: make(chan []*models.Decision), + // DecisionDeleteChan: make(chan []*models.Decision), dbClient: dbClient, mu: sync.Mutex{}, startup: true, @@ -70,15 +73,17 @@ func getAPIC(t *testing.T) *apic { } } -func absDiff(a int, b int) (c int) { - if c = a - b; c < 0 { +func absDiff(a int, b int) int { + c := a - b + if c < 0 { return -1 * c } + return c } -func assertTotalDecisionCount(t *testing.T, dbClient *database.Client, count int) { - d := dbClient.Ent.Decision.Query().AllX(context.Background()) +func assertTotalDecisionCount(t *testing.T, ctx context.Context, dbClient *database.Client, count int) { + d := dbClient.Ent.Decision.Query().AllX(ctx) assert.Len(t, d, count) } @@ -94,6 +99,7 @@ func jsonMarshalX(v interface{}) []byte { if err != nil { panic(err) } + return data } @@ -103,9 +109,10 @@ func assertTotalAlertCount(t *testing.T, dbClient *database.Client, count int) { } func TestAPICCAPIPullIsOld(t *testing.T) { - api := getAPIC(t) + ctx := context.Background() + api := getAPIC(t, ctx) - isOld, err := api.CAPIPullIsOld() + isOld, err := api.CAPIPullIsOld(ctx) require.NoError(t, err) assert.True(t, isOld) @@ -116,7 +123,7 @@ func TestAPICCAPIPullIsOld(t *testing.T) { SetScope("Country"). SetValue("Blah"). SetOrigin(types.CAPIOrigin). - SaveX(context.Background()) + SaveX(ctx) api.dbClient.Ent.Alert.Create(). SetCreatedAt(time.Now()). @@ -124,15 +131,17 @@ func TestAPICCAPIPullIsOld(t *testing.T) { AddDecisions( decision, ). - SaveX(context.Background()) + SaveX(ctx) - isOld, err = api.CAPIPullIsOld() + isOld, err = api.CAPIPullIsOld(ctx) require.NoError(t, err) assert.False(t, isOld) } func TestAPICFetchScenariosListFromDB(t *testing.T) { + ctx := context.Background() + tests := []struct { name string machineIDsWithScenarios map[string]string @@ -156,32 +165,34 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { - api := getAPIC(t) + api := getAPIC(t, ctx) for machineID, scenarios := range tc.machineIDsWithScenarios { api.dbClient.Ent.Machine.Create(). SetMachineId(machineID). SetPassword(testPassword.String()). SetIpAddress("1.2.3.4"). SetScenarios(scenarios). - ExecX(context.Background()) + ExecX(ctx) } - scenarios, err := api.FetchScenariosListFromDB() + scenarios, err := api.FetchScenariosListFromDB(ctx) + require.NoError(t, err) + for machineID := range tc.machineIDsWithScenarios { - api.dbClient.Ent.Machine.Delete().Where(machine.MachineIdEQ(machineID)).ExecX(context.Background()) + api.dbClient.Ent.Machine.Delete().Where(machine.MachineIdEQ(machineID)).ExecX(ctx) } - require.NoError(t, err) assert.ElementsMatch(t, tc.expectedScenarios, scenarios) }) - } } func TestNewAPIC(t *testing.T) { + ctx := context.Background() + var testConfig *csconfig.OnlineApiClientCfg + setConfig := func() { testConfig = &csconfig.OnlineApiClientCfg{ Credentials: &csconfig.ApiCredentialsCfg{ @@ -196,6 +207,7 @@ func TestNewAPIC(t *testing.T) { dbClient *database.Client consoleConfig *csconfig.ConsoleConfig } + tests := []struct { name string args args @@ -206,7 +218,7 @@ func TestNewAPIC(t *testing.T) { name: "simple", action: func() {}, args: args{ - dbClient: getDBClient(t), + dbClient: getDBClient(t, ctx), consoleConfig: LoadTestConfig(t).API.Server.ConsoleConfig, }, }, @@ -214,17 +226,18 @@ func TestNewAPIC(t *testing.T) { name: "error in parsing URL", action: func() { testConfig.Credentials.URL = "foobar http://" }, args: args{ - dbClient: getDBClient(t), + dbClient: getDBClient(t, ctx), consoleConfig: LoadTestConfig(t).API.Server.ConsoleConfig, }, expectedErr: "first path segment in URL cannot contain colon", }, } + for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { setConfig() httpmock.Activate() + defer httpmock.DeactivateAndReset() httpmock.RegisterResponder("POST", "http://foobar/v3/watchers/login", httpmock.NewBytesResponder( 200, jsonMarshalX( @@ -236,14 +249,15 @@ func TestNewAPIC(t *testing.T) { ), )) tc.action() - _, err := NewAPIC(testConfig, tc.args.dbClient, tc.args.consoleConfig, nil) + _, err := NewAPIC(ctx, testConfig, tc.args.dbClient, tc.args.consoleConfig, nil) cstest.RequireErrorContains(t, err, tc.expectedErr) }) } } func TestAPICHandleDeletedDecisions(t *testing.T) { - api := getAPIC(t) + ctx := context.Background() + api := getAPIC(t, ctx) _, deleteCounters := makeAddAndDeleteCounters() decision1 := api.dbClient.Ent.Decision.Create(). @@ -264,7 +278,7 @@ func TestAPICHandleDeletedDecisions(t *testing.T) { SetOrigin(types.CAPIOrigin). SaveX(context.Background()) - assertTotalDecisionCount(t, api.dbClient, 2) + assertTotalDecisionCount(t, ctx, api.dbClient, 2) nbDeleted, err := api.HandleDeletedDecisions([]*models.Decision{{ Value: ptr.Of("1.2.3.4"), @@ -274,15 +288,17 @@ func TestAPICHandleDeletedDecisions(t *testing.T) { Scope: ptr.Of("IP"), }}, deleteCounters) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, 2, nbDeleted) assert.Equal(t, 2, deleteCounters[types.CAPIOrigin]["all"]) } func TestAPICGetMetrics(t *testing.T) { + ctx := context.Background() + cleanUp := func(api *apic) { - api.dbClient.Ent.Bouncer.Delete().ExecX(context.Background()) - api.dbClient.Ent.Machine.Delete().ExecX(context.Background()) + api.dbClient.Ent.Bouncer.Delete().ExecX(ctx) + api.dbClient.Ent.Machine.Delete().ExecX(ctx) } tests := []struct { name string @@ -309,40 +325,41 @@ func TestAPICGetMetrics(t *testing.T) { Bouncers: []*models.MetricsBouncerInfo{ { CustomName: "1", - LastPull: time.Time{}.String(), + LastPull: time.Time{}.Format(time.RFC3339), }, { CustomName: "2", - LastPull: time.Time{}.String(), + LastPull: time.Time{}.Format(time.RFC3339), }, { CustomName: "3", - LastPull: time.Time{}.String(), + LastPull: time.Time{}.Format(time.RFC3339), }, }, Machines: []*models.MetricsAgentInfo{ { Name: "a", - LastPush: time.Time{}.String(), - LastUpdate: time.Time{}.String(), + LastPush: time.Time{}.Format(time.RFC3339), + LastUpdate: time.Time{}.Format(time.RFC3339), }, { Name: "b", - LastPush: time.Time{}.String(), - LastUpdate: time.Time{}.String(), + LastPush: time.Time{}.Format(time.RFC3339), + LastUpdate: time.Time{}.Format(time.RFC3339), }, { Name: "c", - LastPush: time.Time{}.String(), - LastUpdate: time.Time{}.String(), + LastPush: time.Time{}.Format(time.RFC3339), + LastUpdate: time.Time{}.Format(time.RFC3339), }, }, }, }, } + for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { - apiClient := getAPIC(t) + apiClient := getAPIC(t, ctx) cleanUp(apiClient) + for i, machineID := range tc.machineIDs { apiClient.dbClient.Ent.Machine.Create(). SetMachineId(machineID). @@ -351,7 +368,7 @@ func TestAPICGetMetrics(t *testing.T) { SetScenarios("crowdsecurity/test"). SetLastPush(time.Time{}). SetUpdatedAt(time.Time{}). - ExecX(context.Background()) + ExecX(ctx) } for i, bouncerName := range tc.bouncers { @@ -361,15 +378,14 @@ func TestAPICGetMetrics(t *testing.T) { SetAPIKey("foobar"). SetRevoked(false). SetLastPull(time.Time{}). - ExecX(context.Background()) + ExecX(ctx) } - foundMetrics, err := apiClient.GetMetrics() + foundMetrics, err := apiClient.GetMetrics(ctx) require.NoError(t, err) assert.Equal(t, tc.expectedMetric.Bouncers, foundMetrics.Bouncers) assert.Equal(t, tc.expectedMetric.Machines, foundMetrics.Machines) - }) } } @@ -394,9 +410,11 @@ func TestCreateAlertsForDecision(t *testing.T) { Origin: ptr.Of(types.CAPIOrigin), Scenario: ptr.Of("crowdsecurity/ssh-bf"), } + type args struct { decisions []*models.Decision } + tests := []struct { name string args args @@ -443,8 +461,8 @@ func TestCreateAlertsForDecision(t *testing.T) { }, }, } + for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { if got := createAlertsForDecisions(tc.args.decisions); !reflect.DeepEqual(got, tc.want) { t.Errorf("createAlertsForDecisions() = %v, want %v", got, tc.want) @@ -477,10 +495,12 @@ func TestFillAlertsWithDecisions(t *testing.T) { Scenario: ptr.Of("crowdsecurity/ssh-bf"), Scope: ptr.Of("ip"), } + type args struct { alerts []*models.Alert decisions []*models.Decision } + tests := []struct { name string args args @@ -520,8 +540,8 @@ func TestFillAlertsWithDecisions(t *testing.T) { }, }, } + for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { addCounters, _ := makeAddAndDeleteCounters() if got := fillAlertsWithDecisions(tc.args.alerts, tc.args.decisions, addCounters); !reflect.DeepEqual(got, tc.want) { @@ -532,27 +552,22 @@ func TestFillAlertsWithDecisions(t *testing.T) { } func TestAPICWhitelists(t *testing.T) { - api := getAPIC(t) - //one whitelist on IP, one on CIDR + ctx := context.Background() + api := getAPIC(t, ctx) + // one whitelist on IP, one on CIDR api.whitelists = &csconfig.CapiWhitelist{} - ipwl1 := "9.2.3.4" - ip := net.ParseIP(ipwl1) - api.whitelists.Ips = append(api.whitelists.Ips, ip) - ipwl1 = "7.2.3.4" - ip = net.ParseIP(ipwl1) - api.whitelists.Ips = append(api.whitelists.Ips, ip) - cidrwl1 := "13.2.3.0/24" - _, tnet, err := net.ParseCIDR(cidrwl1) - if err != nil { - t.Fatalf("unable to parse cidr : %s", err) - } + api.whitelists.Ips = append(api.whitelists.Ips, net.ParseIP("9.2.3.4"), net.ParseIP("7.2.3.4")) + + _, tnet, err := net.ParseCIDR("13.2.3.0/24") + require.NoError(t, err) + api.whitelists.Cidrs = append(api.whitelists.Cidrs, tnet) - cidrwl1 = "11.2.3.0/24" - _, tnet, err = net.ParseCIDR(cidrwl1) - if err != nil { - t.Fatalf("unable to parse cidr : %s", err) - } + + _, tnet, err = net.ParseCIDR("11.2.3.0/24") + require.NoError(t, err) + api.whitelists.Cidrs = append(api.whitelists.Cidrs, tnet) + api.dbClient.Ent.Decision.Create(). SetOrigin(types.CAPIOrigin). SetType("ban"). @@ -561,9 +576,10 @@ func TestAPICWhitelists(t *testing.T) { SetScenario("crowdsecurity/ssh-bf"). SetUntil(time.Now().Add(time.Hour)). ExecX(context.Background()) - assertTotalDecisionCount(t, api.dbClient, 1) + assertTotalDecisionCount(t, ctx, api.dbClient, 1) assertTotalValidDecisionCount(t, api.dbClient, 1) httpmock.Activate() + defer httpmock.DeactivateAndReset() httpmock.RegisterResponder("GET", "http://api.crowdsec.net/api/decisions/stream", httpmock.NewBytesResponder( 200, jsonMarshalX( @@ -572,7 +588,7 @@ func TestAPICWhitelists(t *testing.T) { &modelscapi.GetDecisionsStreamResponseDeletedItem{ Decisions: []string{ "9.9.9.9", // This is already present in DB - "9.1.9.9", // This not present in DB + "9.1.9.9", // This is not present in DB }, Scope: ptr.Of("Ip"), }, // This is already present in DB @@ -583,7 +599,7 @@ func TestAPICWhitelists(t *testing.T) { Scope: ptr.Of("Ip"), Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{ { - Value: ptr.Of("13.2.3.4"), //wl by cidr + Value: ptr.Of("13.2.3.4"), // wl by cidr Duration: ptr.Of("24h"), }, }, @@ -604,7 +620,7 @@ func TestAPICWhitelists(t *testing.T) { Scope: ptr.Of("Ip"), Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{ { - Value: ptr.Of("13.2.3.5"), //wl by cidr + Value: ptr.Of("13.2.3.5"), // wl by cidr Duration: ptr.Of("24h"), }, }, @@ -624,7 +640,7 @@ func TestAPICWhitelists(t *testing.T) { Scope: ptr.Of("Ip"), Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{ { - Value: ptr.Of("9.2.3.4"), //wl by ip + Value: ptr.Of("9.2.3.4"), // wl by ip Duration: ptr.Of("24h"), }, }, @@ -651,28 +667,31 @@ func TestAPICWhitelists(t *testing.T) { }, ), )) + httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", httpmock.NewStringResponder( 200, "1.2.3.6", )) + httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist2", httpmock.NewStringResponder( 200, "1.2.3.7", )) + url, err := url.ParseRequestURI("http://api.crowdsec.net/") require.NoError(t, err) apic, err := apiclient.NewDefaultClient( url, "/api", - fmt.Sprintf("crowdsec/%s", version.String()), + "", nil, ) require.NoError(t, err) api.apiClient = apic - err = api.PullTop(false) + err = api.PullTop(ctx, false) require.NoError(t, err) - assertTotalDecisionCount(t, api.dbClient, 5) //2 from FIRE + 2 from bl + 1 existing + assertTotalDecisionCount(t, ctx, api.dbClient, 5) // 2 from FIRE + 2 from bl + 1 existing assertTotalValidDecisionCount(t, api.dbClient, 4) assertTotalAlertCount(t, api.dbClient, 3) // 2 for list sub , 1 for community list. alerts := api.dbClient.Ent.Alert.Query().AllX(context.Background()) @@ -681,40 +700,47 @@ func TestAPICWhitelists(t *testing.T) { AllX(context.Background()) decisionScenarioFreq := make(map[string]int) - decisionIp := make(map[string]int) + decisionIP := make(map[string]int) alertScenario := make(map[string]int) for _, alert := range alerts { alertScenario[alert.SourceScope]++ } - assert.Equal(t, 3, len(alertScenario)) - assert.Equal(t, 1, alertScenario[SCOPE_CAPI_ALIAS_ALIAS]) + + assert.Len(t, alertScenario, 3) + assert.Equal(t, 1, alertScenario[types.CommunityBlocklistPullSourceScope]) assert.Equal(t, 1, alertScenario["lists:blocklist1"]) assert.Equal(t, 1, alertScenario["lists:blocklist2"]) for _, decisions := range validDecisions { decisionScenarioFreq[decisions.Scenario]++ - decisionIp[decisions.Value]++ + decisionIP[decisions.Value]++ } - assert.Equal(t, 1, decisionIp["2.2.3.4"], 1) - assert.Equal(t, 1, decisionIp["6.2.3.4"], 1) - if _, ok := decisionIp["13.2.3.4"]; ok { + + assert.Equal(t, 1, decisionIP["2.2.3.4"], 1) + assert.Equal(t, 1, decisionIP["6.2.3.4"], 1) + + if _, ok := decisionIP["13.2.3.4"]; ok { t.Errorf("13.2.3.4 is whitelisted") } - if _, ok := decisionIp["13.2.3.5"]; ok { + + if _, ok := decisionIP["13.2.3.5"]; ok { t.Errorf("13.2.3.5 is whitelisted") } - if _, ok := decisionIp["9.2.3.4"]; ok { + + if _, ok := decisionIP["9.2.3.4"]; ok { t.Errorf("9.2.3.4 is whitelisted") } + assert.Equal(t, 1, decisionScenarioFreq["blocklist1"], 1) assert.Equal(t, 1, decisionScenarioFreq["blocklist2"], 1) assert.Equal(t, 2, decisionScenarioFreq["crowdsecurity/test1"], 2) } func TestAPICPullTop(t *testing.T) { - api := getAPIC(t) + ctx := context.Background() + api := getAPIC(t, ctx) api.dbClient.Ent.Decision.Create(). SetOrigin(types.CAPIOrigin). SetType("ban"). @@ -722,10 +748,11 @@ func TestAPICPullTop(t *testing.T) { SetScope("Ip"). SetScenario("crowdsecurity/ssh-bf"). SetUntil(time.Now().Add(time.Hour)). - ExecX(context.Background()) - assertTotalDecisionCount(t, api.dbClient, 1) + ExecX(ctx) + assertTotalDecisionCount(t, ctx, api.dbClient, 1) assertTotalValidDecisionCount(t, api.dbClient, 1) httpmock.Activate() + defer httpmock.DeactivateAndReset() httpmock.RegisterResponder("GET", "http://api.crowdsec.net/api/decisions/stream", httpmock.NewBytesResponder( 200, jsonMarshalX( @@ -734,7 +761,7 @@ func TestAPICPullTop(t *testing.T) { &modelscapi.GetDecisionsStreamResponseDeletedItem{ Decisions: []string{ "9.9.9.9", // This is already present in DB - "9.1.9.9", // This not present in DB + "9.1.9.9", // This is not present in DB }, Scope: ptr.Of("Ip"), }, // This is already present in DB @@ -782,28 +809,31 @@ func TestAPICPullTop(t *testing.T) { }, ), )) + httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", httpmock.NewStringResponder( 200, "1.2.3.6", )) + httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist2", httpmock.NewStringResponder( 200, "1.2.3.7", )) + url, err := url.ParseRequestURI("http://api.crowdsec.net/") require.NoError(t, err) apic, err := apiclient.NewDefaultClient( url, "/api", - fmt.Sprintf("crowdsec/%s", version.String()), + "", nil, ) require.NoError(t, err) api.apiClient = apic - err = api.PullTop(false) + err = api.PullTop(ctx, false) require.NoError(t, err) - assertTotalDecisionCount(t, api.dbClient, 5) + assertTotalDecisionCount(t, ctx, api.dbClient, 5) assertTotalValidDecisionCount(t, api.dbClient, 4) assertTotalAlertCount(t, api.dbClient, 3) // 2 for list sub , 1 for community list. alerts := api.dbClient.Ent.Alert.Query().AllX(context.Background()) @@ -817,8 +847,9 @@ func TestAPICPullTop(t *testing.T) { for _, alert := range alerts { alertScenario[alert.SourceScope]++ } - assert.Equal(t, 3, len(alertScenario)) - assert.Equal(t, 1, alertScenario[SCOPE_CAPI_ALIAS_ALIAS]) + + assert.Len(t, alertScenario, 3) + assert.Equal(t, 1, alertScenario[types.CommunityBlocklistPullSourceScope]) assert.Equal(t, 1, alertScenario["lists:blocklist1"]) assert.Equal(t, 1, alertScenario["lists:blocklist2"]) @@ -833,10 +864,13 @@ func TestAPICPullTop(t *testing.T) { } func TestAPICPullTopBLCacheFirstCall(t *testing.T) { + ctx := context.Background() // no decision in db, no last modified parameter. - api := getAPIC(t) + api := getAPIC(t, ctx) + httpmock.Activate() defer httpmock.DeactivateAndReset() + httpmock.RegisterResponder("GET", "http://api.crowdsec.net/api/decisions/stream", httpmock.NewBytesResponder( 200, jsonMarshalX( modelscapi.GetDecisionsStreamResponse{ @@ -866,27 +900,29 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) { }, ), )) + httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", func(req *http.Request) (*http.Response, error) { assert.Equal(t, "", req.Header.Get("If-Modified-Since")) return httpmock.NewStringResponse(200, "1.2.3.4"), nil }) + url, err := url.ParseRequestURI("http://api.crowdsec.net/") require.NoError(t, err) apic, err := apiclient.NewDefaultClient( url, "/api", - fmt.Sprintf("crowdsec/%s", version.String()), + "", nil, ) require.NoError(t, err) api.apiClient = apic - err = api.PullTop(false) + err = api.PullTop(ctx, false) require.NoError(t, err) blocklistConfigItemName := "blocklist:blocklist1:last_pull" - lastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName) + lastPullTimestamp, err := api.dbClient.GetConfigItem(ctx, blocklistConfigItemName) require.NoError(t, err) assert.NotEqual(t, "", *lastPullTimestamp) @@ -895,17 +931,21 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) { assert.NotEqual(t, "", req.Header.Get("If-Modified-Since")) return httpmock.NewStringResponse(304, ""), nil }) - err = api.PullTop(false) + + err = api.PullTop(ctx, false) require.NoError(t, err) - secondLastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName) + secondLastPullTimestamp, err := api.dbClient.GetConfigItem(ctx, blocklistConfigItemName) require.NoError(t, err) assert.Equal(t, *lastPullTimestamp, *secondLastPullTimestamp) } func TestAPICPullTopBLCacheForceCall(t *testing.T) { - api := getAPIC(t) + ctx := context.Background() + api := getAPIC(t, ctx) + httpmock.Activate() defer httpmock.DeactivateAndReset() + // create a decision about to expire. It should force fetch alertInstance := api.dbClient.Ent.Alert. Create(). @@ -953,27 +993,64 @@ func TestAPICPullTopBLCacheForceCall(t *testing.T) { }, ), )) + httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", func(req *http.Request) (*http.Response, error) { assert.Equal(t, "", req.Header.Get("If-Modified-Since")) return httpmock.NewStringResponse(304, ""), nil }) + + url, err := url.ParseRequestURI("http://api.crowdsec.net/") + require.NoError(t, err) + + apic, err := apiclient.NewDefaultClient( + url, + "/api", + "", + nil, + ) + require.NoError(t, err) + + api.apiClient = apic + err = api.PullTop(ctx, false) + require.NoError(t, err) +} + +func TestAPICPullBlocklistCall(t *testing.T) { + ctx := context.Background() + api := getAPIC(t, ctx) + + httpmock.Activate() + defer httpmock.DeactivateAndReset() + + httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", func(req *http.Request) (*http.Response, error) { + assert.Equal(t, "", req.Header.Get("If-Modified-Since")) + return httpmock.NewStringResponse(200, "1.2.3.4"), nil + }) + url, err := url.ParseRequestURI("http://api.crowdsec.net/") require.NoError(t, err) apic, err := apiclient.NewDefaultClient( url, "/api", - fmt.Sprintf("crowdsec/%s", version.String()), + "", nil, ) require.NoError(t, err) api.apiClient = apic - err = api.PullTop(false) + err = api.PullBlocklist(ctx, &modelscapi.BlocklistLink{ + URL: ptr.Of("http://api.crowdsec.net/blocklist1"), + Name: ptr.Of("blocklist1"), + Scope: ptr.Of("Ip"), + Remediation: ptr.Of("ban"), + Duration: ptr.Of("24h"), + }, true) require.NoError(t, err) } func TestAPICPush(t *testing.T) { + ctx := context.Background() tests := []struct { name string alerts []*models.Alert @@ -1010,7 +1087,7 @@ func TestAPICPush(t *testing.T) { expectedCalls: 2, alerts: func() []*models.Alert { alerts := make([]*models.Alert, 100) - for i := 0; i < 100; i++ { + for i := range 100 { alerts[i] = &models.Alert{ Scenario: ptr.Of("crowdsec/test"), ScenarioHash: ptr.Of("certified"), @@ -1019,15 +1096,15 @@ func TestAPICPush(t *testing.T) { Source: &models.Source{}, } } + return alerts }(), }, } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { - api := getAPIC(t) + api := getAPIC(t, ctx) api.pushInterval = time.Millisecond api.pushIntervalFirst = time.Millisecond url, err := url.ParseRequestURI("http://api.crowdsec.net/") @@ -1035,22 +1112,29 @@ func TestAPICPush(t *testing.T) { httpmock.Activate() defer httpmock.DeactivateAndReset() + apic, err := apiclient.NewDefaultClient( url, "/api", - fmt.Sprintf("crowdsec/%s", version.String()), + "", nil, ) require.NoError(t, err) api.apiClient = apic + httpmock.RegisterResponder("POST", "http://api.crowdsec.net/api/signals", httpmock.NewBytesResponder(200, []byte{})) + + // capture the alerts to avoid datarace + alerts := tc.alerts go func() { - api.AlertsAddChan <- tc.alerts + api.AlertsAddChan <- alerts + time.Sleep(time.Second) api.Shutdown() }() - err = api.Push() + + err = api.Push(ctx) require.NoError(t, err) assert.Equal(t, tc.expectedCalls, httpmock.GetTotalCallCount()) }) @@ -1058,7 +1142,8 @@ func TestAPICPush(t *testing.T) { } func TestAPICPull(t *testing.T) { - api := getAPIC(t) + ctx := context.Background() + api := getAPIC(t, ctx) tests := []struct { name string setUp func() @@ -1085,23 +1170,26 @@ func TestAPICPull(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { - api = getAPIC(t) + api = getAPIC(t, ctx) api.pullInterval = time.Millisecond api.pullIntervalFirst = time.Millisecond url, err := url.ParseRequestURI("http://api.crowdsec.net/") require.NoError(t, err) httpmock.Activate() + defer httpmock.DeactivateAndReset() + apic, err := apiclient.NewDefaultClient( url, "/api", - fmt.Sprintf("crowdsec/%s", version.String()), + "", nil, ) require.NoError(t, err) + api.apiClient = apic + httpmock.RegisterNoResponder(httpmock.NewBytesResponder(200, jsonMarshalX( modelscapi.GetDecisionsStreamResponse{ New: modelscapi.GetDecisionsStreamResponseNew{ @@ -1119,18 +1207,22 @@ func TestAPICPull(t *testing.T) { }, ))) tc.setUp() + var buf bytes.Buffer + go func() { logrus.SetOutput(&buf) - if err := api.Pull(); err != nil { + + if err := api.Pull(ctx); err != nil { panic(err) } }() - //Slightly long because the CI runner for windows are slow, and this can lead to random failure + + // Slightly long because the CI runner for windows are slow, and this can lead to random failure time.Sleep(time.Millisecond * 500) logrus.SetOutput(os.Stderr) assert.Contains(t, buf.String(), tc.logContains) - assertTotalDecisionCount(t, api.dbClient, tc.expectedDecisionCount) + assertTotalDecisionCount(t, ctx, api.dbClient, tc.expectedDecisionCount) }) } } @@ -1212,7 +1304,6 @@ func TestShouldShareAlert(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { ret := shouldShareAlert(tc.alert, tc.consoleConfig) assert.Equal(t, tc.expectedRet, ret) diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index d802822f84e..35f9beaf635 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -2,8 +2,7 @@ package apiserver import ( "context" - "crypto/tls" - "crypto/x509" + "errors" "fmt" "io" "net" @@ -15,30 +14,25 @@ import ( "github.com/gin-gonic/gin" "github.com/go-co-op/gocron" - "github.com/golang-jwt/jwt/v4" - "github.com/pkg/errors" log "github.com/sirupsen/logrus" "gopkg.in/natefinch/lumberjack.v2" "gopkg.in/tomb.v2" - "github.com/crowdsecurity/go-cs-lib/pkg/trace" + "github.com/crowdsecurity/go-cs-lib/trace" - "github.com/crowdsecurity/crowdsec/pkg/apiclient" "github.com/crowdsecurity/crowdsec/pkg/apiserver/controllers" v1 "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csplugin" "github.com/crowdsecurity/crowdsec/pkg/database" - "github.com/crowdsecurity/crowdsec/pkg/fflag" "github.com/crowdsecurity/crowdsec/pkg/types" ) -var ( - keyLength = 32 -) +const keyLength = 32 type APIServer struct { URL string + UnixSocket string TLS *csconfig.TLSCfg dbClient *database.Client logFile string @@ -50,129 +44,163 @@ type APIServer struct { papi *Papi httpServerTomb tomb.Tomb consoleConfig *csconfig.ConsoleConfig - isEnrolled bool } -// RecoveryWithWriter returns a middleware for a given writer that recovers from any panics and writes a 500 if there was one. -func CustomRecoveryWithWriter() gin.HandlerFunc { - return func(c *gin.Context) { - defer func() { - if err := recover(); err != nil { - // Check for a broken connection, as it is not really a - // condition that warrants a panic stack trace. - var brokenPipe bool - if ne, ok := err.(*net.OpError); ok { - if se, ok := ne.Err.(*os.SyscallError); ok { - if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") { - brokenPipe = true - } - } - } +func recoverFromPanic(c *gin.Context) { + err := recover() + if err == nil { + return + } - // because of https://github.com/golang/net/blob/39120d07d75e76f0079fe5d27480bcb965a21e4c/http2/server.go - // and because it seems gin doesn't handle those neither, we need to "hand define" some errors to properly catch them - if strErr, ok := err.(error); ok { - //stolen from http2/server.go in x/net - var ( - errClientDisconnected = errors.New("client disconnected") - errClosedBody = errors.New("body closed by handler") - errHandlerComplete = errors.New("http2: request body closed due to handler exiting") - errStreamClosed = errors.New("http2: stream closed") - ) - if errors.Is(strErr, errClientDisconnected) || - errors.Is(strErr, errClosedBody) || - errors.Is(strErr, errHandlerComplete) || - errors.Is(strErr, errStreamClosed) { - brokenPipe = true - } - } + // Check for a broken connection, as it is not really a + // condition that warrants a panic stack trace. + brokenPipe := false - if brokenPipe { - log.Warningf("client %s disconnected : %s", c.ClientIP(), err) - c.Abort() - } else { - filename := trace.WriteStackTrace(err) - log.Warningf("client %s error : %s", c.ClientIP(), err) - log.Warningf("stacktrace written to %s, please join to your issue", filename) - c.AbortWithStatus(http.StatusInternalServerError) - } + if ne, ok := err.(*net.OpError); ok { + if se, ok := ne.Err.(*os.SyscallError); ok { + if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") { + brokenPipe = true } - }() - c.Next() + } } -} -func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { - var flushScheduler *gocron.Scheduler - dbClient, err := database.NewClient(config.DbConfig) - if err != nil { - return &APIServer{}, fmt.Errorf("unable to init database client: %w", err) - } + // because of https://github.com/golang/net/blob/39120d07d75e76f0079fe5d27480bcb965a21e4c/http2/server.go + // and because it seems gin doesn't handle those neither, we need to "hand define" some errors to properly catch them + if strErr, ok := err.(error); ok { + // stolen from http2/server.go in x/net + var ( + errClientDisconnected = errors.New("client disconnected") + errClosedBody = errors.New("body closed by handler") + errHandlerComplete = errors.New("http2: request body closed due to handler exiting") + errStreamClosed = errors.New("http2: stream closed") + ) - if config.DbConfig.Flush != nil { - flushScheduler, err = dbClient.StartFlushScheduler(config.DbConfig.Flush) - if err != nil { - return &APIServer{}, err + if errors.Is(strErr, errClientDisconnected) || + errors.Is(strErr, errClosedBody) || + errors.Is(strErr, errHandlerComplete) || + errors.Is(strErr, errStreamClosed) { + brokenPipe = true } } - logFile := "" - if config.LogMedia == "file" { - logFile = filepath.Join(config.LogDir, "crowdsec_api.log") - } + if brokenPipe { + log.Warningf("client %s disconnected: %s", c.ClientIP(), err) + c.Abort() + } else { + log.Warningf("client %s error: %s", c.ClientIP(), err) - if log.GetLevel() < log.DebugLevel { - gin.SetMode(gin.ReleaseMode) + filename, err := trace.WriteStackTrace(err) + if err != nil { + log.Errorf("also while writing stacktrace: %s", err) + } + + log.Warningf("stacktrace written to %s, please join to your issue", filename) + c.AbortWithStatus(http.StatusInternalServerError) } - log.Debugf("starting router, logging to %s", logFile) - router := gin.New() +} - if config.TrustedProxies != nil && config.UseForwardedForHeaders { - if err := router.SetTrustedProxies(*config.TrustedProxies); err != nil { - return &APIServer{}, fmt.Errorf("while setting trusted_proxies: %w", err) - } - router.ForwardedByClientIP = true - } else { - router.ForwardedByClientIP = false +// CustomRecoveryWithWriter returns a middleware for a writer that recovers from any panics and writes a 500 if there was one. +func CustomRecoveryWithWriter() gin.HandlerFunc { + return func(c *gin.Context) { + defer recoverFromPanic(c) + c.Next() } +} - /*The logger that will be used by handlers*/ +// XXX: could be a method of LocalApiServerCfg +func newGinLogger(config *csconfig.LocalApiServerCfg) (*log.Logger, string, error) { clog := log.New() if err := types.ConfigureLogger(clog); err != nil { - return nil, fmt.Errorf("while configuring gin logger: %w", err) + return nil, "", fmt.Errorf("while configuring gin logger: %w", err) } + if config.LogLevel != nil { clog.SetLevel(*config.LogLevel) } - /*Configure logs*/ - if logFile != "" { - _maxsize := 500 - if config.LogMaxSize != 0 { - _maxsize = config.LogMaxSize - } - _maxfiles := 3 - if config.LogMaxFiles != 0 { - _maxfiles = config.LogMaxFiles - } - _maxage := 28 - if config.LogMaxAge != 0 { - _maxage = config.LogMaxAge + if config.LogMedia != "file" { + return clog, "", nil + } + + // Log rotation + + logFile := filepath.Join(config.LogDir, "crowdsec_api.log") + log.Debugf("starting router, logging to %s", logFile) + + logger := &lumberjack.Logger{ + Filename: logFile, + MaxSize: 500, // megabytes + MaxBackups: 3, + MaxAge: 28, // days + Compress: true, // disabled by default + } + + if config.LogMaxSize != 0 { + logger.MaxSize = config.LogMaxSize + } + + if config.LogMaxFiles != 0 { + logger.MaxBackups = config.LogMaxFiles + } + + if config.LogMaxAge != 0 { + logger.MaxAge = config.LogMaxAge + } + + if config.CompressLogs != nil { + logger.Compress = *config.CompressLogs + } + + clog.SetOutput(logger) + + return clog, logFile, nil +} + +// NewServer creates a LAPI server. +// It sets up a gin router, a database client, and a controller. +func NewServer(ctx context.Context, config *csconfig.LocalApiServerCfg) (*APIServer, error) { + var flushScheduler *gocron.Scheduler + + dbClient, err := database.NewClient(ctx, config.DbConfig) + if err != nil { + return nil, fmt.Errorf("unable to init database client: %w", err) + } + + if config.DbConfig.Flush != nil { + flushScheduler, err = dbClient.StartFlushScheduler(ctx, config.DbConfig.Flush) + if err != nil { + return nil, err } - _compress := true - if config.CompressLogs != nil { - _compress = *config.CompressLogs + } + + if log.GetLevel() < log.DebugLevel { + gin.SetMode(gin.ReleaseMode) + } + + router := gin.New() + + router.ForwardedByClientIP = false + + // set the remore address of the request to 127.0.0.1 if it comes from a unix socket + router.Use(func(c *gin.Context) { + if c.Request.RemoteAddr == "@" { + c.Request.RemoteAddr = "127.0.0.1:65535" } + }) - LogOutput := &lumberjack.Logger{ - Filename: logFile, - MaxSize: _maxsize, //megabytes - MaxBackups: _maxfiles, - MaxAge: _maxage, //days - Compress: _compress, //disabled by default + if config.TrustedProxies != nil && config.UseForwardedForHeaders { + if err = router.SetTrustedProxies(*config.TrustedProxies); err != nil { + return nil, fmt.Errorf("while setting trusted_proxies: %w", err) } - clog.SetOutput(LogOutput) + + router.ForwardedByClientIP = true + } + + // The logger that will be used by handlers + clog, logFile, err := newGinLogger(config) + if err != nil { + return nil, err } gin.DefaultErrorWriter = clog.WriterLevel(log.ErrorLevel) @@ -199,53 +227,60 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { controller := &controllers.Controller{ DBClient: dbClient, - Ectx: context.Background(), Router: router, Profiles: config.Profiles, Log: clog, ConsoleConfig: config.ConsoleConfig, DisableRemoteLapiRegistration: config.DisableRemoteLapiRegistration, + AutoRegisterCfg: config.AutoRegister, } - var apiClient *apic - var papiClient *Papi - var isMachineEnrolled = false + var ( + apiClient *apic + papiClient *Papi + ) + + controller.AlertsAddChan = nil + controller.DecisionDeleteChan = nil if config.OnlineClient != nil && config.OnlineClient.Credentials != nil { log.Printf("Loading CAPI manager") - apiClient, err = NewAPIC(config.OnlineClient, dbClient, config.ConsoleConfig, config.CapiWhitelists) + + apiClient, err = NewAPIC(ctx, config.OnlineClient, dbClient, config.ConsoleConfig, config.CapiWhitelists) if err != nil { - return &APIServer{}, err + return nil, err } + log.Infof("CAPI manager configured successfully") - isMachineEnrolled = isEnrolled(apiClient.apiClient) + controller.AlertsAddChan = apiClient.AlertsAddChan - if fflag.PapiClient.IsEnabled() { - if isMachineEnrolled { - log.Infof("Machine is enrolled in the console, Loading PAPI Client") + + if config.ConsoleConfig.IsPAPIEnabled() && config.OnlineClient.Credentials.PapiURL != "" { + if apiClient.apiClient.IsEnrolled() { + log.Info("Machine is enrolled in the console, Loading PAPI Client") + papiClient, err = NewPAPI(apiClient, dbClient, config.ConsoleConfig, *config.PapiLogLevel) if err != nil { - return &APIServer{}, err + return nil, err } + controller.DecisionDeleteChan = papiClient.Channels.DeleteDecisionChannel } else { - log.Errorf("Machine is not enrolled in the console, can't synchronize with the console") + log.Error("Machine is not enrolled in the console, can't synchronize with the console") } } - } else { - apiClient = nil - controller.AlertsAddChan = nil - controller.DecisionDeleteChan = nil } - if trustedIPs, err := config.GetTrustedIPs(); err == nil { - controller.TrustedIPs = trustedIPs - } else { - return &APIServer{}, err + trustedIPs, err := config.GetTrustedIPs() + if err != nil { + return nil, err } + controller.TrustedIPs = trustedIPs + return &APIServer{ URL: config.ListenURI, + UnixSocket: config.ListenSocket, TLS: config.TLS, logFile: logFile, dbClient: dbClient, @@ -256,164 +291,202 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { papi: papiClient, httpServerTomb: tomb.Tomb{}, consoleConfig: config.ConsoleConfig, - isEnrolled: isMachineEnrolled, }, nil - } -func isEnrolled(client *apiclient.ApiClient) bool { - apiHTTPClient := client.GetClient() - jwtTransport := apiHTTPClient.Transport.(*apiclient.JWTTransport) - tokenStr := jwtTransport.Token +func (s *APIServer) Router() (*gin.Engine, error) { + return s.router, nil +} - token, _ := jwt.Parse(tokenStr, nil) - if token == nil { - return false +func (s *APIServer) apicPush(ctx context.Context) error { + if err := s.apic.Push(ctx); err != nil { + log.Errorf("capi push: %s", err) + return err } - claims := token.Claims.(jwt.MapClaims) - _, ok := claims["organization_id"] - return ok + return nil } -func (s *APIServer) Router() (*gin.Engine, error) { - return s.router, nil -} +func (s *APIServer) apicPull(ctx context.Context) error { + if err := s.apic.Pull(ctx); err != nil { + log.Errorf("capi pull: %s", err) + return err + } -func (s *APIServer) GetTLSConfig() (*tls.Config, error) { - var caCert []byte - var err error - var caCertPool *x509.CertPool - var clientAuthType tls.ClientAuthType + return nil +} - if s.TLS == nil { - return &tls.Config{}, nil +func (s *APIServer) papiPull(ctx context.Context) error { + if err := s.papi.Pull(ctx); err != nil { + log.Errorf("papi pull: %s", err) + return err } - if s.TLS.ClientVerification == "" { - //sounds like a sane default : verify client cert if given, but don't make it mandatory - clientAuthType = tls.VerifyClientCertIfGiven - } else { - clientAuthType, err = getTLSAuthType(s.TLS.ClientVerification) - if err != nil { - return nil, err - } + return nil +} + +func (s *APIServer) papiSync() error { + if err := s.papi.SyncDecisions(); err != nil { + log.Errorf("capi decisions sync: %s", err) + return err } - if s.TLS.CACertPath != "" { - if clientAuthType > tls.RequestClientCert { - log.Infof("(tls) Client Auth Type set to %s", clientAuthType.String()) - caCert, err = os.ReadFile(s.TLS.CACertPath) - if err != nil { - return nil, fmt.Errorf("while opening cert file: %w", err) - } - caCertPool, err = x509.SystemCertPool() - if err != nil { - log.Warnf("Error loading system CA certificates: %s", err) - } - if caCertPool == nil { - caCertPool = x509.NewCertPool() + return nil +} + +func (s *APIServer) initAPIC(ctx context.Context) { + s.apic.pushTomb.Go(func() error { return s.apicPush(ctx) }) + s.apic.pullTomb.Go(func() error { return s.apicPull(ctx) }) + + // csConfig.API.Server.ConsoleConfig.ShareCustomScenarios + if s.apic.apiClient.IsEnrolled() { + if s.consoleConfig.IsPAPIEnabled() && s.papi != nil { + if s.papi.URL != "" { + log.Info("Starting PAPI decision receiver") + s.papi.pullTomb.Go(func() error { return s.papiPull(ctx) }) + s.papi.syncTomb.Go(s.papiSync) + } else { + log.Warnf("papi_url is not set in online_api_credentials.yaml, can't synchronize with the console. Run cscli console enable console_management to add it.") } - caCertPool.AppendCertsFromPEM(caCert) + } else { + log.Warningf("Machine is not allowed to synchronize decisions, you can enable it with `cscli console enable console_management`") } } - return &tls.Config{ - ServerName: s.TLS.ServerName, //should it be removed ? - ClientAuth: clientAuthType, - ClientCAs: caCertPool, - MinVersion: tls.VersionTLS12, // TLS versions below 1.2 are considered insecure - see https://www.rfc-editor.org/rfc/rfc7525.txt for details - }, nil + s.apic.metricsTomb.Go(func() error { + s.apic.SendMetrics(ctx, make(chan bool)) + return nil + }) + + s.apic.metricsTomb.Go(func() error { + s.apic.SendUsageMetrics(ctx) + return nil + }) } func (s *APIServer) Run(apiReady chan bool) error { defer trace.CatchPanic("lapi/runServer") - tlsCfg, err := s.GetTLSConfig() + + tlsCfg, err := s.TLS.GetTLSConfig() if err != nil { return fmt.Errorf("while creating TLS config: %w", err) } + s.httpServer = &http.Server{ Addr: s.URL, Handler: s.router, TLSConfig: tlsCfg, } + ctx := context.TODO() + if s.apic != nil { - s.apic.pushTomb.Go(func() error { - if err := s.apic.Push(); err != nil { - log.Errorf("capi push: %s", err) - return err - } - return nil - }) + s.initAPIC(ctx) + } + + s.httpServerTomb.Go(func() error { + return s.listenAndServeLAPI(apiReady) + }) + + if err := s.httpServerTomb.Wait(); err != nil { + return fmt.Errorf("local API server stopped with error: %w", err) + } + + return nil +} - s.apic.pullTomb.Go(func() error { - if err := s.apic.Pull(); err != nil { - log.Errorf("capi pull: %s", err) - return err +// listenAndServeLAPI starts the http server and blocks until it's closed +// it also updates the URL field with the actual address the server is listening on +// it's meant to be run in a separate goroutine +func (s *APIServer) listenAndServeLAPI(apiReady chan bool) error { + var ( + tcpListener net.Listener + unixListener net.Listener + err error + serverError = make(chan error, 2) + listenerClosed = make(chan struct{}) + ) + + startServer := func(listener net.Listener, canTLS bool) { + if canTLS && s.TLS != nil && (s.TLS.CertFilePath != "" || s.TLS.KeyFilePath != "") { + if s.TLS.KeyFilePath == "" { + serverError <- errors.New("missing TLS key file") + return } - return nil - }) - - //csConfig.API.Server.ConsoleConfig.ShareCustomScenarios - if s.isEnrolled { - if fflag.PapiClient.IsEnabled() { - if s.consoleConfig.ConsoleManagement != nil && *s.consoleConfig.ConsoleManagement { - if s.papi.URL != "" { - log.Infof("Starting PAPI decision receiver") - s.papi.pullTomb.Go(func() error { - if err := s.papi.Pull(); err != nil { - log.Errorf("papi pull: %s", err) - return err - } - return nil - }) - - s.papi.syncTomb.Go(func() error { - if err := s.papi.SyncDecisions(); err != nil { - log.Errorf("capi decisions sync: %s", err) - return err - } - return nil - }) - } else { - log.Warnf("papi_url is not set in online_api_credentials.yaml, can't synchronize with the console. Run cscli console enable console_management to add it.") - } - } else { - log.Warningf("Machine is not allowed to synchronize decisions, you can enable it with `cscli console enable console_management`") - } + + if s.TLS.CertFilePath == "" { + serverError <- errors.New("missing TLS cert file") + return } + + err = s.httpServer.ServeTLS(listener, s.TLS.CertFilePath, s.TLS.KeyFilePath) + } else { + err = s.httpServer.Serve(listener) } - s.apic.metricsTomb.Go(func() error { - s.apic.SendMetrics(make(chan bool)) - return nil - }) + switch { + case errors.Is(err, http.ErrServerClosed): + break + case err != nil: + serverError <- err + } } - s.httpServerTomb.Go(func() error { - go func() { - apiReady <- true - log.Infof("CrowdSec Local API listening on %s", s.URL) - if s.TLS != nil && (s.TLS.CertFilePath != "" || s.TLS.KeyFilePath != "") { - if s.TLS.KeyFilePath == "" { - log.Fatalf("while serving local API: %v", errors.New("missing TLS key file")) - } else if s.TLS.CertFilePath == "" { - log.Fatalf("while serving local API: %v", errors.New("missing TLS cert file")) - } + // Starting TCP listener + go func() { + if s.URL == "" { + return + } - if err := s.httpServer.ListenAndServeTLS(s.TLS.CertFilePath, s.TLS.KeyFilePath); err != nil { - log.Fatalf("while serving local API: %v", err) - } - } else { - if err := s.httpServer.ListenAndServe(); err != http.ErrServerClosed { - log.Fatalf("while serving local API: %v", err) - } - } - }() - <-s.httpServerTomb.Dying() - return nil - }) + tcpListener, err = net.Listen("tcp", s.URL) + if err != nil { + serverError <- fmt.Errorf("listening on %s: %w", s.URL, err) + return + } + + log.Infof("CrowdSec Local API listening on %s", s.URL) + startServer(tcpListener, true) + }() + + // Starting Unix socket listener + go func() { + if s.UnixSocket == "" { + return + } + + _ = os.RemoveAll(s.UnixSocket) + + unixListener, err = net.Listen("unix", s.UnixSocket) + if err != nil { + serverError <- fmt.Errorf("while creating unix listener: %w", err) + return + } + + log.Infof("CrowdSec Local API listening on Unix socket %s", s.UnixSocket) + startServer(unixListener, false) + }() + + apiReady <- true + + select { + case err := <-serverError: + return err + case <-s.httpServerTomb.Dying(): + log.Info("Shutting down API server") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := s.httpServer.Shutdown(ctx); err != nil { + log.Errorf("while shutting down http server: %v", err) + } + + close(listenerClosed) + case <-listenerClosed: + if s.UnixSocket != "" { + _ = os.RemoveAll(s.UnixSocket) + } + } return nil } @@ -422,10 +495,13 @@ func (s *APIServer) Close() { if s.apic != nil { s.apic.Shutdown() // stop apic first since it use dbClient } + if s.papi != nil { s.papi.Shutdown() // papi also uses the dbClient } + s.dbClient.Ent.Close() + if s.flushScheduler != nil { s.flushScheduler.Stop() } @@ -433,23 +509,28 @@ func (s *APIServer) Close() { func (s *APIServer) Shutdown() error { s.Close() + if s.httpServer != nil { if err := s.httpServer.Shutdown(context.TODO()); err != nil { return err } } - //close io.writer logger given to gin + // close io.writer logger given to gin if pipe, ok := gin.DefaultErrorWriter.(*io.PipeWriter); ok { pipe.Close() } + if pipe, ok := gin.DefaultWriter.(*io.PipeWriter); ok { pipe.Close() } + s.httpServerTomb.Kill(nil) + if err := s.httpServerTomb.Wait(); err != nil { return fmt.Errorf("while waiting on httpServerTomb: %w", err) } + return nil } @@ -458,36 +539,41 @@ func (s *APIServer) AttachPluginBroker(broker *csplugin.PluginBroker) { } func (s *APIServer) InitController() error { - err := s.controller.Init() if err != nil { return fmt.Errorf("controller init: %w", err) } - if s.TLS != nil { - var cacheExpiration time.Duration - if s.TLS.CacheExpiration != nil { - cacheExpiration = *s.TLS.CacheExpiration - } else { - cacheExpiration = time.Hour - } - s.controller.HandlerV1.Middlewares.JWT.TlsAuth, err = v1.NewTLSAuth(s.TLS.AllowedAgentsOU, s.TLS.CRLPath, - cacheExpiration, - log.WithFields(log.Fields{ - "component": "tls-auth", - "type": "agent", - })) - if err != nil { - return fmt.Errorf("while creating TLS auth for agents: %w", err) - } - s.controller.HandlerV1.Middlewares.APIKey.TlsAuth, err = v1.NewTLSAuth(s.TLS.AllowedBouncersOU, s.TLS.CRLPath, - cacheExpiration, - log.WithFields(log.Fields{ - "component": "tls-auth", - "type": "bouncer", - })) - if err != nil { - return fmt.Errorf("while creating TLS auth for bouncers: %w", err) - } + + if s.TLS == nil { + return nil + } + + // TLS is configured: create the TLSAuth middleware for agents and bouncers + + cacheExpiration := time.Hour + if s.TLS.CacheExpiration != nil { + cacheExpiration = *s.TLS.CacheExpiration + } + + s.controller.HandlerV1.Middlewares.JWT.TlsAuth, err = v1.NewTLSAuth(s.TLS.AllowedAgentsOU, s.TLS.CRLPath, + cacheExpiration, + log.WithFields(log.Fields{ + "component": "tls-auth", + "type": "agent", + })) + if err != nil { + return fmt.Errorf("while creating TLS auth for agents: %w", err) + } + + s.controller.HandlerV1.Middlewares.APIKey.TlsAuth, err = v1.NewTLSAuth(s.TLS.AllowedBouncersOU, s.TLS.CRLPath, + cacheExpiration, + log.WithFields(log.Fields{ + "component": "tls-auth", + "type": "bouncer", + })) + if err != nil { + return fmt.Errorf("while creating TLS auth for bouncers: %w", err) } - return err + + return nil } diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index 464c93f83fd..cdf99462c35 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -1,8 +1,8 @@ package apiserver import ( + "context" "encoding/json" - "fmt" "net/http" "net/http/httptest" "os" @@ -11,31 +11,38 @@ import ( "testing" "time" - "github.com/crowdsecurity/go-cs-lib/pkg/version" - - middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" - "github.com/crowdsecurity/crowdsec/pkg/models" - "github.com/crowdsecurity/crowdsec/pkg/types" + "github.com/gin-gonic/gin" "github.com/go-openapi/strfmt" - "github.com/pkg/errors" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/crowdsecurity/go-cs-lib/cstest" + "github.com/crowdsecurity/go-cs-lib/ptr" + "github.com/crowdsecurity/go-cs-lib/version" + + middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/database" - "github.com/gin-gonic/gin" - - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" + "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/crowdsecurity/crowdsec/pkg/types" ) -var testMachineID = "test" -var testPassword = strfmt.Password("test") -var MachineTest = models.WatcherAuthRequest{ - MachineID: &testMachineID, - Password: &testPassword, -} +const ( + validRegistrationToken = "igheethauCaeteSaiyee3LosohPhahze" + invalidRegistrationToken = "vohl1feibechieG5coh8musheish2auj" +) -var UserAgent = fmt.Sprintf("crowdsec-test/%s", version.Version) -var emptyBody = strings.NewReader("") +var ( + testMachineID = "test" + testPassword = strfmt.Password("test") + MachineTest = models.WatcherRegistrationRequest{ + MachineID: &testMachineID, + Password: &testPassword, + } + UserAgent = "crowdsec-test/" + version.Version + emptyBody = strings.NewReader("") +) func LoadTestConfig(t *testing.T) csconfig.Config { config := csconfig.Config{} @@ -45,6 +52,7 @@ func LoadTestConfig(t *testing.T) csconfig.Config { } tempDir, _ := os.MkdirTemp("", "crowdsec_tests") + t.Cleanup(func() { os.RemoveAll(tempDir) }) dbconfig := csconfig.DatabaseCfg{ @@ -61,14 +69,27 @@ func LoadTestConfig(t *testing.T) csconfig.Config { ShareTaintedScenarios: new(bool), ShareCustomScenarios: new(bool), }, + AutoRegister: &csconfig.LocalAPIAutoRegisterCfg{ + Enable: ptr.Of(true), + Token: validRegistrationToken, + AllowedRanges: []string{ + "127.0.0.1/8", + "::1/128", + }, + }, } + apiConfig := csconfig.APICfg{ Server: &apiServerConfig, } + config.API = &apiConfig - if err := config.API.Server.LoadProfiles(); err != nil { - log.Fatalf("failed to load profiles: %s", err) - } + err := config.API.Server.LoadProfiles() + require.NoError(t, err) + + err = config.API.Server.LoadAutoRegister() + require.NoError(t, err) + return config } @@ -80,6 +101,7 @@ func LoadTestConfigForwardedFor(t *testing.T) csconfig.Config { } tempDir, _ := os.MkdirTemp("", "crowdsec_tests") + t.Cleanup(func() { os.RemoveAll(tempDir) }) dbconfig := csconfig.DatabaseCfg{ @@ -103,100 +125,94 @@ func LoadTestConfigForwardedFor(t *testing.T) csconfig.Config { Server: &apiServerConfig, } config.API = &apiConfig - if err := config.API.Server.LoadProfiles(); err != nil { - log.Fatalf("failed to load profiles: %s", err) - } + err := config.API.Server.LoadProfiles() + require.NoError(t, err) + + err = config.API.Server.LoadAutoRegister() + require.NoError(t, err) + return config } -func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config, error) { +func NewAPIServer(t *testing.T, ctx context.Context) (*APIServer, csconfig.Config) { config := LoadTestConfig(t) + os.Remove("./ent") - apiServer, err := NewServer(config.API.Server) - if err != nil { - return nil, config, fmt.Errorf("unable to run local API: %s", err) - } + + apiServer, err := NewServer(ctx, config.API.Server) + require.NoError(t, err) + log.Printf("Creating new API server") gin.SetMode(gin.TestMode) - return apiServer, config, nil + + return apiServer, config } -func NewAPITest(t *testing.T) (*gin.Engine, csconfig.Config, error) { - apiServer, config, err := NewAPIServer(t) - if err != nil { - return nil, config, fmt.Errorf("unable to run local API: %s", err) - } - err = apiServer.InitController() - if err != nil { - return nil, config, fmt.Errorf("unable to run local API: %s", err) - } +func NewAPITest(t *testing.T, ctx context.Context) (*gin.Engine, csconfig.Config) { + apiServer, config := NewAPIServer(t, ctx) + + err := apiServer.InitController() + require.NoError(t, err) + router, err := apiServer.Router() - if err != nil { - return nil, config, fmt.Errorf("unable to run local API: %s", err) - } - return router, config, nil + require.NoError(t, err) + + return router, config } -func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config, error) { +func NewAPITestForwardedFor(t *testing.T, ctx context.Context) (*gin.Engine, csconfig.Config) { config := LoadTestConfigForwardedFor(t) os.Remove("./ent") - apiServer, err := NewServer(config.API.Server) - if err != nil { - return nil, config, fmt.Errorf("unable to run local API: %s", err) - } + + apiServer, err := NewServer(ctx, config.API.Server) + require.NoError(t, err) + err = apiServer.InitController() - if err != nil { - return nil, config, fmt.Errorf("unable to run local API: %s", err) - } + require.NoError(t, err) + log.Printf("Creating new API server") gin.SetMode(gin.TestMode) + router, err := apiServer.Router() - if err != nil { - return nil, config, fmt.Errorf("unable to run local API: %s", err) - } - return router, config, nil + require.NoError(t, err) + + return router, config } -func ValidateMachine(machineID string, config *csconfig.DatabaseCfg) error { - dbClient, err := database.NewClient(config) - if err != nil { - return fmt.Errorf("unable to create new database client: %s", err) - } - if err := dbClient.ValidateMachine(machineID); err != nil { - return fmt.Errorf("unable to validate machine: %s", err) - } - return nil +func ValidateMachine(t *testing.T, ctx context.Context, machineID string, config *csconfig.DatabaseCfg) { + dbClient, err := database.NewClient(ctx, config) + require.NoError(t, err) + + err = dbClient.ValidateMachine(ctx, machineID) + require.NoError(t, err) } -func GetMachineIP(machineID string, config *csconfig.DatabaseCfg) (string, error) { - dbClient, err := database.NewClient(config) - if err != nil { - return "", fmt.Errorf("unable to create new database client: %s", err) - } - machines, err := dbClient.ListMachines() - if err != nil { - return "", fmt.Errorf("Unable to list machines: %s", err) - } +func GetMachineIP(t *testing.T, machineID string, config *csconfig.DatabaseCfg) string { + ctx := context.Background() + + dbClient, err := database.NewClient(ctx, config) + require.NoError(t, err) + + machines, err := dbClient.ListMachines(ctx) + require.NoError(t, err) + for _, machine := range machines { if machine.MachineId == machineID { - return machine.IpAddress, nil + return machine.IpAddress } } - return "", nil -} -func GetAlertReaderFromFile(path string) *strings.Reader { + return "" +} +func GetAlertReaderFromFile(t *testing.T, path string) *strings.Reader { alertContentBytes, err := os.ReadFile(path) - if err != nil { - log.Fatal(err) - } + require.NoError(t, err) alerts := make([]*models.Alert, 0) - if err := json.Unmarshal(alertContentBytes, &alerts); err != nil { - log.Fatal(err) - } + err = json.Unmarshal(alertContentBytes, &alerts) + require.NoError(t, err) for _, alert := range alerts { *alert.StartAt = time.Now().UTC().Format(time.RFC3339) @@ -204,124 +220,113 @@ func GetAlertReaderFromFile(path string) *strings.Reader { } alertContent, err := json.Marshal(alerts) - if err != nil { - log.Fatal(err) - } - return strings.NewReader(string(alertContent)) + require.NoError(t, err) + return strings.NewReader(string(alertContent)) } -func readDecisionsGetResp(resp *httptest.ResponseRecorder) ([]*models.Decision, int, error) { +func readDecisionsGetResp(t *testing.T, resp *httptest.ResponseRecorder) ([]*models.Decision, int) { var response []*models.Decision - if resp == nil { - return nil, 0, errors.New("response is nil") - } + + require.NotNil(t, resp) + err := json.Unmarshal(resp.Body.Bytes(), &response) - if err != nil { - return nil, resp.Code, err - } - return response, resp.Code, nil + require.NoError(t, err) + + return response, resp.Code } -func readDecisionsErrorResp(resp *httptest.ResponseRecorder) (map[string]string, int, error) { +func readDecisionsErrorResp(t *testing.T, resp *httptest.ResponseRecorder) (map[string]string, int) { var response map[string]string - if resp == nil { - return nil, 0, errors.New("response is nil") - } + + require.NotNil(t, resp) + err := json.Unmarshal(resp.Body.Bytes(), &response) - if err != nil { - return nil, resp.Code, err - } - return response, resp.Code, nil + require.NoError(t, err) + + return response, resp.Code } -func readDecisionsDeleteResp(resp *httptest.ResponseRecorder) (*models.DeleteDecisionResponse, int, error) { +func readDecisionsDeleteResp(t *testing.T, resp *httptest.ResponseRecorder) (*models.DeleteDecisionResponse, int) { var response models.DeleteDecisionResponse - if resp == nil { - return nil, 0, errors.New("response is nil") - } + + require.NotNil(t, resp) err := json.Unmarshal(resp.Body.Bytes(), &response) - if err != nil { - return nil, resp.Code, err - } - return &response, resp.Code, nil + require.NoError(t, err) + + return &response, resp.Code } -func readDecisionsStreamResp(resp *httptest.ResponseRecorder) (map[string][]*models.Decision, int, error) { +func readDecisionsStreamResp(t *testing.T, resp *httptest.ResponseRecorder) (map[string][]*models.Decision, int) { response := make(map[string][]*models.Decision) - if resp == nil { - return nil, 0, errors.New("response is nil") - } + + require.NotNil(t, resp) err := json.Unmarshal(resp.Body.Bytes(), &response) - if err != nil { - return nil, resp.Code, err - } - return response, resp.Code, nil + require.NoError(t, err) + + return response, resp.Code } -func CreateTestMachine(router *gin.Engine) (string, error) { - b, err := json.Marshal(MachineTest) - if err != nil { - return "", fmt.Errorf("unable to marshal MachineTest") - } +func CreateTestMachine(t *testing.T, ctx context.Context, router *gin.Engine, token string) string { + regReq := MachineTest + regReq.RegistrationToken = token + b, err := json.Marshal(regReq) + require.NoError(t, err) + body := string(b) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Set("User-Agent", UserAgent) router.ServeHTTP(w, req) - return body, nil + + return body } -func CreateTestBouncer(config *csconfig.DatabaseCfg) (string, error) { - dbClient, err := database.NewClient(config) - if err != nil { - log.Fatalf("unable to create new database client: %s", err) - } +func CreateTestBouncer(t *testing.T, ctx context.Context, config *csconfig.DatabaseCfg) string { + dbClient, err := database.NewClient(ctx, config) + require.NoError(t, err) + apiKey, err := middlewares.GenerateAPIKey(keyLength) - if err != nil { - return "", fmt.Errorf("unable to generate api key: %s", err) - } - _, err = dbClient.CreateBouncer("test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType) - if err != nil { - return "", fmt.Errorf("unable to create blocker: %s", err) - } + require.NoError(t, err) + + _, err = dbClient.CreateBouncer(ctx, "test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType) + require.NoError(t, err) - return apiKey, nil + return apiKey } func TestWithWrongDBConfig(t *testing.T) { + ctx := context.Background() config := LoadTestConfig(t) config.API.Server.DbConfig.Type = "test" - apiServer, err := NewServer(config.API.Server) + apiServer, err := NewServer(ctx, config.API.Server) - assert.Equal(t, apiServer, &APIServer{}) - assert.Equal(t, "unable to init database client: unknown database type 'test'", err.Error()) + cstest.RequireErrorContains(t, err, "unable to init database client: unknown database type 'test'") + assert.Nil(t, apiServer) } func TestWithWrongFlushConfig(t *testing.T) { + ctx := context.Background() config := LoadTestConfig(t) maxItems := -1 config.API.Server.DbConfig.Flush.MaxItems = &maxItems - apiServer, err := NewServer(config.API.Server) + apiServer, err := NewServer(ctx, config.API.Server) - assert.Equal(t, apiServer, &APIServer{}) - assert.Equal(t, "max_items can't be zero or negative number", err.Error()) + cstest.RequireErrorContains(t, err, "max_items can't be zero or negative") + assert.Nil(t, apiServer) } func TestUnknownPath(t *testing.T) { - router, _, err := NewAPITest(t) - if err != nil { - log.Fatalf("unable to run local API: %s", err) - } + ctx := context.Background() + router, _ := NewAPITest(t, ctx) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodGet, "/test", nil) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test", nil) req.Header.Set("User-Agent", UserAgent) router.ServeHTTP(w, req) - assert.Equal(t, 404, w.Code) - + assert.Equal(t, http.StatusNotFound, w.Code) } /* @@ -340,6 +345,8 @@ ListenURI string `yaml:"listen_uri,omitempty"` //127.0 */ func TestLoggingDebugToFileConfig(t *testing.T) { + ctx := context.Background() + /*declare settings*/ maxAge := "1h" flushConfig := csconfig.FlushDBCfg{ @@ -347,6 +354,7 @@ func TestLoggingDebugToFileConfig(t *testing.T) { } tempDir, _ := os.MkdirTemp("", "crowdsec_tests") + t.Cleanup(func() { os.RemoveAll(tempDir) }) dbconfig := csconfig.DatabaseCfg{ @@ -360,46 +368,37 @@ func TestLoggingDebugToFileConfig(t *testing.T) { LogDir: tempDir, DbConfig: &dbconfig, } - lvl := log.DebugLevel - expectedFile := fmt.Sprintf("%s/crowdsec_api.log", tempDir) + expectedFile := filepath.Join(tempDir, "crowdsec_api.log") expectedLines := []string{"/test42"} - cfg.LogLevel = &lvl + cfg.LogLevel = ptr.Of(log.DebugLevel) // Configure logging - if err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false); err != nil { - t.Fatal(err) - } - api, err := NewServer(&cfg) - if err != nil { - t.Fatalf("failed to create api : %s", err) - } - if api == nil { - t.Fatalf("failed to create api #2 is nbill") - } + err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false) + require.NoError(t, err) + + api, err := NewServer(ctx, &cfg) + require.NoError(t, err) + require.NotNil(t, api) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodGet, "/test42", nil) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test42", nil) req.Header.Set("User-Agent", UserAgent) api.router.ServeHTTP(w, req) - assert.Equal(t, 404, w.Code) - //wait for the request to happen + assert.Equal(t, http.StatusNotFound, w.Code) + // wait for the request to happen time.Sleep(500 * time.Millisecond) - //check file content + // check file content data, err := os.ReadFile(expectedFile) - if err != nil { - t.Fatalf("failed to read file : %s", err) - } + require.NoError(t, err) for _, expectedStr := range expectedLines { - if !strings.Contains(string(data), expectedStr) { - t.Fatalf("expected %s in %s", expectedStr, string(data)) - } + assert.Contains(t, string(data), expectedStr) } - } func TestLoggingErrorToFileConfig(t *testing.T) { + ctx := context.Background() /*declare settings*/ maxAge := "1h" @@ -408,6 +407,7 @@ func TestLoggingErrorToFileConfig(t *testing.T) { } tempDir, _ := os.MkdirTemp("", "crowdsec_tests") + t.Cleanup(func() { os.RemoveAll(tempDir) }) dbconfig := csconfig.DatabaseCfg{ @@ -421,34 +421,29 @@ func TestLoggingErrorToFileConfig(t *testing.T) { LogDir: tempDir, DbConfig: &dbconfig, } - lvl := log.ErrorLevel - expectedFile := fmt.Sprintf("%s/crowdsec_api.log", tempDir) - cfg.LogLevel = &lvl + expectedFile := filepath.Join(tempDir, "crowdsec_api.log") + cfg.LogLevel = ptr.Of(log.ErrorLevel) // Configure logging - if err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false); err != nil { - t.Fatal(err) - } - api, err := NewServer(&cfg) - if err != nil { - t.Fatalf("failed to create api : %s", err) - } - if api == nil { - t.Fatalf("failed to create api #2 is nbill") - } + err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false) + require.NoError(t, err) + + api, err := NewServer(ctx, &cfg) + require.NoError(t, err) + require.NotNil(t, api) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodGet, "/test42", nil) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test42", nil) req.Header.Set("User-Agent", UserAgent) api.router.ServeHTTP(w, req) - assert.Equal(t, 404, w.Code) - //wait for the request to happen + assert.Equal(t, http.StatusNotFound, w.Code) + // wait for the request to happen time.Sleep(500 * time.Millisecond) - //check file content + // check file content x, err := os.ReadFile(expectedFile) - if err == nil && len(x) > 0 { - t.Fatalf("file should be empty, got '%s'", x) + if err == nil { + require.Empty(t, x) } os.Remove("./crowdsec.log") diff --git a/pkg/apiserver/controllers/controller.go b/pkg/apiserver/controllers/controller.go index e0a1656e792..719bb231006 100644 --- a/pkg/apiserver/controllers/controller.go +++ b/pkg/apiserver/controllers/controller.go @@ -1,22 +1,22 @@ package controllers import ( - "context" "net" "net/http" + "strings" "github.com/alexliesenfeld/health" + "github.com/gin-gonic/gin" + log "github.com/sirupsen/logrus" + v1 "github.com/crowdsecurity/crowdsec/pkg/apiserver/controllers/v1" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csplugin" "github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/models" - "github.com/gin-gonic/gin" - log "github.com/sirupsen/logrus" ) type Controller struct { - Ectx context.Context DBClient *database.Client Router *gin.Engine Profiles []*csconfig.ProfileCfg @@ -27,6 +27,7 @@ type Controller struct { ConsoleConfig *csconfig.ConsoleConfig TrustedIPs []net.IPNet HandlerV1 *v1.Controller + AutoRegisterCfg *csconfig.LocalAPIAutoRegisterCfg DisableRemoteLapiRegistration bool } @@ -54,27 +55,46 @@ func serveHealth() http.HandlerFunc { // no caching required health.WithDisabledCache(), ) + return health.NewHandler(checker) } +func eitherAuthMiddleware(jwtMiddleware gin.HandlerFunc, apiKeyMiddleware gin.HandlerFunc) gin.HandlerFunc { + return func(c *gin.Context) { + switch { + case c.GetHeader("X-Api-Key") != "": + apiKeyMiddleware(c) + case c.GetHeader("Authorization") != "": + jwtMiddleware(c) + // uh no auth header. is this TLS with mutual authentication? + case strings.HasPrefix(c.Request.UserAgent(), "crowdsec/"): + // guess log processors by sniffing user-agent + jwtMiddleware(c) + default: + apiKeyMiddleware(c) + } + } +} + func (c *Controller) NewV1() error { var err error v1Config := v1.ControllerV1Config{ DbClient: c.DBClient, - Ctx: c.Ectx, ProfilesCfg: c.Profiles, DecisionDeleteChan: c.DecisionDeleteChan, AlertsAddChan: c.AlertsAddChan, PluginChannel: c.PluginChannel, ConsoleConfig: *c.ConsoleConfig, TrustedIPs: c.TrustedIPs, + AutoRegisterCfg: c.AutoRegisterCfg, } c.HandlerV1, err = v1.New(&v1Config) if err != nil { return err } + c.Router.GET("/health", gin.WrapF(serveHealth())) c.Router.Use(v1.PrometheusMiddleware()) c.Router.HandleMethodNotAllowed = true @@ -103,7 +123,6 @@ func (c *Controller) NewV1() error { jwtAuth.DELETE("/decisions", c.HandlerV1.DeleteDecisions) jwtAuth.DELETE("/decisions/:decision_id", c.HandlerV1.DeleteDecisionById) jwtAuth.GET("/heartbeat", c.HandlerV1.HeartBeat) - } apiKeyAuth := groupV1.Group("") @@ -115,6 +134,12 @@ func (c *Controller) NewV1() error { apiKeyAuth.HEAD("/decisions/stream", c.HandlerV1.StreamDecision) } + eitherAuth := groupV1.Group("") + eitherAuth.Use(eitherAuthMiddleware(c.HandlerV1.Middlewares.JWT.Middleware.MiddlewareFunc(), c.HandlerV1.Middlewares.APIKey.MiddlewareFunc())) + { + eitherAuth.POST("/usage-metrics", c.HandlerV1.UsageMetrics) + } + return nil } diff --git a/pkg/apiserver/controllers/v1/alerts.go b/pkg/apiserver/controllers/v1/alerts.go index 66d19288d74..d1f93228512 100644 --- a/pkg/apiserver/controllers/v1/alerts.go +++ b/pkg/apiserver/controllers/v1/alerts.go @@ -6,23 +6,20 @@ import ( "net" "net/http" "strconv" - "strings" "time" - jwt "github.com/appleboy/gin-jwt/v2" + "github.com/gin-gonic/gin" + "github.com/go-openapi/strfmt" "github.com/google/uuid" + log "github.com/sirupsen/logrus" "github.com/crowdsecurity/crowdsec/pkg/csplugin" "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/types" - "github.com/gin-gonic/gin" - "github.com/go-openapi/strfmt" - log "github.com/sirupsen/logrus" ) func FormatOneAlert(alert *ent.Alert) *models.Alert { - var outputAlert models.Alert startAt := alert.StartedAt.String() StopAt := alert.StoppedAt.String() @@ -31,7 +28,7 @@ func FormatOneAlert(alert *ent.Alert) *models.Alert { machineID = alert.Edges.Owner.MachineId } - outputAlert = models.Alert{ + outputAlert := models.Alert{ ID: int64(alert.ID), MachineID: machineID, CreatedAt: alert.CreatedAt.Format(time.RFC3339), @@ -45,6 +42,7 @@ func FormatOneAlert(alert *ent.Alert) *models.Alert { Capacity: &alert.Capacity, Leakspeed: &alert.LeakSpeed, Simulated: &alert.Simulated, + Remediation: alert.Remediation, UUID: alert.UUID, Source: &models.Source{ Scope: &alert.SourceScope, @@ -58,25 +56,31 @@ func FormatOneAlert(alert *ent.Alert) *models.Alert { Longitude: alert.SourceLongitude, }, } + for _, eventItem := range alert.Edges.Events { - var Metas models.Meta timestamp := eventItem.Time.String() + + var Metas models.Meta + if err := json.Unmarshal([]byte(eventItem.Serialized), &Metas); err != nil { - log.Errorf("unable to unmarshall events meta '%s' : %s", eventItem.Serialized, err) + log.Errorf("unable to parse events meta '%s' : %s", eventItem.Serialized, err) } + outputAlert.Events = append(outputAlert.Events, &models.Event{ Timestamp: ×tamp, Meta: Metas, }) } + for _, metaItem := range alert.Edges.Metas { outputAlert.Meta = append(outputAlert.Meta, &models.MetaItems0{ Key: metaItem.Key, Value: metaItem.Value, }) } + for _, decisionItem := range alert.Edges.Decisions { - duration := decisionItem.Until.Sub(time.Now().UTC()).String() + duration := decisionItem.Until.Sub(time.Now().UTC()).Round(time.Second).String() outputAlert.Decisions = append(outputAlert.Decisions, &models.Decision{ Duration: &duration, // transform into time.Time ? Scenario: &decisionItem.Scenario, @@ -88,6 +92,7 @@ func FormatOneAlert(alert *ent.Alert) *models.Alert { ID: int64(decisionItem.ID), }) } + return &outputAlert } @@ -97,16 +102,18 @@ func FormatAlerts(result []*ent.Alert) models.AddAlertsRequest { for _, alertItem := range result { data = append(data, FormatOneAlert(alertItem)) } + return data } func (c *Controller) sendAlertToPluginChannel(alert *models.Alert, profileID uint) { if c.PluginChannel != nil { RETRY: - for try := 0; try < 3; try++ { + for try := range 3 { select { case c.PluginChannel <- csplugin.ProfileAlert{ProfileID: profileID, Alert: alert}: log.Debugf("alert sent to Plugin channel") + break RETRY default: log.Warningf("Cannot send alert to Plugin channel (try: %d)", try) @@ -116,88 +123,84 @@ func (c *Controller) sendAlertToPluginChannel(alert *models.Alert, profileID uin } } -func normalizeScope(scope string) string { - switch strings.ToLower(scope) { - case "ip": - return types.Ip - case "range": - return types.Range - case "as": - return types.AS - case "country": - return types.Country - default: - return scope - } -} - // CreateAlert writes the alerts received in the body to the database func (c *Controller) CreateAlert(gctx *gin.Context) { - var input models.AddAlertsRequest - claims := jwt.ExtractClaims(gctx) - // TBD: use defined rather than hardcoded key to find back owner - machineID := claims["id"].(string) + ctx := gctx.Request.Context() + machineID, _ := getMachineIDFromContext(gctx) if err := gctx.ShouldBindJSON(&input); err != nil { gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) return } + if err := input.Validate(strfmt.Default); err != nil { c.HandleDBErrors(gctx, err) return } + stopFlush := false + for _, alert := range input { - //normalize scope for alert.Source and decisions + // normalize scope for alert.Source and decisions if alert.Source.Scope != nil { - *alert.Source.Scope = normalizeScope(*alert.Source.Scope) + *alert.Source.Scope = types.NormalizeScope(*alert.Source.Scope) } + for _, decision := range alert.Decisions { if decision.Scope != nil { - *decision.Scope = normalizeScope(*decision.Scope) + *decision.Scope = types.NormalizeScope(*decision.Scope) } } alert.MachineID = machineID - //generate uuid here for alert + // generate uuid here for alert alert.UUID = uuid.NewString() - //if coming from cscli, alert already has decisions + // if coming from cscli, alert already has decisions if len(alert.Decisions) != 0 { - //alert already has a decision (cscli decisions add etc.), generate uuid here + // alert already has a decision (cscli decisions add etc.), generate uuid here for _, decision := range alert.Decisions { decision.UUID = uuid.NewString() } + for pIdx, profile := range c.Profiles { _, matched, err := profile.EvaluateProfile(alert) if err != nil { profile.Logger.Warningf("error while evaluating profile %s : %v", profile.Cfg.Name, err) + continue } + if !matched { continue } + c.sendAlertToPluginChannel(alert, uint(pIdx)) + if profile.Cfg.OnSuccess == "break" { break } } + decision := alert.Decisions[0] if decision.Origin != nil && *decision.Origin == types.CscliImportOrigin { stopFlush = true } + continue } for pIdx, profile := range c.Profiles { profileDecisions, matched, err := profile.EvaluateProfile(alert) forceBreak := false + if err != nil { switch profile.Cfg.OnError { case "apply": profile.Logger.Warningf("applying profile %s despite error: %s", profile.Cfg.Name, err) + matched = true case "continue": profile.Logger.Warningf("skipping %s profile due to error: %s", profile.Cfg.Name, err) @@ -210,18 +213,23 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { return } } + if !matched { continue } + for _, decision := range profileDecisions { decision.UUID = uuid.NewString() } - //generate uuid here for alert + + // generate uuid here for alert if len(alert.Decisions) == 0 { // non manual decision alert.Decisions = append(alert.Decisions, profileDecisions...) } + profileAlert := *alert c.sendAlertToPluginChannel(&profileAlert, uint(pIdx)) + if profile.Cfg.OnSuccess == "break" || forceBreak { break } @@ -232,7 +240,7 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { c.DBClient.CanFlush = false } - alerts, err := c.DBClient.CreateAlert(machineID, input) + alerts, err := c.DBClient.CreateAlert(ctx, machineID, input) c.DBClient.CanFlush = true if err != nil { @@ -254,7 +262,9 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { // FindAlerts: returns alerts from the database based on the specified filter func (c *Controller) FindAlerts(gctx *gin.Context) { - result, err := c.DBClient.QueryAlertWithFilter(gctx.Request.URL.Query()) + ctx := gctx.Request.Context() + + result, err := c.DBClient.QueryAlertWithFilter(ctx, gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) return @@ -266,28 +276,34 @@ func (c *Controller) FindAlerts(gctx *gin.Context) { gctx.String(http.StatusOK, "") return } + gctx.JSON(http.StatusOK, data) } // FindAlertByID returns the alert associated with the ID func (c *Controller) FindAlertByID(gctx *gin.Context) { + ctx := gctx.Request.Context() alertIDStr := gctx.Param("alert_id") + alertID, err := strconv.Atoi(alertIDStr) if err != nil { gctx.JSON(http.StatusBadRequest, gin.H{"message": "alert_id must be valid integer"}) return } - result, err := c.DBClient.GetAlertByID(alertID) + + result, err := c.DBClient.GetAlertByID(ctx, alertID) if err != nil { c.HandleDBErrors(gctx, err) return } + data := FormatOneAlert(result) if gctx.Request.Method == http.MethodHead { gctx.String(http.StatusOK, "") return } + gctx.JSON(http.StatusOK, data) } @@ -295,47 +311,53 @@ func (c *Controller) FindAlertByID(gctx *gin.Context) { func (c *Controller) DeleteAlertByID(gctx *gin.Context) { var err error + ctx := gctx.Request.Context() + incomingIP := gctx.ClientIP() - if incomingIP != "127.0.0.1" && incomingIP != "::1" && !networksContainIP(c.TrustedIPs, incomingIP) { + if incomingIP != "127.0.0.1" && incomingIP != "::1" && !networksContainIP(c.TrustedIPs, incomingIP) && !isUnixSocket(gctx) { gctx.JSON(http.StatusForbidden, gin.H{"message": fmt.Sprintf("access forbidden from this IP (%s)", incomingIP)}) return } decisionIDStr := gctx.Param("alert_id") + decisionID, err := strconv.Atoi(decisionIDStr) if err != nil { gctx.JSON(http.StatusBadRequest, gin.H{"message": "alert_id must be valid integer"}) return } - err = c.DBClient.DeleteAlertByID(decisionID) + + err = c.DBClient.DeleteAlertByID(ctx, decisionID) if err != nil { c.HandleDBErrors(gctx, err) return } - deleteAlertResp := models.DeleteAlertsResponse{ - NbDeleted: "1", - } + deleteAlertResp := models.DeleteAlertsResponse{NbDeleted: "1"} gctx.JSON(http.StatusOK, deleteAlertResp) } // DeleteAlerts deletes alerts from the database based on the specified filter func (c *Controller) DeleteAlerts(gctx *gin.Context) { + ctx := gctx.Request.Context() + incomingIP := gctx.ClientIP() - if incomingIP != "127.0.0.1" && incomingIP != "::1" && !networksContainIP(c.TrustedIPs, incomingIP) { + if incomingIP != "127.0.0.1" && incomingIP != "::1" && !networksContainIP(c.TrustedIPs, incomingIP) && !isUnixSocket(gctx) { gctx.JSON(http.StatusForbidden, gin.H{"message": fmt.Sprintf("access forbidden from this IP (%s)", incomingIP)}) return } - var err error - nbDeleted, err := c.DBClient.DeleteAlertWithFilter(gctx.Request.URL.Query()) + + nbDeleted, err := c.DBClient.DeleteAlertWithFilter(ctx, gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) return } + deleteAlertsResp := models.DeleteAlertsResponse{ NbDeleted: strconv.Itoa(nbDeleted), } + gctx.JSON(http.StatusOK, deleteAlertsResp) } @@ -346,5 +368,6 @@ func networksContainIP(networks []net.IPNet, ip string) bool { return true } } + return false } diff --git a/pkg/apiserver/controllers/v1/controller.go b/pkg/apiserver/controllers/v1/controller.go index 60da83d7dcb..f8b6aa76ea5 100644 --- a/pkg/apiserver/controllers/v1/controller.go +++ b/pkg/apiserver/controllers/v1/controller.go @@ -1,7 +1,6 @@ package v1 import ( - "context" "fmt" "net" @@ -14,7 +13,6 @@ import ( ) type Controller struct { - Ectx context.Context DBClient *database.Client APIKeyHeader string Middlewares *middlewares.Middlewares @@ -23,22 +21,23 @@ type Controller struct { AlertsAddChan chan []*models.Alert DecisionDeleteChan chan []*models.Decision - PluginChannel chan csplugin.ProfileAlert - ConsoleConfig csconfig.ConsoleConfig - TrustedIPs []net.IPNet + PluginChannel chan csplugin.ProfileAlert + ConsoleConfig csconfig.ConsoleConfig + TrustedIPs []net.IPNet + AutoRegisterCfg *csconfig.LocalAPIAutoRegisterCfg } type ControllerV1Config struct { DbClient *database.Client - Ctx context.Context ProfilesCfg []*csconfig.ProfileCfg AlertsAddChan chan []*models.Alert DecisionDeleteChan chan []*models.Decision - PluginChannel chan csplugin.ProfileAlert - ConsoleConfig csconfig.ConsoleConfig - TrustedIPs []net.IPNet + PluginChannel chan csplugin.ProfileAlert + ConsoleConfig csconfig.ConsoleConfig + TrustedIPs []net.IPNet + AutoRegisterCfg *csconfig.LocalAPIAutoRegisterCfg } func New(cfg *ControllerV1Config) (*Controller, error) { @@ -50,7 +49,6 @@ func New(cfg *ControllerV1Config) (*Controller, error) { } v1 := &Controller{ - Ectx: cfg.Ctx, DBClient: cfg.DbClient, APIKeyHeader: middlewares.APIKeyHeader, Profiles: profiles, @@ -59,10 +57,13 @@ func New(cfg *ControllerV1Config) (*Controller, error) { PluginChannel: cfg.PluginChannel, ConsoleConfig: cfg.ConsoleConfig, TrustedIPs: cfg.TrustedIPs, + AutoRegisterCfg: cfg.AutoRegisterCfg, } + v1.Middlewares, err = middlewares.NewMiddlewares(cfg.DbClient) if err != nil { return v1, err } + return v1, nil } diff --git a/pkg/apiserver/controllers/v1/decisions.go b/pkg/apiserver/controllers/v1/decisions.go index 8ea3798731a..ffefffc226b 100644 --- a/pkg/apiserver/controllers/v1/decisions.go +++ b/pkg/apiserver/controllers/v1/decisions.go @@ -1,17 +1,18 @@ package v1 import ( + "context" "encoding/json" - "fmt" "net/http" "strconv" "time" + "github.com/gin-gonic/gin" + log "github.com/sirupsen/logrus" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/fflag" "github.com/crowdsecurity/crowdsec/pkg/models" - "github.com/gin-gonic/gin" - log "github.com/sirupsen/logrus" ) // Format decisions for the bouncers @@ -19,7 +20,7 @@ func FormatDecisions(decisions []*ent.Decision) []*models.Decision { var results []*models.Decision for _, dbDecision := range decisions { - duration := dbDecision.Until.Sub(time.Now().UTC()).String() + duration := dbDecision.Until.Sub(time.Now().UTC()).Round(time.Second).String() decision := models.Decision{ ID: int64(dbDecision.ID), Duration: &duration, @@ -32,23 +33,29 @@ func FormatDecisions(decisions []*ent.Decision) []*models.Decision { } results = append(results, &decision) } + return results } func (c *Controller) GetDecision(gctx *gin.Context) { - var err error - var results []*models.Decision - var data []*ent.Decision + var ( + results []*models.Decision + data []*ent.Decision + ) + + ctx := gctx.Request.Context() bouncerInfo, err := getBouncerFromContext(gctx) if err != nil { gctx.JSON(http.StatusUnauthorized, gin.H{"message": "not allowed"}) + return } - data, err = c.DBClient.QueryDecisionWithFilter(gctx.Request.URL.Query()) + data, err = c.DBClient.QueryDecisionWithFilter(ctx, gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) + return } @@ -63,11 +70,12 @@ func (c *Controller) GetDecision(gctx *gin.Context) { if gctx.Request.Method == http.MethodHead { gctx.String(http.StatusOK, "") + return } - if time.Now().UTC().Sub(bouncerInfo.LastPull) >= time.Minute { - if err := c.DBClient.UpdateBouncerLastPull(time.Now().UTC(), bouncerInfo.ID); err != nil { + if bouncerInfo.LastPull == nil || time.Now().UTC().Sub(*bouncerInfo.LastPull) >= time.Minute { + if err := c.DBClient.UpdateBouncerLastPull(ctx, time.Now().UTC(), bouncerInfo.ID); err != nil { log.Errorf("failed to update bouncer last pull: %v", err) } } @@ -76,20 +84,25 @@ func (c *Controller) GetDecision(gctx *gin.Context) { } func (c *Controller) DeleteDecisionById(gctx *gin.Context) { - var err error - decisionIDStr := gctx.Param("decision_id") + decisionID, err := strconv.Atoi(decisionIDStr) if err != nil { gctx.JSON(http.StatusBadRequest, gin.H{"message": "decision_id must be valid integer"}) + return } - nbDeleted, deletedFromDB, err := c.DBClient.SoftDeleteDecisionByID(decisionID) + + ctx := gctx.Request.Context() + + nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionByID(ctx, decisionID) if err != nil { c.HandleDBErrors(gctx, err) + return } - //transform deleted decisions to be sendable to capi + + // transform deleted decisions to be sendable to capi deletedDecisions := FormatDecisions(deletedFromDB) if c.DecisionDeleteChan != nil { @@ -104,13 +117,16 @@ func (c *Controller) DeleteDecisionById(gctx *gin.Context) { } func (c *Controller) DeleteDecisions(gctx *gin.Context) { - var err error - nbDeleted, deletedFromDB, err := c.DBClient.SoftDeleteDecisionsWithFilter(gctx.Request.URL.Query()) + ctx := gctx.Request.Context() + + nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionsWithFilter(ctx, gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) + return } - //transform deleted decisions to be sendable to capi + + // transform deleted decisions to be sendable to capi deletedDecisions := FormatDecisions(deletedFromDB) if c.DecisionDeleteChan != nil { @@ -120,35 +136,42 @@ func (c *Controller) DeleteDecisions(gctx *gin.Context) { deleteDecisionResp := models.DeleteDecisionResponse{ NbDeleted: nbDeleted, } + gctx.JSON(http.StatusOK, deleteDecisionResp) } -func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFunc func(map[string][]string) ([]*ent.Decision, error)) error { +func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFunc func(context.Context, map[string][]string) ([]*ent.Decision, error)) error { // respBuffer := bytes.NewBuffer([]byte{}) - limit := 30000 //FIXME : make it configurable + limit := 30000 // FIXME : make it configurable needComma := false lastId := 0 - limitStr := fmt.Sprintf("%d", limit) + ctx := gctx.Request.Context() + + limitStr := strconv.Itoa(limit) filters["limit"] = []string{limitStr} + for { if lastId > 0 { - lastIdStr := fmt.Sprintf("%d", lastId) + lastIdStr := strconv.Itoa(lastId) filters["id_gt"] = []string{lastIdStr} } - data, err := dbFunc(filters) + data, err := dbFunc(ctx, filters) if err != nil { return err } + if len(data) > 0 { lastId = data[len(data)-1].ID + results := FormatDecisions(data) for _, decision := range results { decisionJSON, _ := json.Marshal(decision) + if needComma { - //respBuffer.Write([]byte(",")) - gctx.Writer.Write([]byte(",")) + // respBuffer.Write([]byte(",")) + gctx.Writer.WriteString(",") } else { needComma = true } @@ -157,46 +180,57 @@ func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFun _, err := gctx.Writer.Write(decisionJSON) if err != nil { gctx.Writer.Flush() + return err } - //respBuffer.Reset() + // respBuffer.Reset() } } + log.Debugf("startup: %d decisions returned (limit: %d, lastid: %d)", len(data), limit, lastId) + if len(data) < limit { gctx.Writer.Flush() + break } } + return nil } -func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPull time.Time, dbFunc func(time.Time, map[string][]string) ([]*ent.Decision, error)) error { - //respBuffer := bytes.NewBuffer([]byte{}) - limit := 30000 //FIXME : make it configurable +func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPull *time.Time, dbFunc func(context.Context, *time.Time, map[string][]string) ([]*ent.Decision, error)) error { + // respBuffer := bytes.NewBuffer([]byte{}) + limit := 30000 // FIXME : make it configurable needComma := false lastId := 0 - limitStr := fmt.Sprintf("%d", limit) + ctx := gctx.Request.Context() + + limitStr := strconv.Itoa(limit) filters["limit"] = []string{limitStr} + for { if lastId > 0 { - lastIdStr := fmt.Sprintf("%d", lastId) + lastIdStr := strconv.Itoa(lastId) filters["id_gt"] = []string{lastIdStr} } - data, err := dbFunc(lastPull, filters) + data, err := dbFunc(ctx, lastPull, filters) if err != nil { return err } + if len(data) > 0 { lastId = data[len(data)-1].ID + results := FormatDecisions(data) for _, decision := range results { decisionJSON, _ := json.Marshal(decision) + if needComma { - //respBuffer.Write([]byte(",")) - gctx.Writer.Write([]byte(",")) + // respBuffer.Write([]byte(",")) + gctx.Writer.WriteString(",") } else { needComma = true } @@ -205,17 +239,22 @@ func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPul _, err := gctx.Writer.Write(decisionJSON) if err != nil { gctx.Writer.Flush() + return err } - //respBuffer.Reset() + // respBuffer.Reset() } } + log.Debugf("startup: %d decisions returned (limit: %d, lastid: %d)", len(data), limit, lastId) + if len(data) < limit { gctx.Writer.Flush() + break } } + return nil } @@ -225,127 +264,152 @@ func (c *Controller) StreamDecisionChunked(gctx *gin.Context, bouncerInfo *ent.B gctx.Writer.Header().Set("Content-Type", "application/json") gctx.Writer.Header().Set("Transfer-Encoding", "chunked") gctx.Writer.WriteHeader(http.StatusOK) - gctx.Writer.Write([]byte(`{"new": [`)) //No need to check for errors, the doc says it always returns nil + gctx.Writer.WriteString(`{"new": [`) // No need to check for errors, the doc says it always returns nil // if the blocker just started, return all decisions if val, ok := gctx.Request.URL.Query()["startup"]; ok && val[0] == "true" { - //Active decisions - + // Active decisions err := writeStartupDecisions(gctx, filters, c.DBClient.QueryAllDecisionsWithFilters) - if err != nil { log.Errorf("failed sending new decisions for startup: %v", err) - gctx.Writer.Write([]byte(`], "deleted": []}`)) + gctx.Writer.WriteString(`], "deleted": []}`) gctx.Writer.Flush() + return err } - gctx.Writer.Write([]byte(`], "deleted": [`)) - //Expired decisions + gctx.Writer.WriteString(`], "deleted": [`) + // Expired decisions err = writeStartupDecisions(gctx, filters, c.DBClient.QueryExpiredDecisionsWithFilters) if err != nil { log.Errorf("failed sending expired decisions for startup: %v", err) - gctx.Writer.Write([]byte(`]}`)) + gctx.Writer.WriteString(`]}`) gctx.Writer.Flush() + return err } - gctx.Writer.Write([]byte(`]}`)) + gctx.Writer.WriteString(`]}`) gctx.Writer.Flush() } else { err = writeDeltaDecisions(gctx, filters, bouncerInfo.LastPull, c.DBClient.QueryNewDecisionsSinceWithFilters) if err != nil { log.Errorf("failed sending new decisions for delta: %v", err) - gctx.Writer.Write([]byte(`], "deleted": []}`)) + gctx.Writer.WriteString(`], "deleted": []}`) gctx.Writer.Flush() + return err } - gctx.Writer.Write([]byte(`], "deleted": [`)) + gctx.Writer.WriteString(`], "deleted": [`) err = writeDeltaDecisions(gctx, filters, bouncerInfo.LastPull, c.DBClient.QueryExpiredDecisionsSinceWithFilters) - if err != nil { log.Errorf("failed sending expired decisions for delta: %v", err) - gctx.Writer.Write([]byte(`]}`)) + gctx.Writer.WriteString("]}") gctx.Writer.Flush() + return err } - gctx.Writer.Write([]byte(`]}`)) + gctx.Writer.WriteString("]}") gctx.Writer.Flush() } + return nil } func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *ent.Bouncer, streamStartTime time.Time, filters map[string][]string) error { - var data []*ent.Decision - var err error + var ( + data []*ent.Decision + err error + ) + + ctx := gctx.Request.Context() + ret := make(map[string][]*models.Decision, 0) ret["new"] = []*models.Decision{} ret["deleted"] = []*models.Decision{} if val, ok := gctx.Request.URL.Query()["startup"]; ok { if val[0] == "true" { - data, err = c.DBClient.QueryAllDecisionsWithFilters(filters) + data, err = c.DBClient.QueryAllDecisionsWithFilters(ctx, filters) if err != nil { log.Errorf("failed querying decisions: %v", err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) + return err } - //data = KeepLongestDecision(data) + // data = KeepLongestDecision(data) ret["new"] = FormatDecisions(data) // getting expired decisions - data, err = c.DBClient.QueryExpiredDecisionsWithFilters(filters) + data, err = c.DBClient.QueryExpiredDecisionsWithFilters(ctx, filters) if err != nil { log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) + return err } + ret["deleted"] = FormatDecisions(data) gctx.JSON(http.StatusOK, ret) + return nil } } // getting new decisions - data, err = c.DBClient.QueryNewDecisionsSinceWithFilters(bouncerInfo.LastPull, filters) + data, err = c.DBClient.QueryNewDecisionsSinceWithFilters(ctx, bouncerInfo.LastPull, filters) if err != nil { log.Errorf("unable to query new decision for '%s' : %v", bouncerInfo.Name, err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) + return err } - //data = KeepLongestDecision(data) + // data = KeepLongestDecision(data) ret["new"] = FormatDecisions(data) + since := time.Time{} + if bouncerInfo.LastPull != nil { + since = bouncerInfo.LastPull.Add(-2 * time.Second) + } + // getting expired decisions - data, err = c.DBClient.QueryExpiredDecisionsSinceWithFilters(bouncerInfo.LastPull.Add((-2 * time.Second)), filters) // do we want to give exactly lastPull time ? + data, err = c.DBClient.QueryExpiredDecisionsSinceWithFilters(ctx, &since, filters) // do we want to give exactly lastPull time ? if err != nil { log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) + return err } + ret["deleted"] = FormatDecisions(data) gctx.JSON(http.StatusOK, ret) + return nil } func (c *Controller) StreamDecision(gctx *gin.Context) { var err error + ctx := gctx.Request.Context() + streamStartTime := time.Now().UTC() + bouncerInfo, err := getBouncerFromContext(gctx) if err != nil { gctx.JSON(http.StatusUnauthorized, gin.H{"message": "not allowed"}) + return } if gctx.Request.Method == http.MethodHead { - //For HEAD, just return as the bouncer won't get a body anyway, so no need to query the db - //We also don't update the last pull time, as it would mess with the delta sent on the next request (if done without startup=true) + // For HEAD, just return as the bouncer won't get a body anyway, so no need to query the db + // We also don't update the last pull time, as it would mess with the delta sent on the next request (if done without startup=true) gctx.String(http.StatusOK, "") + return } @@ -361,8 +425,8 @@ func (c *Controller) StreamDecision(gctx *gin.Context) { } if err == nil { - //Only update the last pull time if no error occurred when sending the decisions to avoid missing decisions - if err := c.DBClient.UpdateBouncerLastPull(streamStartTime, bouncerInfo.ID); err != nil { + // Only update the last pull time if no error occurred when sending the decisions to avoid missing decisions + if err := c.DBClient.UpdateBouncerLastPull(ctx, streamStartTime, bouncerInfo.ID); err != nil { log.Errorf("unable to update bouncer '%s' pull: %v", bouncerInfo.Name, err) } } diff --git a/pkg/apiserver/controllers/v1/errors.go b/pkg/apiserver/controllers/v1/errors.go index 5edf0d6bfbf..d661de44b0e 100644 --- a/pkg/apiserver/controllers/v1/errors.go +++ b/pkg/apiserver/controllers/v1/errors.go @@ -1,34 +1,36 @@ package v1 import ( + "errors" "net/http" + "strings" - "github.com/crowdsecurity/crowdsec/pkg/database" "github.com/gin-gonic/gin" - "github.com/pkg/errors" + + "github.com/crowdsecurity/crowdsec/pkg/database" ) func (c *Controller) HandleDBErrors(gctx *gin.Context, err error) { - switch errors.Cause(err) { - case database.ItemNotFound: + switch { + case errors.Is(err, database.ItemNotFound): gctx.JSON(http.StatusNotFound, gin.H{"message": err.Error()}) return - case database.UserExists: + case errors.Is(err, database.UserExists): gctx.JSON(http.StatusForbidden, gin.H{"message": err.Error()}) return - case database.HashError: + case errors.Is(err, database.HashError): gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) return - case database.InsertFail: + case errors.Is(err, database.InsertFail): gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) return - case database.QueryFail: + case errors.Is(err, database.QueryFail): gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) return - case database.ParseTimeFail: + case errors.Is(err, database.ParseTimeFail): gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) return - case database.ParseDurationFail: + case errors.Is(err, database.ParseDurationFail): gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) return default: @@ -36,3 +38,32 @@ func (c *Controller) HandleDBErrors(gctx *gin.Context, err error) { return } } + +// collapseRepeatedPrefix collapses repeated occurrences of a given prefix in the text +func collapseRepeatedPrefix(text string, prefix string) string { + count := 0 + for strings.HasPrefix(text, prefix) { + count++ + text = strings.TrimPrefix(text, prefix) + } + + if count > 0 { + return prefix + text + } + + return text +} + +// RepeatedPrefixError wraps an error and removes the repeating prefix from its message +type RepeatedPrefixError struct { + OriginalError error + Prefix string +} + +func (e RepeatedPrefixError) Error() string { + return collapseRepeatedPrefix(e.OriginalError.Error(), e.Prefix) +} + +func (e RepeatedPrefixError) Unwrap() error { + return e.OriginalError +} diff --git a/pkg/apiserver/controllers/v1/errors_test.go b/pkg/apiserver/controllers/v1/errors_test.go new file mode 100644 index 00000000000..89c561f83bd --- /dev/null +++ b/pkg/apiserver/controllers/v1/errors_test.go @@ -0,0 +1,57 @@ +package v1 + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCollapseRepeatedPrefix(t *testing.T) { + tests := []struct { + input string + prefix string + want string + }{ + { + input: "aaabbbcccaaa", + prefix: "aaa", + want: "aaabbbcccaaa", + }, { + input: "hellohellohello world", + prefix: "hello", + want: "hello world", + }, { + input: "ababababxyz", + prefix: "ab", + want: "abxyz", + }, { + input: "xyzxyzxyzxyzxyz", + prefix: "xyz", + want: "xyz", + }, { + input: "123123123456", + prefix: "456", + want: "123123123456", + }, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + assert.Equal(t, tt.want, collapseRepeatedPrefix(tt.input, tt.prefix)) + }) + } +} + +func TestRepeatedPrefixError(t *testing.T) { + originalErr := errors.New("hellohellohello world") + wrappedErr := RepeatedPrefixError{OriginalError: originalErr, Prefix: "hello"} + + want := "hello world" + + assert.Equal(t, want, wrappedErr.Error()) + + assert.Equal(t, originalErr, errors.Unwrap(wrappedErr)) + require.ErrorIs(t, wrappedErr, originalErr) +} diff --git a/pkg/apiserver/controllers/v1/heartbeat.go b/pkg/apiserver/controllers/v1/heartbeat.go index bf6fd578195..799b736ccfe 100644 --- a/pkg/apiserver/controllers/v1/heartbeat.go +++ b/pkg/apiserver/controllers/v1/heartbeat.go @@ -3,19 +3,18 @@ package v1 import ( "net/http" - jwt "github.com/appleboy/gin-jwt/v2" "github.com/gin-gonic/gin" ) func (c *Controller) HeartBeat(gctx *gin.Context) { + machineID, _ := getMachineIDFromContext(gctx) - claims := jwt.ExtractClaims(gctx) - // TBD: use defined rather than hardcoded key to find back owner - machineID := claims["id"].(string) + ctx := gctx.Request.Context() - if err := c.DBClient.UpdateMachineLastHeartBeat(machineID); err != nil { + if err := c.DBClient.UpdateMachineLastHeartBeat(ctx, machineID); err != nil { c.HandleDBErrors(gctx, err) return } + gctx.Status(http.StatusOK) } diff --git a/pkg/apiserver/controllers/v1/machines.go b/pkg/apiserver/controllers/v1/machines.go index b4f28d94fd0..ff59e389cb1 100644 --- a/pkg/apiserver/controllers/v1/machines.go +++ b/pkg/apiserver/controllers/v1/machines.go @@ -1,31 +1,82 @@ package v1 import ( + "errors" + "net" "net/http" - "github.com/crowdsecurity/crowdsec/pkg/models" - "github.com/crowdsecurity/crowdsec/pkg/types" "github.com/gin-gonic/gin" "github.com/go-openapi/strfmt" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/crowdsecurity/crowdsec/pkg/types" ) +func (c *Controller) shouldAutoRegister(token string, gctx *gin.Context) (bool, error) { + if !*c.AutoRegisterCfg.Enable { + return false, nil + } + + clientIP := net.ParseIP(gctx.ClientIP()) + + // Can probaby happen if using unix socket ? + if clientIP == nil { + log.Warnf("Failed to parse client IP for watcher self registration: %s", gctx.ClientIP()) + return false, nil + } + + if token == "" || c.AutoRegisterCfg == nil { + return false, nil + } + + // Check the token + if token != c.AutoRegisterCfg.Token { + return false, errors.New("invalid token for auto registration") + } + + // Check the source IP + for _, ipRange := range c.AutoRegisterCfg.AllowedRangesParsed { + if ipRange.Contains(clientIP) { + return true, nil + } + } + + return false, errors.New("IP not in allowed range for auto registration") +} + func (c *Controller) CreateMachine(gctx *gin.Context) { - var err error + ctx := gctx.Request.Context() + var input models.WatcherRegistrationRequest - if err = gctx.ShouldBindJSON(&input); err != nil { + + if err := gctx.ShouldBindJSON(&input); err != nil { gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) return } - if err = input.Validate(strfmt.Default); err != nil { - c.HandleDBErrors(gctx, err) + + if err := input.Validate(strfmt.Default); err != nil { + gctx.JSON(http.StatusUnprocessableEntity, gin.H{"message": err.Error()}) return } - _, err = c.DBClient.CreateMachine(input.MachineID, input.Password, gctx.ClientIP(), false, false, types.PasswordAuthType) + autoRegister, err := c.shouldAutoRegister(input.RegistrationToken, gctx) if err != nil { + log.WithFields(log.Fields{"ip": gctx.ClientIP(), "machine_id": *input.MachineID}).Errorf("Auto-register failed: %s", err) + gctx.JSON(http.StatusUnauthorized, gin.H{"message": err.Error()}) + + return + } + + if _, err := c.DBClient.CreateMachine(ctx, input.MachineID, input.Password, gctx.ClientIP(), autoRegister, false, types.PasswordAuthType); err != nil { c.HandleDBErrors(gctx, err) return } - gctx.Status(http.StatusCreated) + if autoRegister { + log.WithFields(log.Fields{"ip": gctx.ClientIP(), "machine_id": *input.MachineID}).Info("Auto-registered machine") + gctx.Status(http.StatusAccepted) + } else { + gctx.Status(http.StatusCreated) + } } diff --git a/pkg/apiserver/controllers/v1/metrics.go b/pkg/apiserver/controllers/v1/metrics.go index 0f3bdb6d125..4f6ee0986eb 100644 --- a/pkg/apiserver/controllers/v1/metrics.go +++ b/pkg/apiserver/controllers/v1/metrics.go @@ -3,7 +3,6 @@ package v1 import ( "time" - jwt "github.com/appleboy/gin-jwt/v2" "github.com/gin-gonic/gin" "github.com/prometheus/client_golang/prometheus" ) @@ -35,8 +34,11 @@ var LapiBouncerHits = prometheus.NewCounterVec( []string{"bouncer", "route", "method"}, ) -/* keep track of the number of calls (per bouncer) that lead to nil/non-nil responses. -while it's not exact, it's a good way to know - when you have a rutpure bouncer - what is the rate of ok/ko answers you got from lapi*/ +/* + keep track of the number of calls (per bouncer) that lead to nil/non-nil responses. + +while it's not exact, it's a good way to know - when you have a rutpure bouncer - what is the rate of ok/ko answers you got from lapi +*/ var LapiNilDecisions = prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "cs_lapi_decisions_ko_total", @@ -63,46 +65,49 @@ var LapiResponseTime = prometheus.NewHistogramVec( []string{"endpoint", "method"}) func PrometheusBouncersHasEmptyDecision(c *gin.Context) { - name, ok := c.Get("BOUNCER_NAME") - if ok { + bouncer, _ := getBouncerFromContext(c) + if bouncer != nil { LapiNilDecisions.With(prometheus.Labels{ - "bouncer": name.(string)}).Inc() + "bouncer": bouncer.Name, + }).Inc() } } func PrometheusBouncersHasNonEmptyDecision(c *gin.Context) { - name, ok := c.Get("BOUNCER_NAME") - if ok { + bouncer, _ := getBouncerFromContext(c) + if bouncer != nil { LapiNonNilDecisions.With(prometheus.Labels{ - "bouncer": name.(string)}).Inc() + "bouncer": bouncer.Name, + }).Inc() } } func PrometheusMachinesMiddleware() gin.HandlerFunc { return func(c *gin.Context) { - claims := jwt.ExtractClaims(c) - if claims != nil { - if rawID, ok := claims["id"]; ok { - machineID := rawID.(string) - LapiMachineHits.With(prometheus.Labels{ - "machine": machineID, - "route": c.Request.URL.Path, - "method": c.Request.Method}).Inc() - } + machineID, _ := getMachineIDFromContext(c) + if machineID != "" { + LapiMachineHits.With(prometheus.Labels{ + "machine": machineID, + "route": c.Request.URL.Path, + "method": c.Request.Method, + }).Inc() } + c.Next() } } func PrometheusBouncersMiddleware() gin.HandlerFunc { return func(c *gin.Context) { - name, ok := c.Get("BOUNCER_NAME") - if ok { + bouncer, _ := getBouncerFromContext(c) + if bouncer != nil { LapiBouncerHits.With(prometheus.Labels{ - "bouncer": name.(string), + "bouncer": bouncer.Name, "route": c.Request.URL.Path, - "method": c.Request.Method}).Inc() + "method": c.Request.Method, + }).Inc() } + c.Next() } } @@ -110,10 +115,13 @@ func PrometheusBouncersMiddleware() gin.HandlerFunc { func PrometheusMiddleware() gin.HandlerFunc { return func(c *gin.Context) { startTime := time.Now() + LapiRouteHits.With(prometheus.Labels{ "route": c.Request.URL.Path, - "method": c.Request.Method}).Inc() + "method": c.Request.Method, + }).Inc() c.Next() + elapsed := time.Since(startTime) LapiResponseTime.With(prometheus.Labels{"method": c.Request.Method, "endpoint": c.Request.URL.Path}).Observe(elapsed.Seconds()) } diff --git a/pkg/apiserver/controllers/v1/usagemetrics.go b/pkg/apiserver/controllers/v1/usagemetrics.go new file mode 100644 index 00000000000..5b2c3e3b1a9 --- /dev/null +++ b/pkg/apiserver/controllers/v1/usagemetrics.go @@ -0,0 +1,205 @@ +package v1 + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/go-openapi/strfmt" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/go-cs-lib/ptr" + + "github.com/crowdsecurity/crowdsec/pkg/database/ent" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" + "github.com/crowdsecurity/crowdsec/pkg/models" +) + +// updateBaseMetrics updates the base metrics for a machine or bouncer +func (c *Controller) updateBaseMetrics(ctx context.Context, machineID string, bouncer *ent.Bouncer, baseMetrics models.BaseMetrics, hubItems models.HubItems, datasources map[string]int64) error { + switch { + case machineID != "": + return c.DBClient.MachineUpdateBaseMetrics(ctx, machineID, baseMetrics, hubItems, datasources) + case bouncer != nil: + return c.DBClient.BouncerUpdateBaseMetrics(ctx, bouncer.Name, bouncer.Type, baseMetrics) + default: + return errors.New("no machineID or bouncerName set") + } +} + +// UsageMetrics receives metrics from log processors and remediation components +func (c *Controller) UsageMetrics(gctx *gin.Context) { + var input models.AllMetrics + + logger := log.WithField("func", "UsageMetrics") + + // parse the payload + + if err := gctx.ShouldBindJSON(&input); err != nil { + logger.Errorf("Failed to bind json: %s", err) + gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) + + return + } + + if err := input.Validate(strfmt.Default); err != nil { + // work around a nuisance in the generated code + cleanErr := RepeatedPrefixError{ + OriginalError: err, + Prefix: "validation failure list:\n", + } + logger.Errorf("Failed to validate usage metrics: %s", cleanErr) + gctx.JSON(http.StatusUnprocessableEntity, gin.H{"message": cleanErr.Error()}) + + return + } + + var ( + generatedType metric.GeneratedType + generatedBy string + ) + + bouncer, _ := getBouncerFromContext(gctx) + if bouncer != nil { + logger.Tracef("Received usage metris for bouncer: %s", bouncer.Name) + + generatedType = metric.GeneratedTypeRC + generatedBy = bouncer.Name + } + + machineID, _ := getMachineIDFromContext(gctx) + if machineID != "" { + logger.Tracef("Received usage metrics for log processor: %s", machineID) + + generatedType = metric.GeneratedTypeLP + generatedBy = machineID + } + + if generatedBy == "" { + // how did we get here? + logger.Error("No machineID or bouncer in request context after authentication") + gctx.JSON(http.StatusInternalServerError, gin.H{"message": "No machineID or bouncer in request context after authentication"}) + + return + } + + if machineID != "" && bouncer != nil { + logger.Errorf("Payload has both machineID and bouncer") + gctx.JSON(http.StatusBadRequest, gin.H{"message": "Payload has both LP and RC data"}) + + return + } + + var ( + payload map[string]any + baseMetrics models.BaseMetrics + hubItems models.HubItems + datasources map[string]int64 + ) + + switch len(input.LogProcessors) { + case 0: + if machineID != "" { + logger.Errorf("Missing log processor data") + gctx.JSON(http.StatusBadRequest, gin.H{"message": "Missing log processor data"}) + + return + } + case 1: + // the final slice can't have more than one item, + // guaranteed by the swagger schema + item0 := input.LogProcessors[0] + + err := item0.Validate(strfmt.Default) + if err != nil { + logger.Errorf("Failed to validate log processor data: %s", err) + gctx.JSON(http.StatusUnprocessableEntity, gin.H{"message": err.Error()}) + + return + } + + payload = map[string]any{ + "metrics": item0.Metrics, + } + baseMetrics = item0.BaseMetrics + hubItems = item0.HubItems + datasources = item0.Datasources + default: + logger.Errorf("Payload has more than one log processor") + // this is not checked in the swagger schema + gctx.JSON(http.StatusBadRequest, gin.H{"message": "Payload has more than one log processor"}) + + return + } + + switch len(input.RemediationComponents) { + case 0: + if bouncer != nil { + logger.Errorf("Missing remediation component data") + gctx.JSON(http.StatusBadRequest, gin.H{"message": "Missing remediation component data"}) + + return + } + case 1: + item0 := input.RemediationComponents[0] + + err := item0.Validate(strfmt.Default) + if err != nil { + logger.Errorf("Failed to validate remediation component data: %s", err) + gctx.JSON(http.StatusUnprocessableEntity, gin.H{"message": err.Error()}) + + return + } + + payload = map[string]any{ + "type": item0.Type, + "metrics": item0.Metrics, + } + baseMetrics = item0.BaseMetrics + default: + gctx.JSON(http.StatusBadRequest, gin.H{"message": "Payload has more than one remediation component"}) + return + } + + if baseMetrics.Os == nil { + baseMetrics.Os = &models.OSversion{ + Name: ptr.Of(""), + Version: ptr.Of(""), + } + } + + ctx := gctx.Request.Context() + + err := c.updateBaseMetrics(ctx, machineID, bouncer, baseMetrics, hubItems, datasources) + if err != nil { + logger.Errorf("Failed to update base metrics: %s", err) + c.HandleDBErrors(gctx, err) + + return + } + + jsonPayload, err := json.Marshal(payload) + if err != nil { + logger.Errorf("Failed to serialize usage metrics: %s", err) + c.HandleDBErrors(gctx, err) + + return + } + + receivedAt := time.Now().UTC() + + if _, err := c.DBClient.CreateMetric(ctx, generatedType, generatedBy, receivedAt, string(jsonPayload)); err != nil { + logger.Error(err) + c.HandleDBErrors(gctx, err) + + return + } + + // if CreateMetrics() returned nil, the metric was already there, we're good + // and don't split hair about 201 vs 200/204 + + gctx.Status(http.StatusCreated) +} diff --git a/pkg/apiserver/controllers/v1/utils.go b/pkg/apiserver/controllers/v1/utils.go index 8edce589816..3cd53d217cc 100644 --- a/pkg/apiserver/controllers/v1/utils.go +++ b/pkg/apiserver/controllers/v1/utils.go @@ -1,35 +1,72 @@ package v1 import ( - "fmt" + "errors" + "net" "net/http" + "strings" - "github.com/crowdsecurity/crowdsec/pkg/database/ent" + jwt "github.com/appleboy/gin-jwt/v2" "github.com/gin-gonic/gin" -) -var ( - bouncerContextKey = "bouncer_info" + middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" ) func getBouncerFromContext(ctx *gin.Context) (*ent.Bouncer, error) { - bouncerInterface, exist := ctx.Get(bouncerContextKey) + bouncerInterface, exist := ctx.Get(middlewares.BouncerContextKey) if !exist { - return nil, fmt.Errorf("bouncer not found") + return nil, errors.New("bouncer not found") } bouncerInfo, ok := bouncerInterface.(*ent.Bouncer) if !ok { - return nil, fmt.Errorf("bouncer not found") + return nil, errors.New("bouncer not found") } return bouncerInfo, nil } +func isUnixSocket(c *gin.Context) bool { + if localAddr, ok := c.Request.Context().Value(http.LocalAddrContextKey).(net.Addr); ok { + return strings.HasPrefix(localAddr.Network(), "unix") + } + + return false +} + +func getMachineIDFromContext(ctx *gin.Context) (string, error) { + claims := jwt.ExtractClaims(ctx) + if claims == nil { + return "", errors.New("failed to extract claims") + } + + rawID, ok := claims[middlewares.MachineIDKey] + if !ok { + return "", errors.New("MachineID not found in claims") + } + + id, ok := rawID.(string) + if !ok { + // should never happen + return "", errors.New("failed to cast machineID to string") + } + + return id, nil +} + func (c *Controller) AbortRemoteIf(option bool) gin.HandlerFunc { return func(gctx *gin.Context) { + if !option { + return + } + + if isUnixSocket(gctx) { + return + } + incomingIP := gctx.ClientIP() - if option && incomingIP != "127.0.0.1" && incomingIP != "::1" { + if incomingIP != "127.0.0.1" && incomingIP != "::1" { gctx.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) gctx.Abort() } diff --git a/pkg/apiserver/decisions_test.go b/pkg/apiserver/decisions_test.go index 5f92b1f0897..a0af6956443 100644 --- a/pkg/apiserver/decisions_test.go +++ b/pkg/apiserver/decisions_test.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -12,88 +13,90 @@ const ( ) func TestDeleteDecisionRange(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) // Create Valid Alert - lapi.InsertAlertFromFile("./tests/alert_minibulk.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_minibulk.json") // delete by ip wrong - w := lapi.RecordResponse("DELETE", "/v1/decisions?range=1.2.3.0/24", emptyBody, PASSWORD) + w := lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?range=1.2.3.0/24", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) - assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) // delete by range - w = lapi.RecordResponse("DELETE", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"2"}`, w.Body.String()) // delete by range : ensure it was already deleted - w = lapi.RecordResponse("DELETE", "/v1/decisions?range=91.121.79.0/24", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?range=91.121.79.0/24", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) } func TestDeleteDecisionFilter(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) // Create Valid Alert - lapi.InsertAlertFromFile("./tests/alert_minibulk.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_minibulk.json") // delete by ip wrong - w := lapi.RecordResponse("DELETE", "/v1/decisions?ip=1.2.3.4", emptyBody, PASSWORD) + w := lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?ip=1.2.3.4", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) // delete by ip good - w = lapi.RecordResponse("DELETE", "/v1/decisions?ip=91.121.79.179", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?ip=91.121.79.179", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String()) // delete by scope/value - w = lapi.RecordResponse("DELETE", "/v1/decisions?scopes=Ip&value=91.121.79.178", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?scopes=Ip&value=91.121.79.178", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String()) } func TestDeleteDecisionFilterByScenario(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) // Create Valid Alert - lapi.InsertAlertFromFile("./tests/alert_minibulk.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_minibulk.json") // delete by wrong scenario - w := lapi.RecordResponse("DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bff", emptyBody, PASSWORD) + w := lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bff", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) // delete by scenario good - w = lapi.RecordResponse("DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bf", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bf", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"2"}`, w.Body.String()) } func TestGetDecisionFilters(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) // Create Valid Alert - lapi.InsertAlertFromFile("./tests/alert_minibulk.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_minibulk.json") // Get Decision - w := lapi.RecordResponse("GET", "/v1/decisions", emptyBody, APIKEY) + w := lapi.RecordResponse(t, ctx, "GET", "/v1/decisions", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) - decisions, code, err := readDecisionsGetResp(w) - assert.Nil(t, err) + decisions, code := readDecisionsGetResp(t, w) assert.Equal(t, 200, code) - assert.Equal(t, 2, len(decisions)) + assert.Len(t, decisions, 2) assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario) assert.Equal(t, "91.121.79.179", *decisions[0].Value) assert.Equal(t, int64(1), decisions[0].ID) @@ -103,12 +106,11 @@ func TestGetDecisionFilters(t *testing.T) { // Get Decision : type filter - w = lapi.RecordResponse("GET", "/v1/decisions?type=ban", emptyBody, APIKEY) + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions?type=ban", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) - decisions, code, err = readDecisionsGetResp(w) - assert.Nil(t, err) + decisions, code = readDecisionsGetResp(t, w) assert.Equal(t, 200, code) - assert.Equal(t, 2, len(decisions)) + assert.Len(t, decisions, 2) assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario) assert.Equal(t, "91.121.79.179", *decisions[0].Value) assert.Equal(t, int64(1), decisions[0].ID) @@ -121,12 +123,11 @@ func TestGetDecisionFilters(t *testing.T) { // Get Decision : scope/value - w = lapi.RecordResponse("GET", "/v1/decisions?scopes=Ip&value=91.121.79.179", emptyBody, APIKEY) + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions?scopes=Ip&value=91.121.79.179", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) - decisions, code, err = readDecisionsGetResp(w) - assert.Nil(t, err) + decisions, code = readDecisionsGetResp(t, w) assert.Equal(t, 200, code) - assert.Equal(t, 1, len(decisions)) + assert.Len(t, decisions, 1) assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario) assert.Equal(t, "91.121.79.179", *decisions[0].Value) assert.Equal(t, int64(1), decisions[0].ID) @@ -136,12 +137,11 @@ func TestGetDecisionFilters(t *testing.T) { // Get Decision : ip filter - w = lapi.RecordResponse("GET", "/v1/decisions?ip=91.121.79.179", emptyBody, APIKEY) + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions?ip=91.121.79.179", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) - decisions, code, err = readDecisionsGetResp(w) - assert.Nil(t, err) + decisions, code = readDecisionsGetResp(t, w) assert.Equal(t, 200, code) - assert.Equal(t, 1, len(decisions)) + assert.Len(t, decisions, 1) assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario) assert.Equal(t, "91.121.79.179", *decisions[0].Value) assert.Equal(t, int64(1), decisions[0].ID) @@ -150,30 +150,28 @@ func TestGetDecisionFilters(t *testing.T) { // assert.NotContains(t, w.Body.String(), `"id":2,"origin":"crowdsec","scenario":"crowdsecurity/ssh-bf","scope":"Ip","type":"ban","value":"91.121.79.178"`) // Get decision : by range - w = lapi.RecordResponse("GET", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, APIKEY) + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) - decisions, code, err = readDecisionsGetResp(w) - assert.Nil(t, err) + decisions, code = readDecisionsGetResp(t, w) assert.Equal(t, 200, code) - assert.Equal(t, 2, len(decisions)) + assert.Len(t, decisions, 2) assert.Contains(t, []string{*decisions[0].Value, *decisions[1].Value}, "91.121.79.179") assert.Contains(t, []string{*decisions[0].Value, *decisions[1].Value}, "91.121.79.178") - } func TestGetDecision(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) // Create Valid Alert - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") // Get Decision - w := lapi.RecordResponse("GET", "/v1/decisions", emptyBody, APIKEY) + w := lapi.RecordResponse(t, ctx, "GET", "/v1/decisions", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) - decisions, code, err := readDecisionsGetResp(w) - assert.Nil(t, err) + decisions, code := readDecisionsGetResp(t, w) assert.Equal(t, 200, code) - assert.Equal(t, 3, len(decisions)) + assert.Len(t, decisions, 3) /*decisions get doesn't perform deduplication*/ assert.Equal(t, "crowdsecurity/test", *decisions[0].Scenario) assert.Equal(t, "127.0.0.1", *decisions[0].Value) @@ -188,146 +186,137 @@ func TestGetDecision(t *testing.T) { assert.Equal(t, int64(3), decisions[2].ID) // Get Decision with invalid filter. It should ignore this filter - w = lapi.RecordResponse("GET", "/v1/decisions?test=test", emptyBody, APIKEY) + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions?test=test", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) - assert.Equal(t, 3, len(decisions)) + assert.Len(t, decisions, 3) } func TestDeleteDecisionByID(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) // Create Valid Alert - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") - //Have one alerts - w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) - decisions, code, err := readDecisionsStreamResp(w) - assert.Equal(t, err, nil) + // Have one alert + w := lapi.RecordResponse(t, ctx, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + decisions, code := readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) - assert.Equal(t, 0, len(decisions["deleted"])) - assert.Equal(t, 1, len(decisions["new"])) + assert.Empty(t, decisions["deleted"]) + assert.Len(t, decisions["new"], 1) // Delete alert with Invalid ID - w = lapi.RecordResponse("DELETE", "/v1/decisions/test", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions/test", emptyBody, PASSWORD) assert.Equal(t, 400, w.Code) - err_resp, _, err := readDecisionsErrorResp(w) - assert.NoError(t, err) - assert.Equal(t, "decision_id must be valid integer", err_resp["message"]) + errResp, _ := readDecisionsErrorResp(t, w) + assert.Equal(t, "decision_id must be valid integer", errResp["message"]) // Delete alert with ID that not exist - w = lapi.RecordResponse("DELETE", "/v1/decisions/100", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions/100", emptyBody, PASSWORD) assert.Equal(t, 500, w.Code) - err_resp, _, err = readDecisionsErrorResp(w) - assert.NoError(t, err) - assert.Equal(t, "decision with id '100' doesn't exist: unable to delete", err_resp["message"]) - - //Have one alerts - w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) - decisions, code, err = readDecisionsStreamResp(w) - assert.Equal(t, err, nil) + errResp, _ = readDecisionsErrorResp(t, w) + assert.Equal(t, "decision with id '100' doesn't exist: unable to delete", errResp["message"]) + + // Have one alert + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + decisions, code = readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) - assert.Equal(t, 0, len(decisions["deleted"])) - assert.Equal(t, 1, len(decisions["new"])) + assert.Empty(t, decisions["deleted"]) + assert.Len(t, decisions["new"], 1) // Delete alert with valid ID - w = lapi.RecordResponse("DELETE", "/v1/decisions/1", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions/1", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) - resp, _, err := readDecisionsDeleteResp(w) - assert.NoError(t, err) - assert.Equal(t, resp.NbDeleted, "1") - - //Have one alert (because we delete an alert that has dup targets) - w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) - decisions, code, err = readDecisionsStreamResp(w) - assert.Equal(t, err, nil) + resp, _ := readDecisionsDeleteResp(t, w) + assert.Equal(t, "1", resp.NbDeleted) + + // Have one alert (because we delete an alert that has dup targets) + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + decisions, code = readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) - assert.Equal(t, 0, len(decisions["deleted"])) - assert.Equal(t, 1, len(decisions["new"])) + assert.Empty(t, decisions["deleted"]) + assert.Len(t, decisions["new"], 1) } func TestDeleteDecision(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) // Create Valid Alert - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") // Delete alert with Invalid filter - w := lapi.RecordResponse("DELETE", "/v1/decisions?test=test", emptyBody, PASSWORD) + w := lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?test=test", emptyBody, PASSWORD) assert.Equal(t, 500, w.Code) - err_resp, _, err := readDecisionsErrorResp(w) - assert.NoError(t, err) - assert.Equal(t, err_resp["message"], "'test' doesn't exist: invalid filter") + errResp, _ := readDecisionsErrorResp(t, w) + assert.Equal(t, "'test' doesn't exist: invalid filter", errResp["message"]) // Delete all alert - w = lapi.RecordResponse("DELETE", "/v1/decisions", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) - resp, _, err := readDecisionsDeleteResp(w) - assert.NoError(t, err) - assert.Equal(t, resp.NbDeleted, "3") + resp, _ := readDecisionsDeleteResp(t, w) + assert.Equal(t, "3", resp.NbDeleted) } func TestStreamStartDecisionDedup(t *testing.T) { - //Ensure that at stream startup we only get the longest decision - lapi := SetupLAPITest(t) + ctx := context.Background() + // Ensure that at stream startup we only get the longest decision + lapi := SetupLAPITest(t, ctx) // Create Valid Alert : 3 decisions for 127.0.0.1, longest has id=3 - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") // Get Stream, we only get one decision (the longest one) - w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) - decisions, code, err := readDecisionsStreamResp(w) - assert.Equal(t, nil, err) + w := lapi.RecordResponse(t, ctx, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + decisions, code := readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) - assert.Equal(t, 0, len(decisions["deleted"])) - assert.Equal(t, 1, len(decisions["new"])) + assert.Empty(t, decisions["deleted"]) + assert.Len(t, decisions["new"], 1) assert.Equal(t, int64(3), decisions["new"][0].ID) assert.Equal(t, "test", *decisions["new"][0].Origin) assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value) // id=3 decision is deleted, this won't affect `deleted`, because there are decisions on the same ip - w = lapi.RecordResponse("DELETE", "/v1/decisions/3", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions/3", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) // Get Stream, we only get one decision (the longest one, id=2) - w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) - decisions, code, err = readDecisionsStreamResp(w) - assert.Equal(t, nil, err) + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + decisions, code = readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) - assert.Equal(t, 0, len(decisions["deleted"])) - assert.Equal(t, 1, len(decisions["new"])) + assert.Empty(t, decisions["deleted"]) + assert.Len(t, decisions["new"], 1) assert.Equal(t, int64(2), decisions["new"][0].ID) assert.Equal(t, "test", *decisions["new"][0].Origin) assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value) // We delete another decision, yet don't receive it in stream, since there's another decision on same IP - w = lapi.RecordResponse("DELETE", "/v1/decisions/2", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions/2", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) // And get the remaining decision (1) - w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) - decisions, code, err = readDecisionsStreamResp(w) - assert.Equal(t, nil, err) + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + decisions, code = readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) - assert.Equal(t, 0, len(decisions["deleted"])) - assert.Equal(t, 1, len(decisions["new"])) + assert.Empty(t, decisions["deleted"]) + assert.Len(t, decisions["new"], 1) assert.Equal(t, int64(1), decisions["new"][0].ID) assert.Equal(t, "test", *decisions["new"][0].Origin) assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value) // We delete the last decision, we receive the delete order - w = lapi.RecordResponse("DELETE", "/v1/decisions/1", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions/1", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) - //and now we only get a deleted decision - w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) - decisions, code, err = readDecisionsStreamResp(w) - assert.Equal(t, nil, err) + // and now we only get a deleted decision + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + decisions, code = readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) - assert.Equal(t, 1, len(decisions["deleted"])) + assert.Len(t, decisions["deleted"], 1) assert.Equal(t, int64(1), decisions["deleted"][0].ID) assert.Equal(t, "test", *decisions["deleted"][0].Origin) assert.Equal(t, "127.0.0.1", *decisions["deleted"][0].Value) - assert.Equal(t, 0, len(decisions["new"])) + assert.Empty(t, decisions["new"]) } type DecisionCheck struct { diff --git a/pkg/apiserver/heartbeat_test.go b/pkg/apiserver/heartbeat_test.go index 0082f23ece8..db051566f75 100644 --- a/pkg/apiserver/heartbeat_test.go +++ b/pkg/apiserver/heartbeat_test.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "net/http" "testing" @@ -8,11 +9,12 @@ import ( ) func TestHeartBeat(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) - w := lapi.RecordResponse(http.MethodGet, "/v1/heartbeat", emptyBody, "password") + w := lapi.RecordResponse(t, ctx, http.MethodGet, "/v1/heartbeat", emptyBody, "password") assert.Equal(t, 200, w.Code) - w = lapi.RecordResponse("POST", "/v1/heartbeat", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "POST", "/v1/heartbeat", emptyBody, "password") assert.Equal(t, 405, w.Code) } diff --git a/pkg/apiserver/jwt_test.go b/pkg/apiserver/jwt_test.go index ebca9125260..f6f51763975 100644 --- a/pkg/apiserver/jwt_test.go +++ b/pkg/apiserver/jwt_test.go @@ -1,95 +1,86 @@ package apiserver import ( + "context" "net/http" "net/http/httptest" "strings" "testing" - log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" ) func TestLogin(t *testing.T) { - router, config, err := NewAPITest(t) - if err != nil { - log.Fatalf("unable to run local API: %s", err) - } + ctx := context.Background() + router, config := NewAPITest(t, ctx) - body, err := CreateTestMachine(router) - if err != nil { - log.Fatalln(err) - } + body := CreateTestMachine(t, ctx, router, "") // Login with machine not validated yet w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) - assert.Equal(t, "{\"code\":401,\"message\":\"machine test not validated\"}", w.Body.String()) + assert.Equal(t, `{"code":401,"message":"machine test not validated"}`, w.Body.String()) // Login with machine not exist w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader("{\"machine_id\": \"test1\", \"password\": \"test1\"}")) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test1", "password": "test1"}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) - assert.Equal(t, "{\"code\":401,\"message\":\"ent: machine not found\"}", w.Body.String()) + assert.Equal(t, `{"code":401,"message":"ent: machine not found"}`, w.Body.String()) // Login with invalid body w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader("test")) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader("test")) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) - assert.Equal(t, "{\"code\":401,\"message\":\"missing: invalid character 'e' in literal true (expecting 'r')\"}", w.Body.String()) + assert.Equal(t, `{"code":401,"message":"missing: invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String()) // Login with invalid format w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader("{\"machine_id\": \"test1\"}")) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test1"}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) - assert.Equal(t, "{\"code\":401,\"message\":\"input format error\"}", w.Body.String()) + assert.Equal(t, `{"code":401,"message":"validation failure list:\npassword in body is required"}`, w.Body.String()) - //Validate machine - err = ValidateMachine("test", config.API.Server.DbConfig) - if err != nil { - log.Fatalln(err) - } + // Validate machine + ValidateMachine(t, ctx, "test", config.API.Server.DbConfig) // Login with invalid password w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader("{\"machine_id\": \"test\", \"password\": \"test1\"}")) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test", "password": "test1"}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) - assert.Equal(t, "{\"code\":401,\"message\":\"incorrect Username or Password\"}", w.Body.String()) + assert.Equal(t, `{"code":401,"message":"incorrect Username or Password"}`, w.Body.String()) // Login with valid machine w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) - assert.Contains(t, w.Body.String(), "\"token\"") - assert.Contains(t, w.Body.String(), "\"expire\"") + assert.Contains(t, w.Body.String(), `"token"`) + assert.Contains(t, w.Body.String(), `"expire"`) // Login with valid machine + scenarios w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader("{\"machine_id\": \"test\", \"password\": \"test\", \"scenarios\": [\"crowdsecurity/test\", \"crowdsecurity/test2\"]}")) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test", "password": "test", "scenarios": ["crowdsecurity/test", "crowdsecurity/test2"]}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) - assert.Contains(t, w.Body.String(), "\"token\"") - assert.Contains(t, w.Body.String(), "\"expire\"") - + assert.Contains(t, w.Body.String(), `"token"`) + assert.Contains(t, w.Body.String(), `"expire"`) } diff --git a/pkg/apiserver/machines_test.go b/pkg/apiserver/machines_test.go index 25fd0eaf445..969f75707d6 100644 --- a/pkg/apiserver/machines_test.go +++ b/pkg/apiserver/machines_test.go @@ -1,169 +1,233 @@ package apiserver import ( + "context" "encoding/json" "net/http" "net/http/httptest" "strings" "testing" - log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/crowdsecurity/go-cs-lib/ptr" ) func TestCreateMachine(t *testing.T) { - router, _, err := NewAPITest(t) - if err != nil { - log.Fatalf("unable to run local API: %s", err) - } + ctx := context.Background() + router, _ := NewAPITest(t, ctx) // Create machine with invalid format w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader("test")) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader("test")) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) - assert.Equal(t, 400, w.Code) - assert.Equal(t, "{\"message\":\"invalid character 'e' in literal true (expecting 'r')\"}", w.Body.String()) + assert.Equal(t, http.StatusBadRequest, w.Code) + assert.Equal(t, `{"message":"invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String()) // Create machine with invalid input w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader("{\"test\": \"test\"}")) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(`{"test": "test"}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) - assert.Equal(t, 500, w.Code) - assert.Equal(t, "{\"message\":\"validation failure list:\\nmachine_id in body is required\\npassword in body is required\"}", w.Body.String()) + assert.Equal(t, http.StatusUnprocessableEntity, w.Code) + assert.Equal(t, `{"message":"validation failure list:\nmachine_id in body is required\npassword in body is required"}`, w.Body.String()) // Create machine b, err := json.Marshal(MachineTest) - if err != nil { - log.Fatal("unable to marshal MachineTest") - } + require.NoError(t, err) + body := string(b) w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) - assert.Equal(t, 201, w.Code) + assert.Equal(t, http.StatusCreated, w.Code) assert.Equal(t, "", w.Body.String()) - } func TestCreateMachineWithForwardedFor(t *testing.T) { - router, config, err := NewAPITestForwardedFor(t) - if err != nil { - log.Fatalf("unable to run local API: %s", err) - } + ctx := context.Background() + router, config := NewAPITestForwardedFor(t, ctx) router.TrustedPlatform = "X-Real-IP" + // Create machine b, err := json.Marshal(MachineTest) - if err != nil { - log.Fatal("unable to marshal MachineTest") - } + require.NoError(t, err) + body := string(b) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) req.Header.Add("X-Real-Ip", "1.1.1.1") router.ServeHTTP(w, req) - assert.Equal(t, 201, w.Code) + assert.Equal(t, http.StatusCreated, w.Code) assert.Equal(t, "", w.Body.String()) - ip, err := GetMachineIP(*MachineTest.MachineID, config.API.Server.DbConfig) - if err != nil { - log.Fatalf("Could not get machine IP : %s", err) - } + ip := GetMachineIP(t, *MachineTest.MachineID, config.API.Server.DbConfig) + assert.Equal(t, "1.1.1.1", ip) } func TestCreateMachineWithForwardedForNoConfig(t *testing.T) { - router, config, err := NewAPITest(t) - if err != nil { - log.Fatalf("unable to run local API: %s", err) - } + ctx := context.Background() + router, config := NewAPITest(t, ctx) // Create machine b, err := json.Marshal(MachineTest) - if err != nil { - log.Fatal("unable to marshal MachineTest") - } + require.NoError(t, err) + body := string(b) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) req.Header.Add("X-Real-IP", "1.1.1.1") router.ServeHTTP(w, req) - assert.Equal(t, 201, w.Code) + assert.Equal(t, http.StatusCreated, w.Code) assert.Equal(t, "", w.Body.String()) - ip, err := GetMachineIP(*MachineTest.MachineID, config.API.Server.DbConfig) - if err != nil { - log.Fatalf("Could not get machine IP : %s", err) - } - //For some reason, the IP is empty when running tests - //if no forwarded-for headers are present + ip := GetMachineIP(t, *MachineTest.MachineID, config.API.Server.DbConfig) + + // For some reason, the IP is empty when running tests + // if no forwarded-for headers are present assert.Equal(t, "", ip) } func TestCreateMachineWithoutForwardedFor(t *testing.T) { - router, config, err := NewAPITestForwardedFor(t) - if err != nil { - log.Fatalf("unable to run local API: %s", err) - } + ctx := context.Background() + router, config := NewAPITestForwardedFor(t, ctx) // Create machine b, err := json.Marshal(MachineTest) - if err != nil { - log.Fatal("unable to marshal MachineTest") - } + require.NoError(t, err) + body := string(b) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) - assert.Equal(t, 201, w.Code) + assert.Equal(t, http.StatusCreated, w.Code) assert.Equal(t, "", w.Body.String()) - ip, err := GetMachineIP(*MachineTest.MachineID, config.API.Server.DbConfig) - if err != nil { - log.Fatalf("Could not get machine IP : %s", err) - } - //For some reason, the IP is empty when running tests - //if no forwarded-for headers are present + ip := GetMachineIP(t, *MachineTest.MachineID, config.API.Server.DbConfig) + + // For some reason, the IP is empty when running tests + // if no forwarded-for headers are present assert.Equal(t, "", ip) } func TestCreateMachineAlreadyExist(t *testing.T) { - router, _, err := NewAPITest(t) - if err != nil { - log.Fatalf("unable to run local API: %s", err) - } + ctx := context.Background() + router, _ := NewAPITest(t, ctx) - body, err := CreateTestMachine(router) - if err != nil { - log.Fatalln(err) - } + body := CreateTestMachine(t, ctx, router, "") w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req.Header.Add("User-Agent", UserAgent) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusForbidden, w.Code) + assert.Equal(t, `{"message":"user 'test': user already exist"}`, w.Body.String()) +} + +func TestAutoRegistration(t *testing.T) { + ctx := context.Background() + router, _ := NewAPITest(t, ctx) + + // Invalid registration token / valid source IP + regReq := MachineTest + regReq.RegistrationToken = invalidRegistrationToken + b, err := json.Marshal(regReq) + require.NoError(t, err) + + body := string(b) + + w := httptest.NewRecorder() + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) + req.RemoteAddr = "127.0.0.1:4242" router.ServeHTTP(w, req) - assert.Equal(t, 403, w.Code) - assert.Equal(t, "{\"message\":\"user 'test': user already exist\"}", w.Body.String()) + assert.Equal(t, http.StatusUnauthorized, w.Code) + + // Invalid registration token / invalid source IP + regReq = MachineTest + regReq.RegistrationToken = invalidRegistrationToken + b, err = json.Marshal(regReq) + require.NoError(t, err) + + body = string(b) + + w = httptest.NewRecorder() + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req.Header.Add("User-Agent", UserAgent) + req.RemoteAddr = "42.42.42.42:4242" + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusUnauthorized, w.Code) + + // valid registration token / invalid source IP + regReq = MachineTest + regReq.RegistrationToken = validRegistrationToken + b, err = json.Marshal(regReq) + require.NoError(t, err) + + body = string(b) + + w = httptest.NewRecorder() + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req.Header.Add("User-Agent", UserAgent) + req.RemoteAddr = "42.42.42.42:4242" + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusUnauthorized, w.Code) + + // Valid registration token / valid source IP + regReq = MachineTest + regReq.RegistrationToken = validRegistrationToken + b, err = json.Marshal(regReq) + require.NoError(t, err) + + body = string(b) + + w = httptest.NewRecorder() + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req.Header.Add("User-Agent", UserAgent) + req.RemoteAddr = "127.0.0.1:4242" + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusAccepted, w.Code) + + // No token / valid source IP + regReq = MachineTest + regReq.MachineID = ptr.Of("test2") + b, err = json.Marshal(regReq) + require.NoError(t, err) + + body = string(b) + + w = httptest.NewRecorder() + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req.Header.Add("User-Agent", UserAgent) + req.RemoteAddr = "127.0.0.1:4242" + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusCreated, w.Code) } diff --git a/pkg/apiserver/middlewares/v1/api_key.go b/pkg/apiserver/middlewares/v1/api_key.go index ce1bc8eeece..d438c9b15a4 100644 --- a/pkg/apiserver/middlewares/v1/api_key.go +++ b/pkg/apiserver/middlewares/v1/api_key.go @@ -8,18 +8,19 @@ import ( "net/http" "strings" + "github.com/gin-gonic/gin" + log "github.com/sirupsen/logrus" + "github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/types" - "github.com/gin-gonic/gin" - log "github.com/sirupsen/logrus" ) const ( APIKeyHeader = "X-Api-Key" - bouncerContextKey = "bouncer_info" - // max allowed by bcrypt 72 = 54 bytes in base64 + BouncerContextKey = "bouncer_info" dummyAPIKeySize = 54 + // max allowed by bcrypt 72 = 54 bytes in base64 ) type APIKey struct { @@ -33,7 +34,11 @@ func GenerateAPIKey(n int) (string, error) { if _, err := rand.Read(bytes); err != nil { return "", err } - return base64.StdEncoding.EncodeToString(bytes), nil + + encoded := base64.StdEncoding.EncodeToString(bytes) + + // the '=' can cause issues on some bouncers + return strings.TrimRight(encoded, "="), nil } func NewAPIKey(dbClient *database.Client) *APIKey { @@ -53,176 +58,145 @@ func HashSHA512(str string) string { return hashStr } +func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer { + if a.TlsAuth == nil { + logger.Warn("TLS Auth is not configured but client presented a certificate") + return nil + } + + ctx := c.Request.Context() + + extractedCN, err := a.TlsAuth.ValidateCert(c) + if err != nil { + logger.Warn(err) + return nil + } + + logger = logger.WithField("cn", extractedCN) + + bouncerName := fmt.Sprintf("%s@%s", extractedCN, c.ClientIP()) + bouncer, err := a.DbClient.SelectBouncerByName(ctx, bouncerName) + + // This is likely not the proper way, but isNotFound does not seem to work + if err != nil && strings.Contains(err.Error(), "bouncer not found") { + // Because we have a valid cert, automatically create the bouncer in the database if it does not exist + // Set a random API key, but it will never be used + apiKey, err := GenerateAPIKey(dummyAPIKeySize) + if err != nil { + logger.Errorf("error generating mock api key: %s", err) + return nil + } + + logger.Infof("Creating bouncer %s", bouncerName) + + bouncer, err = a.DbClient.CreateBouncer(ctx, bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType) + if err != nil { + logger.Errorf("while creating bouncer db entry: %s", err) + return nil + } + } else if err != nil { + // error while selecting bouncer + logger.Errorf("while selecting bouncers: %s", err) + return nil + } else if bouncer.AuthType != types.TlsAuthType { + // bouncer was found in DB + logger.Errorf("bouncer isn't allowed to auth by TLS") + return nil + } + + return bouncer +} + +func (a *APIKey) authPlain(c *gin.Context, logger *log.Entry) *ent.Bouncer { + val, ok := c.Request.Header[APIKeyHeader] + if !ok { + logger.Errorf("API key not found") + return nil + } + + ctx := c.Request.Context() + + hashStr := HashSHA512(val[0]) + + bouncer, err := a.DbClient.SelectBouncer(ctx, hashStr) + if err != nil { + logger.Errorf("while fetching bouncer info: %s", err) + return nil + } + + if bouncer.AuthType != types.ApiKeyAuthType { + logger.Errorf("bouncer %s attempted to login using an API key but it is configured to auth with %s", bouncer.Name, bouncer.AuthType) + return nil + } + + return bouncer +} + func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { return func(c *gin.Context) { var bouncer *ent.Bouncer - var err error + + ctx := c.Request.Context() + + clientIP := c.ClientIP() + + logger := log.WithField("ip", clientIP) if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 { - if a.TlsAuth == nil { - log.WithField("ip", c.ClientIP()).Error("TLS Auth is not configured but client presented a certificate") - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return - } - validCert, extractedCN, err := a.TlsAuth.ValidateCert(c) - if !validCert { - log.WithField("ip", c.ClientIP()).Errorf("invalid client certificate: %s", err) - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return - } - if err != nil { - log.WithField("ip", c.ClientIP()).Error(err) - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return - } - bouncerName := fmt.Sprintf("%s@%s", extractedCN, c.ClientIP()) - bouncer, err = a.DbClient.SelectBouncerByName(bouncerName) - //This is likely not the proper way, but isNotFound does not seem to work - if err != nil && strings.Contains(err.Error(), "bouncer not found") { - //Because we have a valid cert, automatically create the bouncer in the database if it does not exist - //Set a random API key, but it will never be used - apiKey, err := GenerateAPIKey(dummyAPIKeySize) - if err != nil { - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - "cn": extractedCN, - }).Errorf("error generating mock api key: %s", err) - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return - } - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - "cn": extractedCN, - }).Infof("Creating bouncer %s", bouncerName) - bouncer, err = a.DbClient.CreateBouncer(bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType) - if err != nil { - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - "cn": extractedCN, - }).Errorf("creating bouncer db entry : %s", err) - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return - } - } else if err != nil { - //error while selecting bouncer - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - "cn": extractedCN, - }).Errorf("while selecting bouncers: %s", err) - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return - } else if bouncer.AuthType != types.TlsAuthType { - //bouncer was found in DB - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - "cn": extractedCN, - }).Errorf("bouncer isn't allowed to auth by TLS") - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return - } + bouncer = a.authTLS(c, logger) } else { - //API Key Authentication - val, ok := c.Request.Header[APIKeyHeader] - if !ok { - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - }).Errorf("API key not found") - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return - } - hashStr := HashSHA512(val[0]) - bouncer, err = a.DbClient.SelectBouncer(hashStr) - if err != nil { - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - }).Errorf("while fetching bouncer info: %s", err) - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return - } - if bouncer.AuthType != types.ApiKeyAuthType { - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - }).Errorf("bouncer %s attempted to login using an API key but it is configured to auth with %s", bouncer.Name, bouncer.AuthType) - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return - } + bouncer = a.authPlain(c, logger) } if bouncer == nil { - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - }).Errorf("bouncer not found") + // XXX: StatusUnauthorized? c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.Abort() + return } - //maybe we want to store the whole bouncer object in the context instead, this would avoid another db query - //in StreamDecision - c.Set("BOUNCER_NAME", bouncer.Name) - c.Set("BOUNCER_HASHED_KEY", bouncer.APIKey) + logger = logger.WithField("name", bouncer.Name) if bouncer.IPAddress == "" { - err = a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID) - if err != nil { - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - "name": bouncer.Name, - }).Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) + if err := a.DbClient.UpdateBouncerIP(ctx, clientIP, bouncer.ID); err != nil { + logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.Abort() + return } } - if bouncer.IPAddress != c.ClientIP() && bouncer.IPAddress != "" { - log.Warningf("new IP address detected for bouncer '%s': %s (old: %s)", bouncer.Name, c.ClientIP(), bouncer.IPAddress) - err = a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID) - if err != nil { - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - "name": bouncer.Name, - }).Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) + // Don't update IP on HEAD request, as it's used by the appsec to check the validity of the API key provided + if bouncer.IPAddress != clientIP && bouncer.IPAddress != "" && c.Request.Method != http.MethodHead { + log.Warningf("new IP address detected for bouncer '%s': %s (old: %s)", bouncer.Name, clientIP, bouncer.IPAddress) + + if err := a.DbClient.UpdateBouncerIP(ctx, clientIP, bouncer.ID); err != nil { + logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.Abort() + return } } useragent := strings.Split(c.Request.UserAgent(), "/") - if len(useragent) != 2 { - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - "name": bouncer.Name, - }).Warningf("bad user agent '%s'", c.Request.UserAgent()) + logger.Warningf("bad user agent '%s'", c.Request.UserAgent()) useragent = []string{c.Request.UserAgent(), "N/A"} } if bouncer.Version != useragent[1] || bouncer.Type != useragent[0] { - if err := a.DbClient.UpdateBouncerTypeAndVersion(useragent[0], useragent[1], bouncer.ID); err != nil { - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - "name": bouncer.Name, - }).Errorf("failed to update bouncer version and type: %s", err) + if err := a.DbClient.UpdateBouncerTypeAndVersion(ctx, useragent[0], useragent[1], bouncer.ID); err != nil { + logger.Errorf("failed to update bouncer version and type: %s", err) c.JSON(http.StatusForbidden, gin.H{"message": "bad user agent"}) c.Abort() + return } } - c.Set(bouncerContextKey, bouncer) - - c.Next() + c.Set(BouncerContextKey, bouncer) } } diff --git a/pkg/apiserver/middlewares/v1/cache.go b/pkg/apiserver/middlewares/v1/cache.go new file mode 100644 index 00000000000..b0037bc4fa4 --- /dev/null +++ b/pkg/apiserver/middlewares/v1/cache.go @@ -0,0 +1,99 @@ +package v1 + +import ( + "crypto/x509" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +type cacheEntry struct { + err error // if nil, the certificate is not revocated + timestamp time.Time +} + +type RevocationCache struct { + mu sync.RWMutex + cache map[string]cacheEntry + expiration time.Duration + lastPurge time.Time + logger *log.Entry +} + +func NewRevocationCache(expiration time.Duration, logger *log.Entry) *RevocationCache { + return &RevocationCache{ + cache: make(map[string]cacheEntry), + expiration: expiration, + lastPurge: time.Now(), + logger: logger, + } +} + +func (*RevocationCache) generateKey(cert *x509.Certificate) string { + return cert.SerialNumber.String() + "-" + cert.Issuer.String() +} + +// purge removes expired entries from the cache +func (rc *RevocationCache) purgeExpired() { + // we don't keep a separate interval for the full sweep, we'll just double the expiration + if time.Since(rc.lastPurge) < rc.expiration { + return + } + + rc.mu.Lock() + defer rc.mu.Unlock() + + for key, entry := range rc.cache { + if time.Since(entry.timestamp) > rc.expiration { + rc.logger.Debugf("purging expired entry for cert %s", key) + delete(rc.cache, key) + } + } +} + +func (rc *RevocationCache) Get(cert *x509.Certificate) (error, bool) { //nolint:revive + rc.purgeExpired() + key := rc.generateKey(cert) + rc.mu.RLock() + entry, exists := rc.cache[key] + rc.mu.RUnlock() + + if !exists { + rc.logger.Tracef("no cached value for cert %s", key) + return nil, false + } + + // Upgrade to write lock to potentially modify the cache + rc.mu.Lock() + defer rc.mu.Unlock() + + if entry.timestamp.Add(rc.expiration).Before(time.Now()) { + rc.logger.Debugf("cached value for %s expired, removing from cache", key) + delete(rc.cache, key) + + return nil, false + } + + rc.logger.Debugf("using cached value for cert %s: %v", key, entry.err) + + return entry.err, true +} + +func (rc *RevocationCache) Set(cert *x509.Certificate, err error) { + key := rc.generateKey(cert) + + rc.mu.Lock() + defer rc.mu.Unlock() + + rc.cache[key] = cacheEntry{ + err: err, + timestamp: time.Now(), + } +} + +func (rc *RevocationCache) Empty() { + rc.mu.Lock() + defer rc.mu.Unlock() + rc.cache = make(map[string]cacheEntry) +} diff --git a/pkg/apiserver/middlewares/v1/crl.go b/pkg/apiserver/middlewares/v1/crl.go new file mode 100644 index 00000000000..64d7d3f0d96 --- /dev/null +++ b/pkg/apiserver/middlewares/v1/crl.go @@ -0,0 +1,145 @@ +package v1 + +import ( + "crypto/x509" + "encoding/pem" + "fmt" + "os" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +type CRLChecker struct { + path string // path to the CRL file + fileInfo os.FileInfo // last stat of the CRL file + crls []*x509.RevocationList // parsed CRLs + logger *log.Entry + mu sync.RWMutex + lastLoad time.Time // time when the CRL file was last read successfully + onLoad func() // called when the CRL file changes (and is read successfully) +} + +func NewCRLChecker(crlPath string, onLoad func(), logger *log.Entry) (*CRLChecker, error) { + cc := &CRLChecker{ + path: crlPath, + logger: logger, + onLoad: onLoad, + } + + err := cc.refresh() + if err != nil { + return nil, err + } + + return cc, nil +} + +func (*CRLChecker) decodeCRLs(content []byte) ([]*x509.RevocationList, error) { + var crls []*x509.RevocationList + + for { + block, rest := pem.Decode(content) + if block == nil { + break // no more PEM blocks + } + + content = rest + + crl, err := x509.ParseRevocationList(block.Bytes) + if err != nil { + // invalidate the whole CRL file so we can still use the previous version + return nil, fmt.Errorf("could not parse file: %w", err) + } + + crls = append(crls, crl) + } + + return crls, nil +} + +// refresh() reads the CRL file if new or changed since the last time +func (cc *CRLChecker) refresh() error { + // noop if lastLoad is less than 5 seconds ago + if time.Since(cc.lastLoad) < 5*time.Second { + return nil + } + + cc.mu.Lock() + defer cc.mu.Unlock() + + cc.logger.Debugf("loading CRL file from %s", cc.path) + + fileInfo, err := os.Stat(cc.path) + if err != nil { + return fmt.Errorf("could not access CRL file: %w", err) + } + + // noop if the file didn't change + if cc.fileInfo != nil && fileInfo.ModTime().Equal(cc.fileInfo.ModTime()) && fileInfo.Size() == cc.fileInfo.Size() { + return nil + } + + // the encoding/pem package wants bytes, not io.Reader + crlContent, err := os.ReadFile(cc.path) + if err != nil { + return fmt.Errorf("could not read CRL file: %w", err) + } + + cc.crls, err = cc.decodeCRLs(crlContent) + if err != nil { + return err + } + + cc.fileInfo = fileInfo + cc.lastLoad = time.Now() + cc.onLoad() + + return nil +} + +// isRevoked checks if the client certificate is revoked by any of the CRL blocks +// It returns a boolean indicating if the certificate is revoked and a boolean indicating +// if the CRL check was successful and could be cached. +func (cc *CRLChecker) isRevokedBy(cert *x509.Certificate, issuer *x509.Certificate) (bool, bool) { + if cc == nil { + return false, true + } + + err := cc.refresh() + if err != nil { + // we can't quit obviously, so we just log the error and continue + // but we can assume we have loaded a CRL, or it would have quit the first time + cc.logger.Errorf("while refreshing CRL: %s - will keep using CRL file read at %s", err, + cc.lastLoad.Format(time.RFC3339)) + } + + now := time.Now().UTC() + + cc.mu.RLock() + defer cc.mu.RUnlock() + + for _, crl := range cc.crls { + if err := crl.CheckSignatureFrom(issuer); err != nil { + continue + } + + if now.After(crl.NextUpdate) { + cc.logger.Warn("CRL has expired, will still validate the cert against it.") + } + + if now.Before(crl.ThisUpdate) { + cc.logger.Warn("CRL is not yet valid, will still validate the cert against it.") + } + + for _, revoked := range crl.RevokedCertificateEntries { + if revoked.SerialNumber.Cmp(cert.SerialNumber) == 0 { + cc.logger.Warn("client certificate is revoked by CRL") + return true, true + } + } + } + + return false, true +} diff --git a/pkg/apiserver/middlewares/v1/jwt.go b/pkg/apiserver/middlewares/v1/jwt.go index bbd33c54420..9171e9fce06 100644 --- a/pkg/apiserver/middlewares/v1/jwt.go +++ b/pkg/apiserver/middlewares/v1/jwt.go @@ -2,26 +2,26 @@ package v1 import ( "crypto/rand" + "errors" "fmt" - "net/http" "os" "strings" "time" jwt "github.com/appleboy/gin-jwt/v2" + "github.com/gin-gonic/gin" + "github.com/go-openapi/strfmt" + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/bcrypt" + "github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/types" - "github.com/gin-gonic/gin" - "github.com/go-openapi/strfmt" - "github.com/pkg/errors" - log "github.com/sirupsen/logrus" - "golang.org/x/crypto/bcrypt" ) -var identityKey = "id" +const MachineIDKey = "id" type JWT struct { Middleware *jwt.GinJWTMiddleware @@ -32,175 +32,220 @@ type JWT struct { func PayloadFunc(data interface{}) jwt.MapClaims { if value, ok := data.(*models.WatcherAuthRequest); ok { return jwt.MapClaims{ - identityKey: &value.MachineID, + MachineIDKey: &value.MachineID, } } + return jwt.MapClaims{} } func IdentityHandler(c *gin.Context) interface{} { claims := jwt.ExtractClaims(c) - machineId := claims[identityKey].(string) + machineID := claims[MachineIDKey].(string) + return &models.WatcherAuthRequest{ - MachineID: &machineId, + MachineID: &machineID, } } -func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { - var loginInput models.WatcherAuthRequest - var scenarios string - var err error - var scenariosInput []string - var clientMachine *ent.Machine - var machineID string +type authInput struct { + machineID string + clientMachine *ent.Machine + scenariosInput []string +} - if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 { - if j.TlsAuth == nil { - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return nil, errors.New("TLS auth is not configured") - } - validCert, extractedCN, err := j.TlsAuth.ValidateCert(c) +func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { + ctx := c.Request.Context() + ret := authInput{} + + if j.TlsAuth == nil { + err := errors.New("tls authentication required") + log.Warn(err) + + return nil, err + } + + extractedCN, err := j.TlsAuth.ValidateCert(c) + if err != nil { + log.Warn(err) + return nil, err + } + + logger := log.WithField("ip", c.ClientIP()) + + ret.machineID = fmt.Sprintf("%s@%s", extractedCN, c.ClientIP()) + + ret.clientMachine, err = j.DbClient.Ent.Machine.Query(). + Where(machine.MachineId(ret.machineID)). + First(ctx) + if ent.IsNotFound(err) { + // Machine was not found, let's create it + logger.Infof("machine %s not found, create it", ret.machineID) + // let's use an apikey as the password, doesn't matter in this case (generatePassword is only available in cscli) + pwd, err := GenerateAPIKey(dummyAPIKeySize) if err != nil { - log.Error(err) - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return nil, errors.Wrap(err, "while trying to validate client cert") - } - if !validCert { - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return nil, fmt.Errorf("failed cert authentication") - } + logger.WithField("cn", extractedCN). + Errorf("error generating password: %s", err) - machineID = fmt.Sprintf("%s@%s", extractedCN, c.ClientIP()) - clientMachine, err = j.DbClient.Ent.Machine.Query(). - Where(machine.MachineId(machineID)). - First(j.DbClient.CTX) - if ent.IsNotFound(err) { - //Machine was not found, let's create it - log.Printf("machine %s not found, create it", machineID) - //let's use an apikey as the password, doesn't matter in this case (generatePassword is only available in cscli) - pwd, err := GenerateAPIKey(dummyAPIKeySize) - if err != nil { - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - "cn": extractedCN, - }).Errorf("error generating password: %s", err) - return nil, fmt.Errorf("error generating password") - } - password := strfmt.Password(pwd) - clientMachine, err = j.DbClient.CreateMachine(&machineID, &password, "", true, true, types.TlsAuthType) - if err != nil { - return "", errors.Wrapf(err, "while creating machine entry for %s", machineID) - } - } else if err != nil { - return "", errors.Wrapf(err, "while selecting machine entry for %s", machineID) - } else { - if clientMachine.AuthType != types.TlsAuthType { - return "", errors.Errorf("machine %s attempted to auth with TLS cert but it is configured to use %s", machineID, clientMachine.AuthType) - } - machineID = clientMachine.MachineId - loginInput := struct { - Scenarios []string `json:"scenarios"` - }{ - Scenarios: []string{}, - } - err := c.ShouldBindJSON(&loginInput) - if err != nil { - return "", errors.Wrap(err, "missing scenarios list in login request for TLS auth") - } - scenariosInput = loginInput.Scenarios + return nil, errors.New("error generating password") } - } else { - //normal auth + password := strfmt.Password(pwd) - if err := c.ShouldBindJSON(&loginInput); err != nil { - return "", errors.Wrap(err, "missing") + ret.clientMachine, err = j.DbClient.CreateMachine(ctx, &ret.machineID, &password, "", true, true, types.TlsAuthType) + if err != nil { + return nil, fmt.Errorf("while creating machine entry for %s: %w", ret.machineID, err) } - if err := loginInput.Validate(strfmt.Default); err != nil { - return "", errors.New("input format error") + } else if err != nil { + return nil, fmt.Errorf("while selecting machine entry for %s: %w", ret.machineID, err) + } else { + if ret.clientMachine.AuthType != types.TlsAuthType { + return nil, fmt.Errorf("machine %s attempted to auth with TLS cert but it is configured to use %s", ret.machineID, ret.clientMachine.AuthType) } - machineID = *loginInput.MachineID - password := *loginInput.Password - scenariosInput = loginInput.Scenarios - clientMachine, err = j.DbClient.Ent.Machine.Query(). - Where(machine.MachineId(machineID)). - First(j.DbClient.CTX) - if err != nil { - log.Printf("Error machine login for %s : %+v ", machineID, err) - return nil, err - } + ret.machineID = ret.clientMachine.MachineId + } - if clientMachine == nil { - log.Errorf("Nothing for '%s'", machineID) - return nil, jwt.ErrFailedAuthentication - } + loginInput := struct { + Scenarios []string `json:"scenarios"` + }{ + Scenarios: []string{}, + } - if clientMachine.AuthType != types.PasswordAuthType { - return nil, errors.Errorf("machine %s attempted to auth with password but it is configured to use %s", machineID, clientMachine.AuthType) - } + err = c.ShouldBindJSON(&loginInput) + if err != nil { + return nil, fmt.Errorf("missing scenarios list in login request for TLS auth: %w", err) + } - if !clientMachine.IsValidated { - return nil, fmt.Errorf("machine %s not validated", machineID) - } + ret.scenariosInput = loginInput.Scenarios - if err = bcrypt.CompareHashAndPassword([]byte(clientMachine.Password), []byte(password)); err != nil { - return nil, jwt.ErrFailedAuthentication - } + return &ret, nil +} + +func (j *JWT) authPlain(c *gin.Context) (*authInput, error) { + var ( + loginInput models.WatcherAuthRequest + err error + ) - //end of normal auth + ctx := c.Request.Context() + + ret := authInput{} + + if err = c.ShouldBindJSON(&loginInput); err != nil { + return nil, fmt.Errorf("missing: %w", err) } - if len(scenariosInput) > 0 { - for _, scenario := range scenariosInput { + if err = loginInput.Validate(strfmt.Default); err != nil { + return nil, err + } + + ret.machineID = *loginInput.MachineID + password := *loginInput.Password + ret.scenariosInput = loginInput.Scenarios + + ret.clientMachine, err = j.DbClient.Ent.Machine.Query(). + Where(machine.MachineId(ret.machineID)). + First(ctx) + if err != nil { + log.Infof("Error machine login for %s : %+v ", ret.machineID, err) + return nil, err + } + + if ret.clientMachine == nil { + log.Errorf("Nothing for '%s'", ret.machineID) + return nil, jwt.ErrFailedAuthentication + } + + if ret.clientMachine.AuthType != types.PasswordAuthType { + return nil, fmt.Errorf("machine %s attempted to auth with password but it is configured to use %s", ret.machineID, ret.clientMachine.AuthType) + } + + if !ret.clientMachine.IsValidated { + return nil, fmt.Errorf("machine %s not validated", ret.machineID) + } + + if err := bcrypt.CompareHashAndPassword([]byte(ret.clientMachine.Password), []byte(password)); err != nil { + return nil, jwt.ErrFailedAuthentication + } + + return &ret, nil +} + +func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { + var ( + err error + auth *authInput + ) + + ctx := c.Request.Context() + + if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 { + auth, err = j.authTLS(c) + if err != nil { + return nil, err + } + } else { + auth, err = j.authPlain(c) + if err != nil { + return nil, err + } + } + + var scenarios string + + if len(auth.scenariosInput) > 0 { + for _, scenario := range auth.scenariosInput { if scenarios == "" { scenarios = scenario } else { scenarios += "," + scenario } } - err = j.DbClient.UpdateMachineScenarios(scenarios, clientMachine.ID) + + err = j.DbClient.UpdateMachineScenarios(ctx, scenarios, auth.clientMachine.ID) if err != nil { - log.Errorf("Failed to update scenarios list for '%s': %s\n", machineID, err) + log.Errorf("Failed to update scenarios list for '%s': %s\n", auth.machineID, err) return nil, jwt.ErrFailedAuthentication } } - if clientMachine.IpAddress == "" { - err = j.DbClient.UpdateMachineIP(c.ClientIP(), clientMachine.ID) + clientIP := c.ClientIP() + + if auth.clientMachine.IpAddress == "" { + err = j.DbClient.UpdateMachineIP(ctx, clientIP, auth.clientMachine.ID) if err != nil { - log.Errorf("Failed to update ip address for '%s': %s\n", machineID, err) + log.Errorf("Failed to update ip address for '%s': %s\n", auth.machineID, err) return nil, jwt.ErrFailedAuthentication } } - if clientMachine.IpAddress != c.ClientIP() && clientMachine.IpAddress != "" { - log.Warningf("new IP address detected for machine '%s': %s (old: %s)", clientMachine.MachineId, c.ClientIP(), clientMachine.IpAddress) - err = j.DbClient.UpdateMachineIP(c.ClientIP(), clientMachine.ID) + if auth.clientMachine.IpAddress != clientIP && auth.clientMachine.IpAddress != "" { + log.Warningf("new IP address detected for machine '%s': %s (old: %s)", auth.clientMachine.MachineId, clientIP, auth.clientMachine.IpAddress) + + err = j.DbClient.UpdateMachineIP(ctx, clientIP, auth.clientMachine.ID) if err != nil { - log.Errorf("Failed to update ip address for '%s': %s\n", clientMachine.MachineId, err) + log.Errorf("Failed to update ip address for '%s': %s\n", auth.clientMachine.MachineId, err) return nil, jwt.ErrFailedAuthentication } } useragent := strings.Split(c.Request.UserAgent(), "/") if len(useragent) != 2 { - log.Warningf("bad user agent '%s' from '%s'", c.Request.UserAgent(), c.ClientIP()) + log.Warningf("bad user agent '%s' from '%s'", c.Request.UserAgent(), clientIP) return nil, jwt.ErrFailedAuthentication } - if err := j.DbClient.UpdateMachineVersion(useragent[1], clientMachine.ID); err != nil { - log.Errorf("unable to update machine '%s' version '%s': %s", clientMachine.MachineId, useragent[1], err) - log.Errorf("bad user agent from : %s", c.ClientIP()) + if err := j.DbClient.UpdateMachineVersion(ctx, useragent[1], auth.clientMachine.ID); err != nil { + log.Errorf("unable to update machine '%s' version '%s': %s", auth.clientMachine.MachineId, useragent[1], err) + log.Errorf("bad user agent from : %s", clientIP) + return nil, jwt.ErrFailedAuthentication } + return &models.WatcherAuthRequest{ - MachineID: &machineID, + MachineID: &auth.machineID, }, nil - } func Authorizator(data interface{}, c *gin.Context) bool { @@ -262,7 +307,7 @@ func NewJWT(dbClient *database.Client) (*JWT, error) { Key: secret, Timeout: time.Hour, MaxRefresh: time.Hour, - IdentityKey: identityKey, + IdentityKey: MachineIDKey, PayloadFunc: PayloadFunc, IdentityHandler: IdentityHandler, Authenticator: jwtMiddleware.Authenticator, @@ -278,8 +323,9 @@ func NewJWT(dbClient *database.Client) (*JWT, error) { errInit := ret.MiddlewareInit() if errInit != nil { - return &JWT{}, fmt.Errorf("authMiddleware.MiddlewareInit() Error:" + errInit.Error()) + return &JWT{}, errors.New("authMiddleware.MiddlewareInit() Error:" + errInit.Error()) } + jwtMiddleware.Middleware = ret return jwtMiddleware, nil diff --git a/pkg/apiserver/middlewares/v1/middlewares.go b/pkg/apiserver/middlewares/v1/middlewares.go index 26879bd8e7f..a5409ea5c9e 100644 --- a/pkg/apiserver/middlewares/v1/middlewares.go +++ b/pkg/apiserver/middlewares/v1/middlewares.go @@ -14,9 +14,10 @@ func NewMiddlewares(dbClient *database.Client) (*Middlewares, error) { ret.JWT, err = NewJWT(dbClient) if err != nil { - return &Middlewares{}, err + return nil, err } ret.APIKey = NewAPIKey(dbClient) + return ret, nil } diff --git a/pkg/apiserver/middlewares/v1/ocsp.go b/pkg/apiserver/middlewares/v1/ocsp.go new file mode 100644 index 00000000000..0b6406ad0e7 --- /dev/null +++ b/pkg/apiserver/middlewares/v1/ocsp.go @@ -0,0 +1,100 @@ +package v1 + +import ( + "bytes" + "crypto" + "crypto/x509" + "io" + "net/http" + "net/url" + + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ocsp" +) + +type OCSPChecker struct { + logger *log.Entry +} + +func NewOCSPChecker(logger *log.Entry) *OCSPChecker { + return &OCSPChecker{ + logger: logger, + } +} + +func (oc *OCSPChecker) query(server string, cert *x509.Certificate, issuer *x509.Certificate) (*ocsp.Response, error) { + req, err := ocsp.CreateRequest(cert, issuer, &ocsp.RequestOptions{Hash: crypto.SHA256}) + if err != nil { + oc.logger.Errorf("TLSAuth: error creating OCSP request: %s", err) + return nil, err + } + + httpRequest, err := http.NewRequest(http.MethodPost, server, bytes.NewBuffer(req)) + if err != nil { + oc.logger.Error("TLSAuth: cannot create HTTP request for OCSP") + return nil, err + } + + ocspURL, err := url.Parse(server) + if err != nil { + oc.logger.Error("TLSAuth: cannot parse OCSP URL") + return nil, err + } + + httpRequest.Header.Add("Content-Type", "application/ocsp-request") + httpRequest.Header.Add("Accept", "application/ocsp-response") + httpRequest.Header.Add("Host", ocspURL.Host) + + httpClient := &http.Client{} + + // XXX: timeout, context? + httpResponse, err := httpClient.Do(httpRequest) + if err != nil { + oc.logger.Error("TLSAuth: cannot send HTTP request to OCSP") + return nil, err + } + defer httpResponse.Body.Close() + + output, err := io.ReadAll(httpResponse.Body) + if err != nil { + oc.logger.Error("TLSAuth: cannot read HTTP response from OCSP") + return nil, err + } + + ocspResponse, err := ocsp.ParseResponseForCert(output, cert, issuer) + + return ocspResponse, err +} + +// isRevokedBy checks if the client certificate is revoked by the issuer via any of the OCSP servers present in the certificate. +// It returns a boolean indicating if the certificate is revoked and a boolean indicating +// if the OCSP check was successful and could be cached. +func (oc *OCSPChecker) isRevokedBy(cert *x509.Certificate, issuer *x509.Certificate) (bool, bool) { + if len(cert.OCSPServer) == 0 { + oc.logger.Infof("TLSAuth: no OCSP Server present in client certificate, skipping OCSP verification") + return false, true + } + + for _, server := range cert.OCSPServer { + ocspResponse, err := oc.query(server, cert, issuer) + if err != nil { + oc.logger.Errorf("TLSAuth: error querying OCSP server %s: %s", server, err) + continue + } + + switch ocspResponse.Status { + case ocsp.Good: + return false, true + case ocsp.Revoked: + oc.logger.Errorf("TLSAuth: client certificate is revoked by server %s", server) + return true, true + case ocsp.Unknown: + log.Debugf("unknown OCSP status for server %s", server) + continue + } + } + + log.Infof("Could not get any valid OCSP response, assuming the cert is revoked") + + return true, false +} diff --git a/pkg/apiserver/middlewares/v1/tls_auth.go b/pkg/apiserver/middlewares/v1/tls_auth.go index 87ca896a8f4..673c8d0cdce 100644 --- a/pkg/apiserver/middlewares/v1/tls_auth.go +++ b/pkg/apiserver/middlewares/v1/tls_auth.go @@ -1,70 +1,24 @@ package v1 import ( - "bytes" - "crypto" "crypto/x509" + "errors" "fmt" - "io" - "net/http" - "net/url" - "os" + "slices" "time" "github.com/gin-gonic/gin" log "github.com/sirupsen/logrus" - "golang.org/x/crypto/ocsp" ) type TLSAuth struct { AllowedOUs []string - CrlPath string - revokationCache map[string]cacheEntry - cacheExpiration time.Duration + crlChecker *CRLChecker + ocspChecker *OCSPChecker + revocationCache *RevocationCache logger *log.Entry } -type cacheEntry struct { - revoked bool - err error - timestamp time.Time -} - -func (ta *TLSAuth) ocspQuery(server string, cert *x509.Certificate, issuer *x509.Certificate) (*ocsp.Response, error) { - req, err := ocsp.CreateRequest(cert, issuer, &ocsp.RequestOptions{Hash: crypto.SHA256}) - if err != nil { - ta.logger.Errorf("TLSAuth: error creating OCSP request: %s", err) - return nil, err - } - httpRequest, err := http.NewRequest(http.MethodPost, server, bytes.NewBuffer(req)) - if err != nil { - ta.logger.Error("TLSAuth: cannot create HTTP request for OCSP") - return nil, err - } - ocspURL, err := url.Parse(server) - if err != nil { - ta.logger.Error("TLSAuth: cannot parse OCSP URL") - return nil, err - } - httpRequest.Header.Add("Content-Type", "application/ocsp-request") - httpRequest.Header.Add("Accept", "application/ocsp-response") - httpRequest.Header.Add("host", ocspURL.Host) - httpClient := &http.Client{} - httpResponse, err := httpClient.Do(httpRequest) - if err != nil { - ta.logger.Error("TLSAuth: cannot send HTTP request to OCSP") - return nil, err - } - defer httpResponse.Body.Close() - output, err := io.ReadAll(httpResponse.Body) - if err != nil { - ta.logger.Error("TLSAuth: cannot read HTTP response from OCSP") - return nil, err - } - ocspResponse, err := ocsp.ParseResponseForCert(output, cert, issuer) - return ocspResponse, err -} - func (ta *TLSAuth) isExpired(cert *x509.Certificate) bool { now := time.Now().UTC() @@ -72,185 +26,158 @@ func (ta *TLSAuth) isExpired(cert *x509.Certificate) bool { ta.logger.Errorf("TLSAuth: client certificate is expired (NotAfter: %s)", cert.NotAfter.UTC()) return true } + if cert.NotBefore.UTC().After(now) { ta.logger.Errorf("TLSAuth: client certificate is not yet valid (NotBefore: %s)", cert.NotBefore.UTC()) return true } + return false } -func (ta *TLSAuth) isOCSPRevoked(cert *x509.Certificate, issuer *x509.Certificate) (bool, error) { - if cert.OCSPServer == nil || (cert.OCSPServer != nil && len(cert.OCSPServer) == 0) { - ta.logger.Infof("TLSAuth: no OCSP Server present in client certificate, skipping OCSP verification") - return false, nil - } - for _, server := range cert.OCSPServer { - ocspResponse, err := ta.ocspQuery(server, cert, issuer) - if err != nil { - ta.logger.Errorf("TLSAuth: error querying OCSP server %s: %s", server, err) - continue +// checkRevocationPath checks a single chain against OCSP and CRL +func (ta *TLSAuth) checkRevocationPath(chain []*x509.Certificate) (error, bool) { //nolint:revive + // if we ever fail to check OCSP or CRL, we should not cache the result + couldCheck := true + + // starting from the root CA and moving towards the leaf certificate, + // check for revocation of intermediates too + for i := len(chain) - 1; i > 0; i-- { + cert := chain[i-1] + issuer := chain[i] + + revokedByOCSP, checkedByOCSP := ta.ocspChecker.isRevokedBy(cert, issuer) + couldCheck = couldCheck && checkedByOCSP + + if revokedByOCSP && checkedByOCSP { + return errors.New("certificate revoked by OCSP"), couldCheck } - switch ocspResponse.Status { - case ocsp.Good: - return false, nil - case ocsp.Revoked: - return true, fmt.Errorf("client certificate is revoked by server %s", server) - case ocsp.Unknown: - log.Debugf("unknow OCSP status for server %s", server) - continue + + revokedByCRL, checkedByCRL := ta.crlChecker.isRevokedBy(cert, issuer) + couldCheck = couldCheck && checkedByCRL + + if revokedByCRL && checkedByCRL { + return errors.New("certificate revoked by CRL"), couldCheck } } - log.Infof("Could not get any valid OCSP response, assuming the cert is revoked") - return true, nil + + return nil, couldCheck } -func (ta *TLSAuth) isCRLRevoked(cert *x509.Certificate) (bool, error) { - if ta.CrlPath == "" { - ta.logger.Warn("no crl_path, skipping CRL check") - return false, nil - } - crlContent, err := os.ReadFile(ta.CrlPath) - if err != nil { - ta.logger.Warnf("could not read CRL file, skipping check: %s", err) - return false, nil - } - crl, err := x509.ParseCRL(crlContent) - if err != nil { - ta.logger.Warnf("could not parse CRL file, skipping check: %s", err) - return false, nil - } - if crl.HasExpired(time.Now().UTC()) { - ta.logger.Warn("CRL has expired, will still validate the cert against it.") - } - for _, revoked := range crl.TBSCertList.RevokedCertificates { - if revoked.SerialNumber.Cmp(cert.SerialNumber) == 0 { - return true, fmt.Errorf("client certificate is revoked by CRL") +func (ta *TLSAuth) setAllowedOu(allowedOus []string) error { + uniqueOUs := make(map[string]struct{}) + + for _, ou := range allowedOus { + // disallow empty ou + if ou == "" { + return errors.New("allowed_ou configuration contains invalid empty string") + } + + if _, exists := uniqueOUs[ou]; exists { + ta.logger.Warningf("dropping duplicate ou %s", ou) + continue } + + uniqueOUs[ou] = struct{}{} + + ta.AllowedOUs = append(ta.AllowedOUs, ou) } - return false, nil + + return nil } -func (ta *TLSAuth) isRevoked(cert *x509.Certificate, issuer *x509.Certificate) (bool, error) { - sn := cert.SerialNumber.String() - if cacheValue, ok := ta.revokationCache[sn]; ok { - if time.Now().UTC().Sub(cacheValue.timestamp) < ta.cacheExpiration { - ta.logger.Debugf("TLSAuth: using cached value for cert %s: %t | %s", sn, cacheValue.revoked, cacheValue.err) - return cacheValue.revoked, cacheValue.err - } else { - ta.logger.Debugf("TLSAuth: cached value expired, removing from cache") - delete(ta.revokationCache, sn) +func (ta *TLSAuth) checkAllowedOU(ous []string) error { + for _, ou := range ous { + if slices.Contains(ta.AllowedOUs, ou) { + return nil } - } else { - ta.logger.Tracef("TLSAuth: no cached value for cert %s", sn) } - revoked, err := ta.isOCSPRevoked(cert, issuer) - if err != nil { - ta.revokationCache[sn] = cacheEntry{ - revoked: revoked, - err: err, - timestamp: time.Now().UTC(), - } - return true, err + + return fmt.Errorf("client certificate OU %v doesn't match expected OU %v", ous, ta.AllowedOUs) +} + +func (ta *TLSAuth) ValidateCert(c *gin.Context) (string, error) { + // Checks cert validity, Returns true + CN if client cert matches requested OU + var leaf *x509.Certificate + + if c.Request.TLS == nil || len(c.Request.TLS.PeerCertificates) == 0 { + return "", errors.New("no certificate in request") } - if revoked { - ta.revokationCache[sn] = cacheEntry{ - revoked: revoked, - err: err, - timestamp: time.Now().UTC(), - } - return true, nil + + if len(c.Request.TLS.VerifiedChains) == 0 { + return "", errors.New("no verified cert in request") } - revoked, err = ta.isCRLRevoked(cert) - ta.revokationCache[sn] = cacheEntry{ - revoked: revoked, - err: err, - timestamp: time.Now().UTC(), + + // although there can be multiple chains, the leaf certificate is the same + // we take the first one + leaf = c.Request.TLS.VerifiedChains[0][0] + + if err := ta.checkAllowedOU(leaf.Subject.OrganizationalUnit); err != nil { + return "", err } - return revoked, err -} -func (ta *TLSAuth) isInvalid(cert *x509.Certificate, issuer *x509.Certificate) (bool, error) { - if ta.isExpired(cert) { - return true, nil + if ta.isExpired(leaf) { + return "", errors.New("client certificate is expired") } - revoked, err := ta.isRevoked(cert, issuer) - if err != nil { - //Fail securely, if we can't check the revocation status, let's consider the cert invalid - //We may change this in the future based on users feedback, but this seems the most sensible thing to do - return true, fmt.Errorf("could not check for client certification revocation status: %w", err) + + if validErr, cached := ta.revocationCache.Get(leaf); cached { + if validErr != nil { + return "", fmt.Errorf("(cache) %w", validErr) + } + + return leaf.Subject.CommonName, nil } - return revoked, nil -} + okToCache := true -func (ta *TLSAuth) SetAllowedOu(allowedOus []string) error { - for _, ou := range allowedOus { - //disallow empty ou - if ou == "" { - return fmt.Errorf("empty ou isn't allowed") - } - //drop & warn on duplicate ou - ok := true - for _, validOu := range ta.AllowedOUs { - if validOu == ou { - ta.logger.Warningf("dropping duplicate ou %s", ou) - ok = false - } - } - if ok { - ta.AllowedOUs = append(ta.AllowedOUs, ou) + var validErr error + + var couldCheck bool + + for _, chain := range c.Request.TLS.VerifiedChains { + validErr, couldCheck = ta.checkRevocationPath(chain) + okToCache = okToCache && couldCheck + + if validErr != nil { + break } } - return nil -} -func (ta *TLSAuth) ValidateCert(c *gin.Context) (bool, string, error) { - //Checks cert validity, Returns true + CN if client cert matches requested OU - var clientCert *x509.Certificate - if c.Request.TLS == nil || len(c.Request.TLS.PeerCertificates) == 0 { - //do not error if it's not TLS or there are no peer certs - return false, "", nil + if okToCache { + ta.revocationCache.Set(leaf, validErr) } - if len(c.Request.TLS.VerifiedChains) > 0 { - validOU := false - clientCert = c.Request.TLS.VerifiedChains[0][0] - for _, ou := range clientCert.Subject.OrganizationalUnit { - for _, allowedOu := range ta.AllowedOUs { - if allowedOu == ou { - validOU = true - break - } - } - } - if !validOU { - return false, "", fmt.Errorf("client certificate OU (%v) doesn't match expected OU (%v)", - clientCert.Subject.OrganizationalUnit, ta.AllowedOUs) - } - revoked, err := ta.isInvalid(clientCert, c.Request.TLS.VerifiedChains[0][1]) - if err != nil { - ta.logger.Errorf("TLSAuth: error checking if client certificate is revoked: %s", err) - return false, "", fmt.Errorf("could not check for client certification revokation status: %w", err) - } - if revoked { - return false, "", fmt.Errorf("client certificate is revoked") - } - ta.logger.Debugf("client OU %v is allowed vs required OU %v", clientCert.Subject.OrganizationalUnit, ta.AllowedOUs) - return true, clientCert.Subject.CommonName, nil + if validErr != nil { + return "", validErr } - return false, "", fmt.Errorf("no verified cert in request") + + return leaf.Subject.CommonName, nil } func NewTLSAuth(allowedOus []string, crlPath string, cacheExpiration time.Duration, logger *log.Entry) (*TLSAuth, error) { + var err error + + cache := NewRevocationCache(cacheExpiration, logger) + ta := &TLSAuth{ - revokationCache: map[string]cacheEntry{}, - cacheExpiration: cacheExpiration, - CrlPath: crlPath, + revocationCache: cache, + ocspChecker: NewOCSPChecker(logger), logger: logger, } - err := ta.SetAllowedOu(allowedOus) - if err != nil { + + switch crlPath { + case "": + logger.Info("no crl_path, skipping CRL checks") + default: + ta.crlChecker, err = NewCRLChecker(crlPath, cache.Empty, logger) + if err != nil { + return nil, err + } + } + + if err := ta.setAllowedOu(allowedOus); err != nil { return nil, err } + return ta, nil } diff --git a/pkg/apiserver/papi.go b/pkg/apiserver/papi.go index 96bb9251b2f..7dd6b346aa9 100644 --- a/pkg/apiserver/papi.go +++ b/pkg/apiserver/papi.go @@ -3,6 +3,7 @@ package apiserver import ( "context" "encoding/json" + "errors" "fmt" "net/http" "sync" @@ -11,7 +12,7 @@ import ( log "github.com/sirupsen/logrus" "gopkg.in/tomb.v2" - "github.com/crowdsecurity/go-cs-lib/pkg/trace" + "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/crowdsec/pkg/apiclient" "github.com/crowdsecurity/crowdsec/pkg/csconfig" @@ -21,21 +22,15 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -var ( - SyncInterval = time.Second * 10 -) +var SyncInterval = time.Second * 10 -const ( - PapiPullKey = "papi:last_pull" -) +const PapiPullKey = "papi:last_pull" -var ( - operationMap = map[string]func(*Message, *Papi, bool) error{ - "decision": DecisionCmd, - "alert": AlertCmd, - "management": ManagementCmd, - } -) +var operationMap = map[string]func(*Message, *Papi, bool) error{ + "decision": DecisionCmd, + "alert": AlertCmd, + "management": ManagementCmd, +} type Header struct { OperationType string `json:"operation_type"` @@ -87,21 +82,21 @@ type PapiPermCheckSuccess struct { } func NewPAPI(apic *apic, dbClient *database.Client, consoleConfig *csconfig.ConsoleConfig, logLevel log.Level) (*Papi, error) { - logger := log.New() if err := types.ConfigureLogger(logger); err != nil { - return &Papi{}, fmt.Errorf("creating papi logger: %s", err) + return &Papi{}, fmt.Errorf("creating papi logger: %w", err) } + logger.SetLevel(logLevel) papiUrl := *apic.apiClient.PapiURL papiUrl.Path = fmt.Sprintf("%s%s", types.PAPIVersion, types.PAPIPollUrl) + longPollClient, err := longpollclient.NewLongPollClient(longpollclient.LongPollClientConfig{ Url: papiUrl, Logger: logger, HttpClient: apic.apiClient.GetClient(), }) - if err != nil { return &Papi{}, fmt.Errorf("failed to create PAPI client: %w", err) } @@ -132,55 +127,69 @@ func NewPAPI(apic *apic, dbClient *database.Client, consoleConfig *csconfig.Cons func (p *Papi) handleEvent(event longpollclient.Event, sync bool) error { logger := p.Logger.WithField("request-id", event.RequestId) logger.Debugf("message received: %+v", event.Data) + message := &Message{} if err := json.Unmarshal([]byte(event.Data), message); err != nil { - return fmt.Errorf("polling papi message format is not compatible: %+v: %s", event.Data, err) + return fmt.Errorf("polling papi message format is not compatible: %+v: %w", event.Data, err) } + if message.Header == nil { - return fmt.Errorf("no header in message, skipping") + return errors.New("no header in message, skipping") } + if message.Header.Source == nil { - return fmt.Errorf("no source user in header message, skipping") + return errors.New("no source user in header message, skipping") } - if operationFunc, ok := operationMap[message.Header.OperationType]; ok { - logger.Debugf("Calling operation '%s'", message.Header.OperationType) - err := operationFunc(message, p, sync) - if err != nil { - return fmt.Errorf("'%s %s failed: %s", message.Header.OperationType, message.Header.OperationCmd, err) - } - } else { + operationFunc, ok := operationMap[message.Header.OperationType] + if !ok { return fmt.Errorf("operation '%s' unknown, continue", message.Header.OperationType) } + + logger.Debugf("Calling operation '%s'", message.Header.OperationType) + + err := operationFunc(message, p, sync) + if err != nil { + return fmt.Errorf("'%s %s failed: %w", message.Header.OperationType, message.Header.OperationCmd, err) + } + return nil } -func (p *Papi) GetPermissions() (PapiPermCheckSuccess, error) { +func (p *Papi) GetPermissions(ctx context.Context) (PapiPermCheckSuccess, error) { httpClient := p.apiClient.GetClient() papiCheckUrl := fmt.Sprintf("%s%s%s", p.URL, types.PAPIVersion, types.PAPIPermissionsUrl) - req, err := http.NewRequest(http.MethodGet, papiCheckUrl, nil) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, papiCheckUrl, nil) if err != nil { - return PapiPermCheckSuccess{}, fmt.Errorf("failed to create request : %s", err) + return PapiPermCheckSuccess{}, fmt.Errorf("failed to create request: %w", err) } + resp, err := httpClient.Do(req) if err != nil { - log.Fatalf("failed to get response : %s", err) + return PapiPermCheckSuccess{}, fmt.Errorf("failed to get response: %w", err) } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { errResp := PapiPermCheckError{} + err = json.NewDecoder(resp.Body).Decode(&errResp) if err != nil { - return PapiPermCheckSuccess{}, fmt.Errorf("failed to decode response : %s", err) + return PapiPermCheckSuccess{}, fmt.Errorf("failed to decode response: %w", err) } + return PapiPermCheckSuccess{}, fmt.Errorf("unable to query PAPI : %s (%d)", errResp.Error, resp.StatusCode) } + respBody := PapiPermCheckSuccess{} + err = json.NewDecoder(resp.Body).Decode(&respBody) if err != nil { - return PapiPermCheckSuccess{}, fmt.Errorf("failed to decode response : %s", err) + return PapiPermCheckSuccess{}, fmt.Errorf("failed to decode response: %w", err) } + return respBody, nil } @@ -202,56 +211,64 @@ func (p *Papi) PullOnce(since time.Time, sync bool) error { return err } - reversedEvents := reverse(events) //PAPI sends events in the reverse order, which is not an issue when pulling them in real time, but here we need the correct order + reversedEvents := reverse(events) // PAPI sends events in the reverse order, which is not an issue when pulling them in real time, but here we need the correct order eventsCount := len(events) p.Logger.Infof("received %d events", eventsCount) + for i, event := range reversedEvents { if err := p.handleEvent(event, sync); err != nil { p.Logger.WithField("request-id", event.RequestId).Errorf("failed to handle event: %s", err) } + p.Logger.Debugf("handled event %d/%d", i, eventsCount) } + p.Logger.Debugf("finished handling events") - //Don't update the timestamp in DB, as a "real" LAPI might be running - //Worst case, crowdsec will receive a few duplicated events and will discard them + // Don't update the timestamp in DB, as a "real" LAPI might be running + // Worst case, crowdsec will receive a few duplicated events and will discard them return nil } // PullPAPI is the long polling client for real-time decisions from PAPI -func (p *Papi) Pull() error { +func (p *Papi) Pull(ctx context.Context) error { defer trace.CatchPanic("lapi/PullPAPI") p.Logger.Infof("Starting Polling API Pull") lastTimestamp := time.Time{} - lastTimestampStr, err := p.DBClient.GetConfigItem(PapiPullKey) + + lastTimestampStr, err := p.DBClient.GetConfigItem(ctx, PapiPullKey) if err != nil { p.Logger.Warningf("failed to get last timestamp for papi pull: %s", err) } - //value doesn't exist, it's first time we're pulling + + // value doesn't exist, it's first time we're pulling if lastTimestampStr == nil { binTime, err := lastTimestamp.MarshalText() if err != nil { - return fmt.Errorf("failed to marshal last timestamp: %w", err) + return fmt.Errorf("failed to serialize last timestamp: %w", err) } - if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil { + + if err := p.DBClient.SetConfigItem(ctx, PapiPullKey, string(binTime)); err != nil { p.Logger.Errorf("error setting papi pull last key: %s", err) } else { p.Logger.Debugf("config item '%s' set in database with value '%s'", PapiPullKey, string(binTime)) } } else { if err := lastTimestamp.UnmarshalText([]byte(*lastTimestampStr)); err != nil { - return fmt.Errorf("failed to unmarshal last timestamp: %w", err) + return fmt.Errorf("failed to parse last timestamp: %w", err) } } p.Logger.Infof("Starting PAPI pull (since:%s)", lastTimestamp) + for event := range p.Client.Start(lastTimestamp) { logger := p.Logger.WithField("request-id", event.RequestId) - //update last timestamp in database + // update last timestamp in database newTime := time.Now().UTC() + binTime, err := newTime.MarshalText() if err != nil { - return fmt.Errorf("failed to marshal last timestamp: %w", err) + return fmt.Errorf("failed to serialize last timestamp: %w", err) } err = p.handleEvent(event, false) @@ -260,13 +277,13 @@ func (p *Papi) Pull() error { continue } - if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil { + if err := p.DBClient.SetConfigItem(ctx, PapiPullKey, string(binTime)); err != nil { return fmt.Errorf("failed to update last timestamp: %w", err) - } else { - logger.Debugf("set last timestamp to %s", newTime) } + logger.Debugf("set last timestamp to %s", newTime) } + return nil } @@ -274,6 +291,7 @@ func (p *Papi) SyncDecisions() error { defer trace.CatchPanic("lapi/syncDecisionsToCAPI") var cache models.DecisionsDeleteRequest + ticker := time.NewTicker(p.SyncInterval) p.Logger.Infof("Start decisions sync to CrowdSec Central API (interval: %s)", p.SyncInterval) @@ -281,10 +299,13 @@ func (p *Papi) SyncDecisions() error { select { case <-p.syncTomb.Dying(): // if one apic routine is dying, do we kill the others? p.Logger.Infof("sync decisions tomb is dying, sending cache (%d elements) before exiting", len(cache)) + if len(cache) == 0 { return nil } + go p.SendDeletedDecisions(&cache) + return nil case <-ticker.C: if len(cache) > 0 { @@ -293,15 +314,19 @@ func (p *Papi) SyncDecisions() error { cache = make([]models.DecisionsDeleteRequestItem, 0) p.mu.Unlock() p.Logger.Infof("sync decisions: %d deleted decisions to push", len(cacheCopy)) + go p.SendDeletedDecisions(&cacheCopy) } case deletedDecisions := <-p.Channels.DeleteDecisionChannel: if (p.consoleConfig.ShareManualDecisions != nil && *p.consoleConfig.ShareManualDecisions) || (p.consoleConfig.ConsoleManagement != nil && *p.consoleConfig.ConsoleManagement) { var tmpDecisions []models.DecisionsDeleteRequestItem + p.Logger.Debugf("%d decisions deletion to add in cache", len(deletedDecisions)) + for _, decision := range deletedDecisions { tmpDecisions = append(tmpDecisions, models.DecisionsDeleteRequestItem(decision.UUID)) } + p.mu.Lock() cache = append(cache, tmpDecisions...) p.mu.Unlock() @@ -311,33 +336,42 @@ func (p *Papi) SyncDecisions() error { } func (p *Papi) SendDeletedDecisions(cacheOrig *models.DecisionsDeleteRequest) { - - var cache []models.DecisionsDeleteRequestItem = *cacheOrig - var send models.DecisionsDeleteRequest + var ( + cache []models.DecisionsDeleteRequestItem = *cacheOrig + send models.DecisionsDeleteRequest + ) bulkSize := 50 pageStart := 0 pageEnd := bulkSize + for { if pageEnd >= len(cache) { send = cache[pageStart:] ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, _, err := p.apiClient.DecisionDelete.Add(ctx, &send) if err != nil { p.Logger.Errorf("sending deleted decisions to central API: %s", err) return } + break } + send = cache[pageStart:pageEnd] ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, _, err := p.apiClient.DecisionDelete.Add(ctx, &send) if err != nil { - //we log it here as well, because the return value of func might be discarded + // we log it here as well, because the return value of func might be discarded p.Logger.Errorf("sending deleted decisions to central API: %s", err) } + pageStart += bulkSize pageEnd += bulkSize } diff --git a/pkg/apiserver/papi_cmd.go b/pkg/apiserver/papi_cmd.go index 4cb9603b7ab..78f5dc9b0fe 100644 --- a/pkg/apiserver/papi_cmd.go +++ b/pkg/apiserver/papi_cmd.go @@ -1,16 +1,18 @@ package apiserver import ( + "context" "encoding/json" "fmt" "time" log "github.com/sirupsen/logrus" - "github.com/crowdsecurity/go-cs-lib/pkg/ptr" + "github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/crowdsec/pkg/apiclient" "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/crowdsecurity/crowdsec/pkg/modelscapi" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -19,20 +21,44 @@ type deleteDecisions struct { Decisions []string `json:"decisions"` } +type blocklistLink struct { + // blocklist name + Name string `json:"name"` + // blocklist url + Url string `json:"url"` + // blocklist remediation + Remediation string `json:"remediation"` + // blocklist scope + Scope string `json:"scope,omitempty"` + // blocklist duration + Duration string `json:"duration,omitempty"` +} + +type forcePull struct { + Blocklist *blocklistLink `json:"blocklist,omitempty"` +} + +type listUnsubscribe struct { + Name string `json:"name"` +} + func DecisionCmd(message *Message, p *Papi, sync bool) error { + ctx := context.TODO() + switch message.Header.OperationCmd { case "delete": - data, err := json.Marshal(message.Data) if err != nil { return err } + UUIDs := make([]string, 0) deleteDecisionMsg := deleteDecisions{ Decisions: make([]string, 0), } + if err := json.Unmarshal(data, &deleteDecisionMsg); err != nil { - return fmt.Errorf("message for '%s' contains bad data format: %s", message.Header.OperationType, err) + return fmt.Errorf("message for '%s' contains bad data format: %w", message.Header.OperationType, err) } UUIDs = append(UUIDs, deleteDecisionMsg.Decisions...) @@ -40,11 +66,14 @@ func DecisionCmd(message *Message, p *Papi, sync bool) error { filter := make(map[string][]string) filter["uuid"] = UUIDs - _, deletedDecisions, err := p.DBClient.SoftDeleteDecisionsWithFilter(filter) + + _, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(ctx, filter) if err != nil { - return fmt.Errorf("unable to delete decisions %+v : %s", UUIDs, err) + return fmt.Errorf("unable to expire decisions %+v: %w", UUIDs, err) } + decisions := make([]*models.Decision, 0) + for _, deletedDecision := range deletedDecisions { log.Infof("Decision from '%s' for '%s' (%s) has been deleted", deletedDecision.Origin, deletedDecision.Value, deletedDecision.Type) dec := &models.Decision{ @@ -68,12 +97,15 @@ func DecisionCmd(message *Message, p *Papi, sync bool) error { } func AlertCmd(message *Message, p *Papi, sync bool) error { + ctx := context.TODO() + switch message.Header.OperationCmd { case "add": data, err := json.Marshal(message.Data) if err != nil { return err } + alert := &models.Alert{} if err := json.Unmarshal(data, alert); err != nil { @@ -87,10 +119,12 @@ func AlertCmd(message *Message, p *Papi, sync bool) error { log.Warnf("Alert %d has no StartAt, setting it to now", alert.ID) alert.StartAt = ptr.Of(time.Now().UTC().Format(time.RFC3339)) } + if alert.StopAt == nil || *alert.StopAt == "" { log.Warnf("Alert %d has no StopAt, setting it to now", alert.ID) alert.StopAt = ptr.Of(time.Now().UTC().Format(time.RFC3339)) } + alert.EventsCount = ptr.Of(int32(0)) alert.Capacity = ptr.Of(int32(0)) alert.Leakspeed = ptr.Of("") @@ -101,26 +135,29 @@ func AlertCmd(message *Message, p *Papi, sync bool) error { alert.Scenario = ptr.Of("") alert.Source = &models.Source{} - //if we're setting Source.Scope to types.ConsoleOrigin, it messes up the alert's value + // if we're setting Source.Scope to types.ConsoleOrigin, it messes up the alert's value if len(alert.Decisions) >= 1 { alert.Source.Scope = alert.Decisions[0].Scope alert.Source.Value = alert.Decisions[0].Value } else { log.Warningf("No decision found in alert for Polling API (%s : %s)", message.Header.Source.User, message.Header.Message) + alert.Source.Scope = ptr.Of(types.ConsoleOrigin) alert.Source.Value = &message.Header.Source.User } + alert.Scenario = &message.Header.Message for _, decision := range alert.Decisions { if *decision.Scenario == "" { decision.Scenario = &message.Header.Message } + log.Infof("Adding decision for '%s' with UUID: %s", *decision.Value, decision.UUID) } - //use a different method : alert and/or decision might already be partially present in the database - _, err = p.DBClient.CreateOrUpdateAlert("", alert) + // use a different method: alert and/or decision might already be partially present in the database + _, err = p.DBClient.CreateOrUpdateAlert(ctx, "", alert) if err != nil { log.Errorf("Failed to create alerts in DB: %s", err) } else { @@ -135,22 +172,82 @@ func AlertCmd(message *Message, p *Papi, sync bool) error { } func ManagementCmd(message *Message, p *Papi, sync bool) error { + ctx := context.TODO() + if sync { - log.Infof("Ignoring management command from PAPI in sync mode") + p.Logger.Infof("Ignoring management command from PAPI in sync mode") return nil } + switch message.Header.OperationCmd { + case "blocklist_unsubscribe": + data, err := json.Marshal(message.Data) + if err != nil { + return err + } + + unsubscribeMsg := listUnsubscribe{} + if err := json.Unmarshal(data, &unsubscribeMsg); err != nil { + return fmt.Errorf("message for '%s' contains bad data format: %w", message.Header.OperationType, err) + } + + if unsubscribeMsg.Name == "" { + return fmt.Errorf("message for '%s' contains bad data format: missing blocklist name", message.Header.OperationType) + } + + p.Logger.Infof("Received blocklist_unsubscribe command from PAPI, unsubscribing from blocklist %s", unsubscribeMsg.Name) + + filter := make(map[string][]string) + filter["origin"] = []string{types.ListOrigin} + filter["scenario"] = []string{unsubscribeMsg.Name} + + _, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(ctx, filter) + if err != nil { + return fmt.Errorf("unable to expire decisions for list %s : %w", unsubscribeMsg.Name, err) + } + + p.Logger.Infof("deleted %d decisions for list %s", len(deletedDecisions), unsubscribeMsg.Name) case "reauth": - log.Infof("Received reauth command from PAPI, resetting token") + p.Logger.Infof("Received reauth command from PAPI, resetting token") p.apiClient.GetClient().Transport.(*apiclient.JWTTransport).ResetToken() case "force_pull": - log.Infof("Received force_pull command from PAPI, pulling community and 3rd-party blocklists") - err := p.apic.PullTop(true) + data, err := json.Marshal(message.Data) if err != nil { - return fmt.Errorf("failed to force pull operation: %s", err) + return err + } + + forcePullMsg := forcePull{} + + if err := json.Unmarshal(data, &forcePullMsg); err != nil { + return fmt.Errorf("message for '%s' contains bad data format: %w", message.Header.OperationType, err) + } + + ctx := context.TODO() + + if forcePullMsg.Blocklist == nil { + p.Logger.Infof("Received force_pull command from PAPI, pulling community and 3rd-party blocklists") + + err = p.apic.PullTop(ctx, true) + if err != nil { + return fmt.Errorf("failed to force pull operation: %w", err) + } + } else { + p.Logger.Infof("Received force_pull command from PAPI, pulling blocklist %s", forcePullMsg.Blocklist.Name) + + err = p.apic.PullBlocklist(ctx, &modelscapi.BlocklistLink{ + Name: &forcePullMsg.Blocklist.Name, + URL: &forcePullMsg.Blocklist.Url, + Remediation: &forcePullMsg.Blocklist.Remediation, + Scope: &forcePullMsg.Blocklist.Scope, + Duration: &forcePullMsg.Blocklist.Duration, + }, true) + if err != nil { + return fmt.Errorf("failed to force pull operation: %w", err) + } } default: return fmt.Errorf("unknown command '%s' for operation type '%s'", message.Header.OperationCmd, message.Header.OperationType) } + return nil } diff --git a/pkg/apiserver/usage_metrics_test.go b/pkg/apiserver/usage_metrics_test.go new file mode 100644 index 00000000000..32aeb7d9a5a --- /dev/null +++ b/pkg/apiserver/usage_metrics_test.go @@ -0,0 +1,388 @@ +package apiserver + +import ( + "context" + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/crowdsecurity/crowdsec/pkg/database" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" +) + +func TestLPMetrics(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + body string + expectedStatusCode int + expectedResponse string + expectedMetricsCount int + expectedOSName string + expectedOSVersion string + expectedFeatureFlags string + authType string + }{ + { + name: "empty metrics for LP", + body: `{ + }`, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: "Missing log processor data", + authType: PASSWORD, + }, + { + name: "basic metrics with empty dynamic metrics for LP", + body: ` +{ + "log_processors": [ + { + "version": "1.42", + "os": {"name":"foo", "version": "42"}, + "utc_startup_timestamp": 42, + "metrics": [], + "feature_flags": ["a", "b", "c"], + "datasources": {"file": 42}, + "hub_items": {} + } + ] +}`, + expectedStatusCode: http.StatusCreated, + expectedMetricsCount: 1, + expectedResponse: "", + expectedOSName: "foo", + expectedOSVersion: "42", + expectedFeatureFlags: "a,b,c", + authType: PASSWORD, + }, + { + name: "basic metrics with dynamic metrics for LP", + body: ` +{ + "log_processors": [ + { + "version": "1.42", + "os": {"name":"foo", "version": "42"}, + "utc_startup_timestamp": 42, + "metrics": [{"meta":{"utc_now_timestamp":42, "window_size_seconds": 42}, "items": [{"name": "foo", "value": 42, "unit": "bla"}] }, {"meta":{"utc_now_timestamp":43, "window_size_seconds": 42}, "items": [{"name": "foo", "value": 42, "unit": "bla"}] }], + "feature_flags": ["a", "b", "c"], + "datasources": {"file": 42}, + "hub_items": {} + } + ] +}`, + expectedStatusCode: http.StatusCreated, + expectedMetricsCount: 1, + expectedResponse: "", + expectedOSName: "foo", + expectedOSVersion: "42", + expectedFeatureFlags: "a,b,c", + authType: PASSWORD, + }, + { + name: "wrong auth type for LP", + body: ` +{ + "log_processors": [ + { + "version": "1.42", + "os": {"name":"foo", "version": "42"}, + "utc_startup_timestamp": 42, + "metrics": [], + "feature_flags": ["a", "b", "c"], + "datasources": {"file": 42}, + "hub_items": {} + } + ] +}`, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: "Missing remediation component data", + authType: APIKEY, + }, + { + name: "missing OS field for LP", + body: ` +{ + "log_processors": [ + { + "version": "1.42", + "utc_startup_timestamp": 42, + "metrics": [], + "feature_flags": ["a", "b", "c"], + "datasources": {"file": 42}, + "hub_items": {} + } + ] +}`, + expectedStatusCode: http.StatusCreated, + expectedResponse: "", + expectedMetricsCount: 1, + expectedFeatureFlags: "a,b,c", + authType: PASSWORD, + }, + { + name: "missing datasources for LP", + body: ` +{ + "log_processors": [ + { + "version": "1.42", + "os": {"name":"foo", "version": "42"}, + "utc_startup_timestamp": 42, + "metrics": [], + "feature_flags": ["a", "b", "c"], + "hub_items": {} + } + ] +}`, + expectedStatusCode: http.StatusUnprocessableEntity, + expectedResponse: "log_processors.0.datasources in body is required", + authType: PASSWORD, + }, + { + name: "missing feature flags for LP", + body: ` +{ + "log_processors": [ + { + "version": "1.42", + "os": {"name":"foo", "version": "42"}, + "utc_startup_timestamp": 42, + "metrics": [], + "datasources": {"file": 42}, + "hub_items": {} + } + ] +}`, + expectedStatusCode: http.StatusCreated, + expectedMetricsCount: 1, + expectedOSName: "foo", + expectedOSVersion: "42", + authType: PASSWORD, + }, + { + name: "missing OS name", + body: ` +{ + "log_processors": [ + { + "version": "1.42", + "os": {"version": "42"}, + "utc_startup_timestamp": 42, + "metrics": [], + "feature_flags": ["a", "b", "c"], + "datasources": {"file": 42}, + "hub_items": {} + } + ] +}`, + expectedStatusCode: http.StatusUnprocessableEntity, + expectedResponse: "log_processors.0.os.name in body is required", + authType: PASSWORD, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lapi := SetupLAPITest(t, ctx) + + dbClient, err := database.NewClient(ctx, lapi.DBConfig) + if err != nil { + t.Fatalf("unable to create database client: %s", err) + } + + w := lapi.RecordResponse(t, ctx, http.MethodPost, "/v1/usage-metrics", strings.NewReader(tt.body), tt.authType) + + assert.Equal(t, tt.expectedStatusCode, w.Code) + assert.Contains(t, w.Body.String(), tt.expectedResponse) + + machine, _ := dbClient.QueryMachineByID(ctx, "test") + metrics, _ := dbClient.GetLPUsageMetricsByMachineID(ctx, "test") + + assert.Len(t, metrics, tt.expectedMetricsCount) + assert.Equal(t, tt.expectedOSName, machine.Osname) + assert.Equal(t, tt.expectedOSVersion, machine.Osversion) + assert.Equal(t, tt.expectedFeatureFlags, machine.Featureflags) + + if len(metrics) > 0 { + assert.Equal(t, "test", metrics[0].GeneratedBy) + assert.Equal(t, metric.GeneratedType("LP"), metrics[0].GeneratedType) + } + }) + } +} + +func TestRCMetrics(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + body string + expectedStatusCode int + expectedResponse string + expectedMetricsCount int + expectedOSName string + expectedOSVersion string + expectedFeatureFlags string + authType string + }{ + { + name: "empty metrics for RC", + body: `{ + }`, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: "Missing remediation component data", + authType: APIKEY, + }, + { + name: "basic metrics with empty dynamic metrics for RC", + body: ` +{ + "remediation_components": [ + { + "version": "1.42", + "os": {"name":"foo", "version": "42"}, + "utc_startup_timestamp": 42, + "metrics": [], + "feature_flags": ["a", "b", "c"] + } + ] +}`, + expectedStatusCode: http.StatusCreated, + expectedMetricsCount: 1, + expectedResponse: "", + expectedOSName: "foo", + expectedOSVersion: "42", + expectedFeatureFlags: "a,b,c", + authType: APIKEY, + }, + { + name: "basic metrics with dynamic metrics for RC", + body: ` +{ + "remediation_components": [ + { + "version": "1.42", + "os": {"name":"foo", "version": "42"}, + "utc_startup_timestamp": 42, + "metrics": [{"meta":{"utc_now_timestamp":42, "window_size_seconds": 42}, "items": [{"name": "foo", "value": 42, "unit": "bla"}] }, {"meta":{"utc_now_timestamp":43, "window_size_seconds": 42}, "items": [{"name": "foo", "value": 42, "unit": "bla"}] }], + "feature_flags": ["a", "b", "c"] + } + ] +}`, + expectedStatusCode: http.StatusCreated, + expectedMetricsCount: 1, + expectedResponse: "", + expectedOSName: "foo", + expectedOSVersion: "42", + expectedFeatureFlags: "a,b,c", + authType: APIKEY, + }, + { + name: "wrong auth type for RC", + body: ` +{ + "remediation_components": [ + { + "version": "1.42", + "os": {"name":"foo", "version": "42"}, + "utc_startup_timestamp": 42, + "metrics": [], + "feature_flags": ["a", "b", "c"] + } + ] +}`, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: "Missing log processor data", + authType: PASSWORD, + }, + { + name: "missing OS field for RC", + body: ` +{ + "remediation_components": [ + { + "version": "1.42", + "utc_startup_timestamp": 42, + "metrics": [], + "feature_flags": ["a", "b", "c"] + } + ] +}`, + expectedStatusCode: http.StatusCreated, + expectedResponse: "", + expectedMetricsCount: 1, + expectedFeatureFlags: "a,b,c", + authType: APIKEY, + }, + { + name: "missing feature flags for RC", + body: ` +{ + "remediation_components": [ + { + "version": "1.42", + "os": {"name":"foo", "version": "42"}, + "utc_startup_timestamp": 42, + "metrics": [] + } + ] +}`, + expectedStatusCode: http.StatusCreated, + expectedMetricsCount: 1, + expectedOSName: "foo", + expectedOSVersion: "42", + authType: APIKEY, + }, + { + name: "missing OS name", + body: ` +{ + "remediation_components": [ + { + "version": "1.42", + "os": {"version": "42"}, + "utc_startup_timestamp": 42, + "metrics": [], + "feature_flags": ["a", "b", "c"] + } + ] +}`, + expectedStatusCode: http.StatusUnprocessableEntity, + expectedResponse: "remediation_components.0.os.name in body is required", + authType: APIKEY, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lapi := SetupLAPITest(t, ctx) + + dbClient, err := database.NewClient(ctx, lapi.DBConfig) + if err != nil { + t.Fatalf("unable to create database client: %s", err) + } + + w := lapi.RecordResponse(t, ctx, http.MethodPost, "/v1/usage-metrics", strings.NewReader(tt.body), tt.authType) + + assert.Equal(t, tt.expectedStatusCode, w.Code) + assert.Contains(t, w.Body.String(), tt.expectedResponse) + + bouncer, _ := dbClient.SelectBouncerByName(ctx, "test") + metrics, _ := dbClient.GetBouncerUsageMetricsByName(ctx, "test") + + assert.Len(t, metrics, tt.expectedMetricsCount) + assert.Equal(t, tt.expectedOSName, bouncer.Osname) + assert.Equal(t, tt.expectedOSVersion, bouncer.Osversion) + assert.Equal(t, tt.expectedFeatureFlags, bouncer.Featureflags) + + if len(metrics) > 0 { + assert.Equal(t, "test", metrics[0].GeneratedBy) + assert.Equal(t, metric.GeneratedType("RC"), metrics[0].GeneratedType) + } + }) + } +} diff --git a/pkg/apiserver/utils.go b/pkg/apiserver/utils.go deleted file mode 100644 index 409d79b011a..00000000000 --- a/pkg/apiserver/utils.go +++ /dev/null @@ -1,27 +0,0 @@ -package apiserver - -import ( - "crypto/tls" - "fmt" - - log "github.com/sirupsen/logrus" -) - -func getTLSAuthType(authType string) (tls.ClientAuthType, error) { - switch authType { - case "NoClientCert": - return tls.NoClientCert, nil - case "RequestClientCert": - log.Warn("RequestClientCert is insecure, please use VerifyClientCertIfGiven or RequireAndVerifyClientCert instead") - return tls.RequestClientCert, nil - case "RequireAnyClientCert": - log.Warn("RequireAnyClientCert is insecure, please use VerifyClientCertIfGiven or RequireAndVerifyClientCert instead") - return tls.RequireAnyClientCert, nil - case "VerifyClientCertIfGiven": - return tls.VerifyClientCertIfGiven, nil - case "RequireAndVerifyClientCert": - return tls.RequireAndVerifyClientCert, nil - default: - return 0, fmt.Errorf("unknown TLS client_verification value: %s", authType) - } -} diff --git a/pkg/appsec/appsec.go b/pkg/appsec/appsec.go new file mode 100644 index 00000000000..30784b23db0 --- /dev/null +++ b/pkg/appsec/appsec.go @@ -0,0 +1,634 @@ +package appsec + +import ( + "errors" + "fmt" + "net/http" + "os" + "regexp" + + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" + log "github.com/sirupsen/logrus" + "gopkg.in/yaml.v2" + + "github.com/crowdsecurity/crowdsec/pkg/cwhub" + "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +type Hook struct { + Filter string `yaml:"filter"` + FilterExpr *vm.Program `yaml:"-"` + + OnSuccess string `yaml:"on_success"` + Apply []string `yaml:"apply"` + ApplyExpr []*vm.Program `yaml:"-"` +} + +const ( + hookOnLoad = iota + hookPreEval + hookPostEval + hookOnMatch +) + +const ( + BanRemediation = "ban" + CaptchaRemediation = "captcha" + AllowRemediation = "allow" +) + +func (h *Hook) Build(hookStage int) error { + ctx := map[string]interface{}{} + switch hookStage { + case hookOnLoad: + ctx = GetOnLoadEnv(&AppsecRuntimeConfig{}) + case hookPreEval: + ctx = GetPreEvalEnv(&AppsecRuntimeConfig{}, &ParsedRequest{}) + case hookPostEval: + ctx = GetPostEvalEnv(&AppsecRuntimeConfig{}, &ParsedRequest{}) + case hookOnMatch: + ctx = GetOnMatchEnv(&AppsecRuntimeConfig{}, &ParsedRequest{}, types.Event{}) + } + opts := exprhelpers.GetExprOptions(ctx) + if h.Filter != "" { + program, err := expr.Compile(h.Filter, opts...) // FIXME: opts + if err != nil { + return fmt.Errorf("unable to compile filter %s : %w", h.Filter, err) + } + h.FilterExpr = program + } + for _, apply := range h.Apply { + program, err := expr.Compile(apply, opts...) + if err != nil { + return fmt.Errorf("unable to compile apply %s : %w", apply, err) + } + h.ApplyExpr = append(h.ApplyExpr, program) + } + return nil +} + +type AppsecTempResponse struct { + InBandInterrupt bool + OutOfBandInterrupt bool + Action string // allow, deny, captcha, log + UserHTTPResponseCode int // The response code to send to the user + BouncerHTTPResponseCode int // The response code to send to the remediation component + SendEvent bool // do we send an internal event on rule match + SendAlert bool // do we send an alert on rule match +} + +type AppsecSubEngineOpts struct { + DisableBodyInspection bool `yaml:"disable_body_inspection"` + RequestBodyInMemoryLimit *int `yaml:"request_body_in_memory_limit"` +} + +// runtime version of AppsecConfig +type AppsecRuntimeConfig struct { + Name string + OutOfBandRules []AppsecCollection + + InBandRules []AppsecCollection + + DefaultRemediation string + RemediationByTag map[string]string // Also used for ByName, as the name (for modsec rules) is a tag crowdsec-NAME + RemediationById map[int]string + CompiledOnLoad []Hook + CompiledPreEval []Hook + CompiledPostEval []Hook + CompiledOnMatch []Hook + CompiledVariablesTracking []*regexp.Regexp + Config *AppsecConfig + // CorazaLogger debuglog.Logger + + // those are ephemeral, created/destroyed with every req + OutOfBandTx ExtendedTransaction // is it a good idea ? + InBandTx ExtendedTransaction // is it a good idea ? + Response AppsecTempResponse + // should we store matched rules here ? + + Logger *log.Entry + + // Set by on_load to ignore some rules on loading + DisabledInBandRuleIds []int + DisabledInBandRulesTags []string // Also used for ByName, as the name (for modsec rules) is a tag crowdsec-NAME + + DisabledOutOfBandRuleIds []int + DisabledOutOfBandRulesTags []string // Also used for ByName, as the name (for modsec rules) is a tag crowdsec-NAME +} + +type AppsecConfig struct { + Name string `yaml:"name"` + OutOfBandRules []string `yaml:"outofband_rules"` + InBandRules []string `yaml:"inband_rules"` + DefaultRemediation string `yaml:"default_remediation"` + DefaultPassAction string `yaml:"default_pass_action"` + BouncerBlockedHTTPCode int `yaml:"blocked_http_code"` // returned to the bouncer + BouncerPassedHTTPCode int `yaml:"passed_http_code"` // returned to the bouncer + UserBlockedHTTPCode int `yaml:"user_blocked_http_code"` // returned to the user + UserPassedHTTPCode int `yaml:"user_passed_http_code"` // returned to the user + + OnLoad []Hook `yaml:"on_load"` + PreEval []Hook `yaml:"pre_eval"` + PostEval []Hook `yaml:"post_eval"` + OnMatch []Hook `yaml:"on_match"` + VariablesTracking []string `yaml:"variables_tracking"` + InbandOptions AppsecSubEngineOpts `yaml:"inband_options"` + OutOfBandOptions AppsecSubEngineOpts `yaml:"outofband_options"` + + LogLevel *log.Level `yaml:"log_level"` + Logger *log.Entry `yaml:"-"` +} + +func (w *AppsecRuntimeConfig) ClearResponse() { + w.Response = AppsecTempResponse{} + w.Response.Action = w.Config.DefaultPassAction + w.Response.BouncerHTTPResponseCode = w.Config.BouncerPassedHTTPCode + w.Response.UserHTTPResponseCode = w.Config.UserPassedHTTPCode + w.Response.SendEvent = true + w.Response.SendAlert = true +} + +func (wc *AppsecConfig) LoadByPath(file string) error { + wc.Logger.Debugf("loading config %s", file) + + yamlFile, err := os.ReadFile(file) + if err != nil { + return fmt.Errorf("unable to read file %s : %s", file, err) + } + err = yaml.UnmarshalStrict(yamlFile, wc) + if err != nil { + return fmt.Errorf("unable to parse yaml file %s : %s", file, err) + } + + if wc.Name == "" { + return errors.New("name cannot be empty") + } + if wc.LogLevel == nil { + lvl := wc.Logger.Logger.GetLevel() + wc.LogLevel = &lvl + } + wc.Logger = wc.Logger.Dup().WithField("name", wc.Name) + wc.Logger.Logger.SetLevel(*wc.LogLevel) + return nil +} + +func (wc *AppsecConfig) Load(configName string) error { + item := hub.GetItem(cwhub.APPSEC_CONFIGS, configName) + + if item != nil && item.State.Installed { + wc.Logger.Infof("loading %s", item.State.LocalPath) + err := wc.LoadByPath(item.State.LocalPath) + if err != nil { + return fmt.Errorf("unable to load appsec-config %s : %s", item.State.LocalPath, err) + } + return nil + } + + return fmt.Errorf("no appsec-config found for %s", configName) +} + +func (wc *AppsecConfig) GetDataDir() string { + return hub.GetDataDir() +} + +func (wc *AppsecConfig) Build() (*AppsecRuntimeConfig, error) { + ret := &AppsecRuntimeConfig{Logger: wc.Logger.WithField("component", "appsec_runtime_config")} + + if wc.BouncerBlockedHTTPCode == 0 { + wc.BouncerBlockedHTTPCode = http.StatusForbidden + } + if wc.BouncerPassedHTTPCode == 0 { + wc.BouncerPassedHTTPCode = http.StatusOK + } + + if wc.UserBlockedHTTPCode == 0 { + wc.UserBlockedHTTPCode = http.StatusForbidden + } + if wc.UserPassedHTTPCode == 0 { + wc.UserPassedHTTPCode = http.StatusOK + } + if wc.DefaultPassAction == "" { + wc.DefaultPassAction = AllowRemediation + } + if wc.DefaultRemediation == "" { + wc.DefaultRemediation = BanRemediation + } + + // set the defaults + switch wc.DefaultRemediation { + case BanRemediation, CaptchaRemediation, AllowRemediation: + // those are the officially supported remediation(s) + default: + wc.Logger.Warningf("default '%s' remediation of %s is none of [%s,%s,%s] ensure bouncer compatbility!", wc.DefaultRemediation, wc.Name, BanRemediation, CaptchaRemediation, AllowRemediation) + } + + ret.Name = wc.Name + ret.Config = wc + ret.DefaultRemediation = wc.DefaultRemediation + + wc.Logger.Tracef("Loading config %+v", wc) + // load rules + for _, rule := range wc.OutOfBandRules { + wc.Logger.Infof("loading outofband rule %s", rule) + collections, err := LoadCollection(rule, wc.Logger.WithField("component", "appsec_collection_loader")) + if err != nil { + return nil, fmt.Errorf("unable to load outofband rule %s : %s", rule, err) + } + ret.OutOfBandRules = append(ret.OutOfBandRules, collections...) + } + + wc.Logger.Infof("Loaded %d outofband rules", len(ret.OutOfBandRules)) + for _, rule := range wc.InBandRules { + wc.Logger.Infof("loading inband rule %s", rule) + collections, err := LoadCollection(rule, wc.Logger.WithField("component", "appsec_collection_loader")) + if err != nil { + return nil, fmt.Errorf("unable to load inband rule %s : %s", rule, err) + } + ret.InBandRules = append(ret.InBandRules, collections...) + } + + wc.Logger.Infof("Loaded %d inband rules", len(ret.InBandRules)) + + // load hooks + for _, hook := range wc.OnLoad { + if hook.OnSuccess != "" && hook.OnSuccess != "continue" && hook.OnSuccess != "break" { + return nil, fmt.Errorf("invalid 'on_success' for on_load hook : %s", hook.OnSuccess) + } + err := hook.Build(hookOnLoad) + if err != nil { + return nil, fmt.Errorf("unable to build on_load hook : %s", err) + } + ret.CompiledOnLoad = append(ret.CompiledOnLoad, hook) + } + + for _, hook := range wc.PreEval { + if hook.OnSuccess != "" && hook.OnSuccess != "continue" && hook.OnSuccess != "break" { + return nil, fmt.Errorf("invalid 'on_success' for pre_eval hook : %s", hook.OnSuccess) + } + err := hook.Build(hookPreEval) + if err != nil { + return nil, fmt.Errorf("unable to build pre_eval hook : %s", err) + } + ret.CompiledPreEval = append(ret.CompiledPreEval, hook) + } + + for _, hook := range wc.PostEval { + if hook.OnSuccess != "" && hook.OnSuccess != "continue" && hook.OnSuccess != "break" { + return nil, fmt.Errorf("invalid 'on_success' for post_eval hook : %s", hook.OnSuccess) + } + err := hook.Build(hookPostEval) + if err != nil { + return nil, fmt.Errorf("unable to build post_eval hook : %s", err) + } + ret.CompiledPostEval = append(ret.CompiledPostEval, hook) + } + + for _, hook := range wc.OnMatch { + if hook.OnSuccess != "" && hook.OnSuccess != "continue" && hook.OnSuccess != "break" { + return nil, fmt.Errorf("invalid 'on_success' for on_match hook : %s", hook.OnSuccess) + } + err := hook.Build(hookOnMatch) + if err != nil { + return nil, fmt.Errorf("unable to build on_match hook : %s", err) + } + ret.CompiledOnMatch = append(ret.CompiledOnMatch, hook) + } + + // variable tracking + for _, variable := range wc.VariablesTracking { + compiledVariableRule, err := regexp.Compile(variable) + if err != nil { + return nil, fmt.Errorf("cannot compile variable regexp %s: %w", variable, err) + } + ret.CompiledVariablesTracking = append(ret.CompiledVariablesTracking, compiledVariableRule) + } + return ret, nil +} + +func (w *AppsecRuntimeConfig) ProcessOnLoadRules() error { + has_match := false + for _, rule := range w.CompiledOnLoad { + if rule.FilterExpr != nil { + output, err := exprhelpers.Run(rule.FilterExpr, GetOnLoadEnv(w), w.Logger, w.Logger.Level >= log.DebugLevel) + if err != nil { + return fmt.Errorf("unable to run appsec on_load filter %s : %w", rule.Filter, err) + } + switch t := output.(type) { + case bool: + if !t { + w.Logger.Debugf("filter didnt match") + continue + } + default: + w.Logger.Errorf("Filter must return a boolean, can't filter") + continue + } + has_match = true + } + for _, applyExpr := range rule.ApplyExpr { + o, err := exprhelpers.Run(applyExpr, GetOnLoadEnv(w), w.Logger, w.Logger.Level >= log.DebugLevel) + if err != nil { + w.Logger.Errorf("unable to apply appsec on_load expr: %s", err) + continue + } + switch t := o.(type) { + case error: + w.Logger.Errorf("unable to apply appsec on_load expr: %s", t) + continue + default: + } + } + if has_match && rule.OnSuccess == "break" { + break + } + } + return nil +} + +func (w *AppsecRuntimeConfig) ProcessOnMatchRules(request *ParsedRequest, evt types.Event) error { + has_match := false + for _, rule := range w.CompiledOnMatch { + if rule.FilterExpr != nil { + output, err := exprhelpers.Run(rule.FilterExpr, GetOnMatchEnv(w, request, evt), w.Logger, w.Logger.Level >= log.DebugLevel) + if err != nil { + return fmt.Errorf("unable to run appsec on_match filter %s : %w", rule.Filter, err) + } + switch t := output.(type) { + case bool: + if !t { + w.Logger.Debugf("filter didnt match") + continue + } + default: + w.Logger.Errorf("Filter must return a boolean, can't filter") + continue + } + has_match = true + } + for _, applyExpr := range rule.ApplyExpr { + o, err := exprhelpers.Run(applyExpr, GetOnMatchEnv(w, request, evt), w.Logger, w.Logger.Level >= log.DebugLevel) + if err != nil { + w.Logger.Errorf("unable to apply appsec on_match expr: %s", err) + continue + } + switch t := o.(type) { + case error: + w.Logger.Errorf("unable to apply appsec on_match expr: %s", t) + continue + default: + } + } + if has_match && rule.OnSuccess == "break" { + break + } + } + return nil +} + +func (w *AppsecRuntimeConfig) ProcessPreEvalRules(request *ParsedRequest) error { + has_match := false + for _, rule := range w.CompiledPreEval { + if rule.FilterExpr != nil { + output, err := exprhelpers.Run(rule.FilterExpr, GetPreEvalEnv(w, request), w.Logger, w.Logger.Level >= log.DebugLevel) + if err != nil { + return fmt.Errorf("unable to run appsec pre_eval filter %s : %w", rule.Filter, err) + } + switch t := output.(type) { + case bool: + if !t { + w.Logger.Debugf("filter didnt match") + continue + } + default: + w.Logger.Errorf("Filter must return a boolean, can't filter") + continue + } + has_match = true + } + // here means there is no filter or the filter matched + for _, applyExpr := range rule.ApplyExpr { + o, err := exprhelpers.Run(applyExpr, GetPreEvalEnv(w, request), w.Logger, w.Logger.Level >= log.DebugLevel) + if err != nil { + w.Logger.Errorf("unable to apply appsec pre_eval expr: %s", err) + continue + } + switch t := o.(type) { + case error: + w.Logger.Errorf("unable to apply appsec pre_eval expr: %s", t) + continue + default: + } + } + if has_match && rule.OnSuccess == "break" { + break + } + } + + return nil +} + +func (w *AppsecRuntimeConfig) ProcessPostEvalRules(request *ParsedRequest) error { + has_match := false + for _, rule := range w.CompiledPostEval { + if rule.FilterExpr != nil { + output, err := exprhelpers.Run(rule.FilterExpr, GetPostEvalEnv(w, request), w.Logger, w.Logger.Level >= log.DebugLevel) + if err != nil { + return fmt.Errorf("unable to run appsec post_eval filter %s : %w", rule.Filter, err) + } + switch t := output.(type) { + case bool: + if !t { + w.Logger.Debugf("filter didnt match") + continue + } + default: + w.Logger.Errorf("Filter must return a boolean, can't filter") + continue + } + has_match = true + } + // here means there is no filter or the filter matched + for _, applyExpr := range rule.ApplyExpr { + o, err := exprhelpers.Run(applyExpr, GetPostEvalEnv(w, request), w.Logger, w.Logger.Level >= log.DebugLevel) + if err != nil { + w.Logger.Errorf("unable to apply appsec post_eval expr: %s", err) + continue + } + + switch t := o.(type) { + case error: + w.Logger.Errorf("unable to apply appsec post_eval expr: %s", t) + continue + default: + } + } + if has_match && rule.OnSuccess == "break" { + break + } + } + + return nil +} + +func (w *AppsecRuntimeConfig) RemoveInbandRuleByID(id int) error { + w.Logger.Debugf("removing inband rule %d", id) + return w.InBandTx.RemoveRuleByIDWithError(id) +} + +func (w *AppsecRuntimeConfig) RemoveOutbandRuleByID(id int) error { + w.Logger.Debugf("removing outband rule %d", id) + return w.OutOfBandTx.RemoveRuleByIDWithError(id) +} + +func (w *AppsecRuntimeConfig) RemoveInbandRuleByTag(tag string) error { + w.Logger.Debugf("removing inband rule with tag %s", tag) + return w.InBandTx.RemoveRuleByTagWithError(tag) +} + +func (w *AppsecRuntimeConfig) RemoveOutbandRuleByTag(tag string) error { + w.Logger.Debugf("removing outband rule with tag %s", tag) + return w.OutOfBandTx.RemoveRuleByTagWithError(tag) +} + +func (w *AppsecRuntimeConfig) RemoveInbandRuleByName(name string) error { + tag := fmt.Sprintf("crowdsec-%s", name) + w.Logger.Debugf("removing inband rule %s", tag) + return w.InBandTx.RemoveRuleByTagWithError(tag) +} + +func (w *AppsecRuntimeConfig) RemoveOutbandRuleByName(name string) error { + tag := fmt.Sprintf("crowdsec-%s", name) + w.Logger.Debugf("removing outband rule %s", tag) + return w.OutOfBandTx.RemoveRuleByTagWithError(tag) +} + +func (w *AppsecRuntimeConfig) CancelEvent() error { + w.Logger.Debugf("canceling event") + w.Response.SendEvent = false + return nil +} + +// Disable a rule at load time, meaning it will not run for any request +func (w *AppsecRuntimeConfig) DisableInBandRuleByID(id int) error { + w.DisabledInBandRuleIds = append(w.DisabledInBandRuleIds, id) + return nil +} + +// Disable a rule at load time, meaning it will not run for any request +func (w *AppsecRuntimeConfig) DisableInBandRuleByName(name string) error { + tagValue := fmt.Sprintf("crowdsec-%s", name) + w.DisabledInBandRulesTags = append(w.DisabledInBandRulesTags, tagValue) + return nil +} + +// Disable a rule at load time, meaning it will not run for any request +func (w *AppsecRuntimeConfig) DisableInBandRuleByTag(tag string) error { + w.DisabledInBandRulesTags = append(w.DisabledInBandRulesTags, tag) + return nil +} + +// Disable a rule at load time, meaning it will not run for any request +func (w *AppsecRuntimeConfig) DisableOutBandRuleByID(id int) error { + w.DisabledOutOfBandRuleIds = append(w.DisabledOutOfBandRuleIds, id) + return nil +} + +// Disable a rule at load time, meaning it will not run for any request +func (w *AppsecRuntimeConfig) DisableOutBandRuleByName(name string) error { + tagValue := fmt.Sprintf("crowdsec-%s", name) + w.DisabledOutOfBandRulesTags = append(w.DisabledOutOfBandRulesTags, tagValue) + return nil +} + +// Disable a rule at load time, meaning it will not run for any request +func (w *AppsecRuntimeConfig) DisableOutBandRuleByTag(tag string) error { + w.DisabledOutOfBandRulesTags = append(w.DisabledOutOfBandRulesTags, tag) + return nil +} + +func (w *AppsecRuntimeConfig) SendEvent() error { + w.Logger.Debugf("sending event") + w.Response.SendEvent = true + return nil +} + +func (w *AppsecRuntimeConfig) SendAlert() error { + w.Logger.Debugf("sending alert") + w.Response.SendAlert = true + return nil +} + +func (w *AppsecRuntimeConfig) CancelAlert() error { + w.Logger.Debugf("canceling alert") + w.Response.SendAlert = false + return nil +} + +func (w *AppsecRuntimeConfig) SetActionByTag(tag string, action string) error { + if w.RemediationByTag == nil { + w.RemediationByTag = make(map[string]string) + } + w.Logger.Debugf("setting action of %s to %s", tag, action) + w.RemediationByTag[tag] = action + return nil +} + +func (w *AppsecRuntimeConfig) SetActionByID(id int, action string) error { + if w.RemediationById == nil { + w.RemediationById = make(map[int]string) + } + w.Logger.Debugf("setting action of %d to %s", id, action) + w.RemediationById[id] = action + return nil +} + +func (w *AppsecRuntimeConfig) SetActionByName(name string, action string) error { + if w.RemediationByTag == nil { + w.RemediationByTag = make(map[string]string) + } + tag := fmt.Sprintf("crowdsec-%s", name) + w.Logger.Debugf("setting action of %s to %s", tag, action) + w.RemediationByTag[tag] = action + return nil +} + +func (w *AppsecRuntimeConfig) SetAction(action string) error { + // log.Infof("setting to %s", action) + w.Logger.Debugf("setting action to %s", action) + w.Response.Action = action + return nil +} + +func (w *AppsecRuntimeConfig) SetHTTPCode(code int) error { + w.Logger.Debugf("setting http code to %d", code) + w.Response.UserHTTPResponseCode = code + return nil +} + +type BodyResponse struct { + Action string `json:"action"` + HTTPStatus int `json:"http_status"` +} + +func (w *AppsecRuntimeConfig) GenerateResponse(response AppsecTempResponse, logger *log.Entry) (int, BodyResponse) { + var bouncerStatusCode int + + resp := BodyResponse{Action: response.Action} + if response.Action == AllowRemediation { + resp.HTTPStatus = w.Config.UserPassedHTTPCode + bouncerStatusCode = w.Config.BouncerPassedHTTPCode + } else { // ban, captcha and anything else + resp.HTTPStatus = response.UserHTTPResponseCode + if resp.HTTPStatus == 0 { + resp.HTTPStatus = w.Config.UserBlockedHTTPCode + } + bouncerStatusCode = response.BouncerHTTPResponseCode + if bouncerStatusCode == 0 { + bouncerStatusCode = w.Config.BouncerBlockedHTTPCode + } + } + + return bouncerStatusCode, resp +} diff --git a/pkg/appsec/appsec_rule/appsec_rule.go b/pkg/appsec/appsec_rule/appsec_rule.go new file mode 100644 index 00000000000..136d8b11cb7 --- /dev/null +++ b/pkg/appsec/appsec_rule/appsec_rule.go @@ -0,0 +1,70 @@ +package appsec_rule + +import ( + "errors" + "fmt" +) + +/* +rules: + - name: "test" + and: + - zones: + - BODY_ARGS + variables: + - foo + - bar + transform: + - lowercase|uppercase|b64decode|... + match: + type: regex + value: "[^a-zA-Z]" + - zones: + - ARGS + variables: + - bla + +*/ + +type Match struct { + Type string `yaml:"type"` + Value string `yaml:"value"` + Not bool `yaml:"not,omitempty"` +} + +type CustomRule struct { + Name string `yaml:"name"` + + Zones []string `yaml:"zones"` + Variables []string `yaml:"variables"` + + Match Match `yaml:"match"` + Transform []string `yaml:"transform"` //t:lowercase, t:uppercase, etc + And []CustomRule `yaml:"and,omitempty"` + Or []CustomRule `yaml:"or,omitempty"` + + BodyType string `yaml:"body_type,omitempty"` +} + +func (v *CustomRule) Convert(ruleType string, appsecRuleName string) (string, []uint32, error) { + + if v.Zones == nil && v.And == nil && v.Or == nil { + return "", nil, errors.New("no zones defined") + } + + if v.Match.Type == "" && v.And == nil && v.Or == nil { + return "", nil, errors.New("no match type defined") + } + + if v.Match.Value == "" && v.And == nil && v.Or == nil { + return "", nil, errors.New("no match value defined") + } + + switch ruleType { + case ModsecurityRuleType: + r := ModsecurityRule{} + return r.Build(v, appsecRuleName) + default: + return "", nil, fmt.Errorf("unknown rule format '%s'", ruleType) + } +} diff --git a/pkg/appsec/appsec_rule/modsec_rule_test.go b/pkg/appsec/appsec_rule/modsec_rule_test.go new file mode 100644 index 00000000000..ffb8a15ff1f --- /dev/null +++ b/pkg/appsec/appsec_rule/modsec_rule_test.go @@ -0,0 +1,173 @@ +package appsec_rule + +import "testing" + +func TestVPatchRuleString(t *testing.T) { + tests := []struct { + name string + rule CustomRule + expected string + }{ + { + name: "Collection count", + rule: CustomRule{ + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: Match{Type: "eq", Value: "1"}, + Transform: []string{"count"}, + }, + expected: `SecRule &ARGS_GET:foo "@eq 1" "id:853070236,phase:2,deny,log,msg:'Collection count',tag:'crowdsec-Collection count'"`, + }, + { + name: "Base Rule", + rule: CustomRule{ + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: Match{Type: "regex", Value: "[^a-zA-Z]"}, + Transform: []string{"lowercase"}, + }, + expected: `SecRule ARGS_GET:foo "@rx [^a-zA-Z]" "id:2203944045,phase:2,deny,log,msg:'Base Rule',tag:'crowdsec-Base Rule',t:lowercase"`, + }, + { + name: "One zone, multi var", + rule: CustomRule{ + Zones: []string{"ARGS"}, + Variables: []string{"foo", "bar"}, + Match: Match{Type: "regex", Value: "[^a-zA-Z]"}, + Transform: []string{"lowercase"}, + }, + expected: `SecRule ARGS_GET:foo|ARGS_GET:bar "@rx [^a-zA-Z]" "id:385719930,phase:2,deny,log,msg:'One zone, multi var',tag:'crowdsec-One zone, multi var',t:lowercase"`, + }, + { + name: "Base Rule #2", + rule: CustomRule{ + Zones: []string{"METHOD"}, + Match: Match{Type: "startsWith", Value: "toto"}, + }, + expected: `SecRule REQUEST_METHOD "@beginsWith toto" "id:2759779019,phase:2,deny,log,msg:'Base Rule #2',tag:'crowdsec-Base Rule #2'"`, + }, + { + name: "Base Negative Rule", + rule: CustomRule{ + Zones: []string{"METHOD"}, + Match: Match{Type: "startsWith", Value: "toto", Not: true}, + }, + expected: `SecRule REQUEST_METHOD "!@beginsWith toto" "id:3966251995,phase:2,deny,log,msg:'Base Negative Rule',tag:'crowdsec-Base Negative Rule'"`, + }, + { + name: "Multiple Zones", + rule: CustomRule{ + Zones: []string{"ARGS", "BODY_ARGS"}, + Variables: []string{"foo"}, + Match: Match{Type: "regex", Value: "[^a-zA-Z]"}, + Transform: []string{"lowercase"}, + }, + expected: `SecRule ARGS_GET:foo|ARGS_POST:foo "@rx [^a-zA-Z]" "id:3387135861,phase:2,deny,log,msg:'Multiple Zones',tag:'crowdsec-Multiple Zones',t:lowercase"`, + }, + { + name: "Multiple Zones Multi Var", + rule: CustomRule{ + Zones: []string{"ARGS", "BODY_ARGS"}, + Variables: []string{"foo", "bar"}, + Match: Match{Type: "regex", Value: "[^a-zA-Z]"}, + Transform: []string{"lowercase"}, + }, + expected: `SecRule ARGS_GET:foo|ARGS_GET:bar|ARGS_POST:foo|ARGS_POST:bar "@rx [^a-zA-Z]" "id:1119773585,phase:2,deny,log,msg:'Multiple Zones Multi Var',tag:'crowdsec-Multiple Zones Multi Var',t:lowercase"`, + }, + { + name: "Multiple Zones No Vars", + rule: CustomRule{ + Zones: []string{"ARGS", "BODY_ARGS"}, + Match: Match{Type: "regex", Value: "[^a-zA-Z]"}, + Transform: []string{"lowercase"}, + }, + expected: `SecRule ARGS_GET|ARGS_POST "@rx [^a-zA-Z]" "id:2020110336,phase:2,deny,log,msg:'Multiple Zones No Vars',tag:'crowdsec-Multiple Zones No Vars',t:lowercase"`, + }, + { + name: "Basic AND", + rule: CustomRule{ + And: []CustomRule{ + { + + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: Match{Type: "regex", Value: "[^a-zA-Z]"}, + Transform: []string{"lowercase"}, + }, + { + Zones: []string{"ARGS"}, + Variables: []string{"bar"}, + Match: Match{Type: "regex", Value: "[^a-zA-Z]"}, + Transform: []string{"lowercase"}, + }, + }, + }, + expected: `SecRule ARGS_GET:foo "@rx [^a-zA-Z]" "id:4145519614,phase:2,deny,log,msg:'Basic AND',tag:'crowdsec-Basic AND',t:lowercase,chain" +SecRule ARGS_GET:bar "@rx [^a-zA-Z]" "id:1865217529,phase:2,deny,log,msg:'Basic AND',tag:'crowdsec-Basic AND',t:lowercase"`, + }, + { + name: "Basic OR", + rule: CustomRule{ + Or: []CustomRule{ + { + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: Match{Type: "regex", Value: "[^a-zA-Z]"}, + Transform: []string{"lowercase"}, + }, + { + Zones: []string{"ARGS"}, + Variables: []string{"bar"}, + Match: Match{Type: "regex", Value: "[^a-zA-Z]"}, + Transform: []string{"lowercase"}, + }, + }, + }, + expected: `SecRule ARGS_GET:foo "@rx [^a-zA-Z]" "id:651140804,phase:2,deny,log,msg:'Basic OR',tag:'crowdsec-Basic OR',t:lowercase,skip:1" +SecRule ARGS_GET:bar "@rx [^a-zA-Z]" "id:271441587,phase:2,deny,log,msg:'Basic OR',tag:'crowdsec-Basic OR',t:lowercase"`, + }, + { + name: "OR AND mix", + rule: CustomRule{ + And: []CustomRule{ + { + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: Match{Type: "regex", Value: "[^a-zA-Z]"}, + Transform: []string{"lowercase"}, + Or: []CustomRule{ + { + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: Match{Type: "regex", Value: "[^a-zA-Z]"}, + Transform: []string{"lowercase"}, + }, + { + Zones: []string{"ARGS"}, + Variables: []string{"bar"}, + Match: Match{Type: "regex", Value: "[^a-zA-Z]"}, + Transform: []string{"lowercase"}, + }, + }, + }, + }, + }, + expected: `SecRule ARGS_GET:foo "@rx [^a-zA-Z]" "id:1714963250,phase:2,deny,log,msg:'OR AND mix',tag:'crowdsec-OR AND mix',t:lowercase,skip:1" +SecRule ARGS_GET:bar "@rx [^a-zA-Z]" "id:1519945803,phase:2,deny,log,msg:'OR AND mix',tag:'crowdsec-OR AND mix',t:lowercase" +SecRule ARGS_GET:foo "@rx [^a-zA-Z]" "id:1519945803,phase:2,deny,log,msg:'OR AND mix',tag:'crowdsec-OR AND mix',t:lowercase"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual, _, err := tt.rule.Convert(ModsecurityRuleType, tt.name) + + if err != nil { + t.Errorf("Error converting rule: %s", err) + } + if actual != tt.expected { + t.Errorf("Expected:\n%s\nGot:\n%s", tt.expected, actual) + } + }) + } +} diff --git a/pkg/appsec/appsec_rule/modsecurity.go b/pkg/appsec/appsec_rule/modsecurity.go new file mode 100644 index 00000000000..135ba525e8e --- /dev/null +++ b/pkg/appsec/appsec_rule/modsecurity.go @@ -0,0 +1,216 @@ +package appsec_rule + +import ( + "errors" + "fmt" + "hash/fnv" + "strings" +) + +type ModsecurityRule struct { + ids []uint32 +} + +var zonesMap = map[string]string{ + "ARGS": "ARGS_GET", + "ARGS_NAMES": "ARGS_GET_NAMES", + "BODY_ARGS": "ARGS_POST", + "BODY_ARGS_NAMES": "ARGS_POST_NAMES", + "COOKIES": "REQUEST_COOKIES", + "COOKIES_NAMES": "REQUEST_COOKIES_NAMES", + "FILES": "FILES", + "FILES_NAMES": "FILES_NAMES", + "FILES_TOTAL_SIZE": "FILES_COMBINED_SIZE", + "HEADERS_NAMES": "REQUEST_HEADERS_NAMES", + "HEADERS": "REQUEST_HEADERS", + "METHOD": "REQUEST_METHOD", + "PROTOCOL": "REQUEST_PROTOCOL", + "URI": "REQUEST_FILENAME", + "URI_FULL": "REQUEST_URI", + "RAW_BODY": "REQUEST_BODY", + "FILENAMES": "FILES", +} + +var transformMap = map[string]string{ + "lowercase": "t:lowercase", + "uppercase": "t:uppercase", + "b64decode": "t:base64Decode", + //"hexdecode": "t:hexDecode", -> not supported by coraza + "length": "t:length", + "urldecode": "t:urlDecode", + "trim": "t:trim", + "normalize_path": "t:normalizePath", + "normalizepath": "t:normalizePath", + "htmlentitydecode": "t:htmlEntityDecode", + "html_entity_decode": "t:htmlEntityDecode", +} + +var matchMap = map[string]string{ + "regex": "@rx", + "equals": "@streq", + "startsWith": "@beginsWith", + "endsWith": "@endsWith", + "contains": "@contains", + "libinjectionSQL": "@detectSQLi", + "libinjectionXSS": "@detectXSS", + "gt": "@gt", + "lt": "@lt", + "gte": "@ge", + "lte": "@le", + "eq": "@eq", +} + +var bodyTypeMatch = map[string]string{ + "json": "JSON", + "xml": "XML", + "multipart": "MULTIPART", + "urlencoded": "URLENCODED", +} + +func (m *ModsecurityRule) Build(rule *CustomRule, appsecRuleName string) (string, []uint32, error) { + rules, err := m.buildRules(rule, appsecRuleName, false, 0, 0) + if err != nil { + return "", nil, err + } + + //We return the id of the first generated rule, as it's the interesting one in case of chain or skip + return strings.Join(rules, "\n"), m.ids, nil +} + +func (m *ModsecurityRule) generateRuleID(rule *CustomRule, appsecRuleName string, depth int) uint32 { + h := fnv.New32a() + h.Write([]byte(appsecRuleName)) + h.Write([]byte(rule.Match.Type)) + h.Write([]byte(rule.Match.Value)) + h.Write([]byte(fmt.Sprintf("%d", depth))) + for _, zone := range rule.Zones { + h.Write([]byte(zone)) + } + for _, transform := range rule.Transform { + h.Write([]byte(transform)) + } + id := h.Sum32() + m.ids = append(m.ids, id) + return id +} + +func (m *ModsecurityRule) buildRules(rule *CustomRule, appsecRuleName string, and bool, toSkip int, depth int) ([]string, error) { + ret := make([]string, 0) + + if len(rule.And) != 0 && len(rule.Or) != 0 { + return nil, errors.New("cannot have both 'and' and 'or' in the same rule") + } + + if rule.And != nil { + for c, andRule := range rule.And { + depth++ + lastRule := c == len(rule.And)-1 // || len(rule.Or) == 0 + rules, err := m.buildRules(&andRule, appsecRuleName, !lastRule, 0, depth) + if err != nil { + return nil, err + } + ret = append(ret, rules...) + } + } + + if rule.Or != nil { + for c, orRule := range rule.Or { + depth++ + skip := len(rule.Or) - c - 1 + rules, err := m.buildRules(&orRule, appsecRuleName, false, skip, depth) + if err != nil { + return nil, err + } + ret = append(ret, rules...) + } + } + + r := strings.Builder{} + + r.WriteString("SecRule ") + + if rule.Zones == nil { + return ret, nil + } + + zone_prefix := "" + variable_prefix := "" + if rule.Transform != nil { + for tidx, transform := range rule.Transform { + if transform == "count" { + zone_prefix = "&" + rule.Transform[tidx] = "" + } + } + } + for idx, zone := range rule.Zones { + if idx > 0 { + r.WriteByte('|') + } + mappedZone, ok := zonesMap[zone] + if !ok { + return nil, fmt.Errorf("unknown zone '%s'", zone) + } + if len(rule.Variables) == 0 { + r.WriteString(mappedZone) + } else { + for j, variable := range rule.Variables { + if j > 0 { + r.WriteByte('|') + } + r.WriteString(fmt.Sprintf("%s%s:%s%s", zone_prefix, mappedZone, variable_prefix, variable)) + } + } + } + r.WriteByte(' ') + + if rule.Match.Type != "" { + match, ok := matchMap[rule.Match.Type] + if !ok { + return nil, fmt.Errorf("unknown match type '%s'", rule.Match.Type) + } + prefix := "" + if rule.Match.Not { + prefix = "!" + } + r.WriteString(fmt.Sprintf(`"%s%s %s"`, prefix, match, rule.Match.Value)) + } + + //Should phase:2 be configurable? + r.WriteString(fmt.Sprintf(` "id:%d,phase:2,deny,log,msg:'%s',tag:'crowdsec-%s'`, m.generateRuleID(rule, appsecRuleName, depth), appsecRuleName, appsecRuleName)) + + if rule.Transform != nil { + for _, transform := range rule.Transform { + if transform == "" { + continue + } + r.WriteByte(',') + mappedTransform, ok := transformMap[transform] + if !ok { + return nil, fmt.Errorf("unknown transform '%s'", transform) + } + r.WriteString(mappedTransform) + } + } + + if rule.BodyType != "" { + mappedBodyType, ok := bodyTypeMatch[rule.BodyType] + if !ok { + return nil, fmt.Errorf("unknown body type '%s'", rule.BodyType) + } + r.WriteString(fmt.Sprintf(",ctl:requestBodyProcessor=%s", mappedBodyType)) + } + + if and { + r.WriteString(",chain") + } + + if toSkip > 0 { + r.WriteString(fmt.Sprintf(",skip:%d", toSkip)) + } + + r.WriteByte('"') + + ret = append(ret, r.String()) + return ret, nil +} diff --git a/pkg/appsec/appsec_rule/types.go b/pkg/appsec/appsec_rule/types.go new file mode 100644 index 00000000000..13716975a05 --- /dev/null +++ b/pkg/appsec/appsec_rule/types.go @@ -0,0 +1,9 @@ +package appsec_rule + +const ( + ModsecurityRuleType = "modsecurity" +) + +func SupportedTypes() []string { + return []string{ModsecurityRuleType} +} diff --git a/pkg/appsec/appsec_rules_collection.go b/pkg/appsec/appsec_rules_collection.go new file mode 100644 index 00000000000..d283f95cb19 --- /dev/null +++ b/pkg/appsec/appsec_rules_collection.go @@ -0,0 +1,142 @@ +package appsec + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/pkg/appsec/appsec_rule" + "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" +) + +type AppsecCollection struct { + collectionName string + Rules []string +} + +var APPSEC_RULE = "appsec-rule" + +// to be filled w/ seb update +type AppsecCollectionConfig struct { + Type string `yaml:"type"` + Name string `yaml:"name"` + Debug bool `yaml:"debug"` + Description string `yaml:"description"` + SecLangFilesRules []string `yaml:"seclang_files_rules"` + SecLangRules []string `yaml:"seclang_rules"` + Rules []appsec_rule.CustomRule `yaml:"rules"` + + Labels map[string]interface{} `yaml:"labels"` // Labels is K:V list aiming at providing context the overflow + + Data interface{} `yaml:"data"` // Ignore it + hash string + version string +} + +type RulesDetails struct { + LogLevel log.Level + Hash string + Version string + Name string +} + +// FIXME: this shouldn't be a global +// Is using the id is a good idea ? might be too specific to coraza and not easily reusable +var AppsecRulesDetails = make(map[int]RulesDetails) + +func LoadCollection(pattern string, logger *log.Entry) ([]AppsecCollection, error) { + ret := make([]AppsecCollection, 0) + + for _, appsecRule := range appsecRules { + tmpMatch, err := exprhelpers.Match(pattern, appsecRule.Name) + if err != nil { + logger.Errorf("unable to match %s with %s : %s", appsecRule.Name, pattern, err) + continue + } + + matched, ok := tmpMatch.(bool) + + if !ok { + logger.Errorf("unable to match %s with %s : %s", appsecRule.Name, pattern, err) + continue + } + + if !matched { + continue + } + + appsecCol := AppsecCollection{ + collectionName: appsecRule.Name, + } + + if appsecRule.SecLangFilesRules != nil { + for _, rulesFile := range appsecRule.SecLangFilesRules { + logger.Debugf("Adding rules from %s", rulesFile) + fullPath := filepath.Join(hub.GetDataDir(), rulesFile) + c, err := os.ReadFile(fullPath) + if err != nil { + logger.Errorf("unable to read file %s : %s", rulesFile, err) + continue + } + for _, line := range strings.Split(string(c), "\n") { + if strings.HasPrefix(line, "#") { + continue + } + if strings.TrimSpace(line) == "" { + continue + } + appsecCol.Rules = append(appsecCol.Rules, line) + } + } + } + + if appsecRule.SecLangRules != nil { + logger.Tracef("Adding inline rules %+v", appsecRule.SecLangRules) + appsecCol.Rules = append(appsecCol.Rules, appsecRule.SecLangRules...) + } + + if appsecRule.Rules != nil { + for _, rule := range appsecRule.Rules { + strRule, rulesId, err := rule.Convert(appsec_rule.ModsecurityRuleType, appsecRule.Name) + if err != nil { + logger.Errorf("unable to convert rule %s : %s", appsecRule.Name, err) + return nil, err + } + logger.Debugf("Adding rule %s", strRule) + appsecCol.Rules = append(appsecCol.Rules, strRule) + + // We only take the first id, as it's the one of the "main" rule + if _, ok := AppsecRulesDetails[int(rulesId[0])]; !ok { + AppsecRulesDetails[int(rulesId[0])] = RulesDetails{ + LogLevel: log.InfoLevel, + Hash: appsecRule.hash, + Version: appsecRule.version, + Name: appsecRule.Name, + } + } else { + logger.Warnf("conflicting id %d for rule %s !", rulesId[0], rule.Name) + } + + for _, id := range rulesId { + SetRuleDebug(int(id), appsecRule.Debug) + } + } + } + ret = append(ret, appsecCol) + } + if len(ret) == 0 { + return nil, fmt.Errorf("no appsec-rules found for pattern %s", pattern) + } + return ret, nil +} + +func (w AppsecCollection) String() string { + ret := "" + for _, rule := range w.Rules { + ret += rule + "\n" + } + return ret +} diff --git a/pkg/appsec/coraza_logger.go b/pkg/appsec/coraza_logger.go new file mode 100644 index 00000000000..d2c1612cbd7 --- /dev/null +++ b/pkg/appsec/coraza_logger.go @@ -0,0 +1,217 @@ +package appsec + +import ( + "fmt" + "io" + + log "github.com/sirupsen/logrus" + + dbg "github.com/crowdsecurity/coraza/v3/debuglog" +) + +var DebugRules = map[int]bool{} + +func SetRuleDebug(id int, debug bool) { + DebugRules[id] = debug +} + +func GetRuleDebug(id int) bool { + if val, ok := DebugRules[id]; ok { + return val + } + + return false +} + +// type ContextField func(Event) Event + +type crzLogEvent struct { + fields log.Fields + logger *log.Entry + muted bool + level log.Level +} + +func (e *crzLogEvent) Msg(msg string) { + if e.muted { + return + } + + /*this is a hack. As we want to have per-level rule debug but it's not allowed by coraza/modsec, if a rule ID is flagged to be in debug mode, the + .Int("rule_id", ) call will set the log_level of the event to debug. However, given the logger is global to the appsec-runner, + we are switching forth and back the log level of the logger*/ + oldLvl := e.logger.Logger.GetLevel() + + if e.level != oldLvl { + e.logger.Logger.SetLevel(e.level) + } + + if len(e.fields) == 0 { + e.logger.Log(e.level, msg) + } else { + e.logger.WithFields(e.fields).Log(e.level, msg) + } + + if e.level != oldLvl { + e.logger.Logger.SetLevel(oldLvl) + e.level = oldLvl + } +} + +func (e *crzLogEvent) Str(key, val string) dbg.Event { + if e.muted { + return e + } + + e.fields[key] = val + + return e +} + +func (e *crzLogEvent) Err(err error) dbg.Event { + if e.muted { + return e + } + + e.fields["error"] = err + + return e +} + +func (e *crzLogEvent) Bool(key string, b bool) dbg.Event { + if e.muted { + return e + } + + e.fields[key] = b + + return e +} + +func (e *crzLogEvent) Int(key string, i int) dbg.Event { + if e.muted { + if key != "rule_id" || !GetRuleDebug(i) { + return e + } + // this allows us to have per-rule debug logging + e.muted = false + e.fields = map[string]interface{}{} + e.level = log.DebugLevel + } + + e.fields[key] = i + + return e +} + +func (e *crzLogEvent) Uint(key string, i uint) dbg.Event { + if e.muted { + return e + } + + e.fields[key] = i + + return e +} + +func (e *crzLogEvent) Stringer(key string, val fmt.Stringer) dbg.Event { + if e.muted { + return e + } + + e.fields[key] = val + + return e +} + +func (e crzLogEvent) IsEnabled() bool { + return !e.muted +} + +type crzLogger struct { + logger *log.Entry + defaultFields log.Fields + logLevel log.Level +} + +func NewCrzLogger(logger *log.Entry) *crzLogger { + return &crzLogger{logger: logger, logLevel: logger.Logger.GetLevel()} +} + +func (c *crzLogger) NewMutedEvt(lvl log.Level) dbg.Event { + return &crzLogEvent{muted: true, logger: c.logger, level: lvl} +} + +func (c *crzLogger) NewEvt(lvl log.Level) dbg.Event { + evt := &crzLogEvent{fields: map[string]interface{}{}, logger: c.logger, level: lvl} + + if c.defaultFields != nil { + for k, v := range c.defaultFields { + evt.fields[k] = v + } + } + + return evt +} + +func (c *crzLogger) WithOutput(w io.Writer) dbg.Logger { + return c +} + +func (c *crzLogger) WithLevel(lvl dbg.Level) dbg.Logger { + c.logLevel = log.Level(lvl) + c.logger.Logger.SetLevel(c.logLevel) + + return c +} + +func (c *crzLogger) With(fs ...dbg.ContextField) dbg.Logger { + e := c.NewEvt(c.logLevel) + for _, f := range fs { + e = f(e) + } + + c.defaultFields = e.(*crzLogEvent).fields + + return c +} + +func (c *crzLogger) Trace() dbg.Event { + if c.logLevel < log.TraceLevel { + return c.NewMutedEvt(log.TraceLevel) + } + + return c.NewEvt(log.TraceLevel) +} + +func (c *crzLogger) Debug() dbg.Event { + if c.logLevel < log.DebugLevel { + return c.NewMutedEvt(log.DebugLevel) + } + + return c.NewEvt(log.DebugLevel) +} + +func (c *crzLogger) Info() dbg.Event { + if c.logLevel < log.InfoLevel { + return c.NewMutedEvt(log.InfoLevel) + } + + return c.NewEvt(log.InfoLevel) +} + +func (c *crzLogger) Warn() dbg.Event { + if c.logLevel < log.WarnLevel { + return c.NewMutedEvt(log.WarnLevel) + } + + return c.NewEvt(log.WarnLevel) +} + +func (c *crzLogger) Error() dbg.Event { + if c.logLevel < log.ErrorLevel { + return c.NewMutedEvt(log.ErrorLevel) + } + + return c.NewEvt(log.ErrorLevel) +} diff --git a/pkg/appsec/loader.go b/pkg/appsec/loader.go new file mode 100644 index 00000000000..c724010cec2 --- /dev/null +++ b/pkg/appsec/loader.go @@ -0,0 +1,47 @@ +package appsec + +import ( + "os" + + log "github.com/sirupsen/logrus" + "gopkg.in/yaml.v2" + + "github.com/crowdsecurity/crowdsec/pkg/cwhub" +) + +var appsecRules = make(map[string]AppsecCollectionConfig) // FIXME: would probably be better to have a struct for this + +var hub *cwhub.Hub // FIXME: this is a temporary hack to make the hub available in the package + +func LoadAppsecRules(hubInstance *cwhub.Hub) error { + hub = hubInstance + appsecRules = make(map[string]AppsecCollectionConfig) + + for _, hubAppsecRuleItem := range hub.GetInstalledByType(cwhub.APPSEC_RULES, false) { + content, err := os.ReadFile(hubAppsecRuleItem.State.LocalPath) + if err != nil { + log.Warnf("unable to read file %s : %s", hubAppsecRuleItem.State.LocalPath, err) + continue + } + + var rule AppsecCollectionConfig + + err = yaml.UnmarshalStrict(content, &rule) + if err != nil { + log.Warnf("unable to parse file %s : %s", hubAppsecRuleItem.State.LocalPath, err) + continue + } + + rule.hash = hubAppsecRuleItem.State.LocalHash + rule.version = hubAppsecRuleItem.Version + + log.Infof("Adding %s to appsec rules", rule.Name) + + appsecRules[rule.Name] = rule + } + + if len(appsecRules) == 0 { + log.Debugf("No appsec rules found") + } + return nil +} diff --git a/pkg/appsec/query_utils.go b/pkg/appsec/query_utils.go new file mode 100644 index 00000000000..0c886e0ea51 --- /dev/null +++ b/pkg/appsec/query_utils.go @@ -0,0 +1,78 @@ +package appsec + +// This file is mostly stolen from net/url package, but with some modifications to allow less strict parsing of query strings + +import ( + "net/url" + "strings" +) + +// parseQuery and parseQuery are copied net/url package, but allow semicolon in values +func ParseQuery(query string) url.Values { + m := make(url.Values) + parseQuery(m, query) + return m +} + +func parseQuery(m url.Values, query string) { + for query != "" { + var key string + key, query, _ = strings.Cut(query, "&") + + if key == "" { + continue + } + key, value, _ := strings.Cut(key, "=") + //for now we'll just ignore the errors, but ideally we want to fire some "internal" rules when we see invalid query strings + key = unescape(key) + value = unescape(value) + m[key] = append(m[key], value) + } +} + +func hexDigitToByte(digit byte) (byte, bool) { + switch { + case digit >= '0' && digit <= '9': + return digit - '0', true + case digit >= 'a' && digit <= 'f': + return digit - 'a' + 10, true + case digit >= 'A' && digit <= 'F': + return digit - 'A' + 10, true + default: + return 0, false + } +} + +func unescape(input string) string { + ilen := len(input) + res := strings.Builder{} + res.Grow(ilen) + for i := 0; i < ilen; i++ { + ci := input[i] + if ci == '+' { + res.WriteByte(' ') + continue + } + if ci == '%' { + if i+2 >= ilen { + res.WriteByte(ci) + continue + } + hi, ok := hexDigitToByte(input[i+1]) + if !ok { + res.WriteByte(ci) + continue + } + lo, ok := hexDigitToByte(input[i+2]) + if !ok { + res.WriteByte(ci) + continue + } + res.WriteByte(hi<<4 | lo) + i += 2 + continue + } + res.WriteByte(ci) + } + return res.String() +} diff --git a/pkg/appsec/query_utils_test.go b/pkg/appsec/query_utils_test.go new file mode 100644 index 00000000000..2ad7927968d --- /dev/null +++ b/pkg/appsec/query_utils_test.go @@ -0,0 +1,207 @@ +package appsec + +import ( + "net/url" + "reflect" + "testing" +) + +func TestParseQuery(t *testing.T) { + tests := []struct { + name string + query string + expected url.Values + }{ + { + name: "Simple query", + query: "foo=bar", + expected: url.Values{ + "foo": []string{"bar"}, + }, + }, + { + name: "Multiple values", + query: "foo=bar&foo=baz", + expected: url.Values{ + "foo": []string{"bar", "baz"}, + }, + }, + { + name: "Empty value", + query: "foo=", + expected: url.Values{ + "foo": []string{""}, + }, + }, + { + name: "Empty key", + query: "=bar", + expected: url.Values{ + "": []string{"bar"}, + }, + }, + { + name: "Empty query", + query: "", + expected: url.Values{}, + }, + { + name: "Multiple keys", + query: "foo=bar&baz=qux", + expected: url.Values{ + "foo": []string{"bar"}, + "baz": []string{"qux"}, + }, + }, + { + name: "Multiple keys with empty value", + query: "foo=bar&baz=qux&quux=", + expected: url.Values{ + "foo": []string{"bar"}, + "baz": []string{"qux"}, + "quux": []string{""}, + }, + }, + { + name: "Multiple keys with empty value and empty key", + query: "foo=bar&baz=qux&quux=&=quuz", + expected: url.Values{ + "foo": []string{"bar"}, + "baz": []string{"qux"}, + "quux": []string{""}, + "": []string{"quuz"}, + }, + }, + { + name: "Multiple keys with empty value and empty key and multiple values", + query: "foo=bar&baz=qux&quux=&=quuz&foo=baz", + expected: url.Values{ + "foo": []string{"bar", "baz"}, + "baz": []string{"qux"}, + "quux": []string{""}, + "": []string{"quuz"}, + }, + }, + { + name: "Multiple keys with empty value and empty key and multiple values and escaped characters", + query: "foo=bar&baz=qux&quux=&=quuz&foo=baz&foo=bar%20baz", + expected: url.Values{ + "foo": []string{"bar", "baz", "bar baz"}, + "baz": []string{"qux"}, + "quux": []string{""}, + "": []string{"quuz"}, + }, + }, + { + name: "Multiple keys with empty value and empty key and multiple values and escaped characters and semicolon", + query: "foo=bar&baz=qux&quux=&=quuz&foo=baz&foo=bar%20baz&foo=bar%3Bbaz", + expected: url.Values{ + "foo": []string{"bar", "baz", "bar baz", "bar;baz"}, + "baz": []string{"qux"}, + "quux": []string{""}, + "": []string{"quuz"}, + }, + }, + { + name: "Multiple keys with empty value and empty key and multiple values and escaped characters and semicolon and ampersand", + query: "foo=bar&baz=qux&quux=&=quuz&foo=baz&foo=bar%20baz&foo=bar%3Bbaz&foo=bar%26baz", + expected: url.Values{ + "foo": []string{"bar", "baz", "bar baz", "bar;baz", "bar&baz"}, + "baz": []string{"qux"}, + "quux": []string{""}, + "": []string{"quuz"}, + }, + }, + { + name: "Multiple keys with empty value and empty key and multiple values and escaped characters and semicolon and ampersand and equals", + query: "foo=bar&baz=qux&quux=&=quuz&foo=baz&foo=bar%20baz&foo=bar%3Bbaz&foo=bar%26baz&foo=bar%3Dbaz", + expected: url.Values{ + "foo": []string{"bar", "baz", "bar baz", "bar;baz", "bar&baz", "bar=baz"}, + "baz": []string{"qux"}, + "quux": []string{""}, + "": []string{"quuz"}, + }, + }, + { + name: "Multiple keys with empty value and empty key and multiple values and escaped characters and semicolon and ampersand and equals and question mark", + query: "foo=bar&baz=qux&quux=&=quuz&foo=baz&foo=bar%20baz&foo=bar%3Bbaz&foo=bar%26baz&foo=bar%3Dbaz&foo=bar%3Fbaz", + expected: url.Values{ + "foo": []string{"bar", "baz", "bar baz", "bar;baz", "bar&baz", "bar=baz", "bar?baz"}, + "baz": []string{"qux"}, + "quux": []string{""}, + "": []string{"quuz"}, + }, + }, + { + name: "keys with escaped characters", + query: "foo=ba;r&baz=qu;;x&quux=x\\&ww&xx=qu?uz&", + expected: url.Values{ + "foo": []string{"ba;r"}, + "baz": []string{"qu;;x"}, + "quux": []string{"x\\"}, + "ww": []string{""}, + "xx": []string{"qu?uz"}, + }, + }, + { + name: "hexadecimal characters", + query: "foo=bar%20baz", + expected: url.Values{ + "foo": []string{"bar baz"}, + }, + }, + { + name: "hexadecimal characters upper and lower case", + query: "foo=Ba%42%42&bar=w%2f%2F", + expected: url.Values{ + "foo": []string{"BaBB"}, + "bar": []string{"w//"}, + }, + }, + { + name: "hexadecimal characters with invalid characters", + query: "foo=bar%20baz%2", + expected: url.Values{ + "foo": []string{"bar baz%2"}, + }, + }, + { + name: "hexadecimal characters with invalid hex characters", + query: "foo=bar%xx", + expected: url.Values{ + "foo": []string{"bar%xx"}, + }, + }, + { + name: "hexadecimal characters with invalid 2nd hex character", + query: "foo=bar%2x", + expected: url.Values{ + "foo": []string{"bar%2x"}, + }, + }, + { + name: "url +", + query: "foo=bar+x", + expected: url.Values{ + "foo": []string{"bar x"}, + }, + }, + { + name: "url &&", + query: "foo=bar&&lol=bur", + expected: url.Values{ + "foo": []string{"bar"}, + "lol": []string{"bur"}, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + res := ParseQuery(test.query) + if !reflect.DeepEqual(res, test.expected) { + t.Fatalf("unexpected result: %v", res) + } + }) + } +} diff --git a/pkg/appsec/request.go b/pkg/appsec/request.go new file mode 100644 index 00000000000..ccd7a9f9cc8 --- /dev/null +++ b/pkg/appsec/request.go @@ -0,0 +1,381 @@ +package appsec + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os" + "regexp" + + "github.com/google/uuid" + log "github.com/sirupsen/logrus" +) + +const ( + URIHeaderName = "X-Crowdsec-Appsec-Uri" + VerbHeaderName = "X-Crowdsec-Appsec-Verb" + HostHeaderName = "X-Crowdsec-Appsec-Host" + IPHeaderName = "X-Crowdsec-Appsec-Ip" + APIKeyHeaderName = "X-Crowdsec-Appsec-Api-Key" + UserAgentHeaderName = "X-Crowdsec-Appsec-User-Agent" +) + +type ParsedRequest struct { + RemoteAddr string `json:"remote_addr,omitempty"` + Host string `json:"host,omitempty"` + ClientIP string `json:"client_ip,omitempty"` + URI string `json:"uri,omitempty"` + Args url.Values `json:"args,omitempty"` + ClientHost string `json:"client_host,omitempty"` + Headers http.Header `json:"headers,omitempty"` + URL *url.URL `json:"url,omitempty"` + Method string `json:"method,omitempty"` + Proto string `json:"proto,omitempty"` + Body []byte `json:"body,omitempty"` + TransferEncoding []string `json:"transfer_encoding,omitempty"` + UUID string `json:"uuid,omitempty"` + Tx ExtendedTransaction `json:"-"` + ResponseChannel chan AppsecTempResponse `json:"-"` + IsInBand bool `json:"-"` + IsOutBand bool `json:"-"` + AppsecEngine string `json:"appsec_engine,omitempty"` + RemoteAddrNormalized string `json:"normalized_remote_addr,omitempty"` + HTTPRequest *http.Request `json:"-"` +} + +type ReqDumpFilter struct { + req *ParsedRequest + HeadersContentFilters []string + HeadersNameFilters []string + HeadersDrop bool + + BodyDrop bool + //BodyContentFilters []string TBD + + ArgsContentFilters []string + ArgsNameFilters []string + ArgsDrop bool +} + +func (r *ParsedRequest) DumpRequest(params ...any) *ReqDumpFilter { + filter := ReqDumpFilter{} + filter.BodyDrop = true + filter.HeadersNameFilters = []string{"cookie", "authorization"} + filter.req = r + return &filter +} + +// clear filters +func (r *ReqDumpFilter) NoFilters() *ReqDumpFilter { + r2 := ReqDumpFilter{} + r2.req = r.req + return &r2 +} + +func (r *ReqDumpFilter) WithEmptyHeadersFilters() *ReqDumpFilter { + r.HeadersContentFilters = []string{} + return r +} + +func (r *ReqDumpFilter) WithHeadersContentFilter(filter string) *ReqDumpFilter { + r.HeadersContentFilters = append(r.HeadersContentFilters, filter) + return r +} + +func (r *ReqDumpFilter) WithHeadersNameFilter(filter string) *ReqDumpFilter { + r.HeadersNameFilters = append(r.HeadersNameFilters, filter) + return r +} + +func (r *ReqDumpFilter) WithNoHeaders() *ReqDumpFilter { + r.HeadersDrop = true + return r +} + +func (r *ReqDumpFilter) WithHeaders() *ReqDumpFilter { + r.HeadersDrop = false + r.HeadersNameFilters = []string{} + return r +} + +func (r *ReqDumpFilter) WithBody() *ReqDumpFilter { + r.BodyDrop = false + return r +} + +func (r *ReqDumpFilter) WithNoBody() *ReqDumpFilter { + r.BodyDrop = true + return r +} + +func (r *ReqDumpFilter) WithEmptyArgsFilters() *ReqDumpFilter { + r.ArgsContentFilters = []string{} + return r +} + +func (r *ReqDumpFilter) WithArgsContentFilter(filter string) *ReqDumpFilter { + r.ArgsContentFilters = append(r.ArgsContentFilters, filter) + return r +} + +func (r *ReqDumpFilter) WithArgsNameFilter(filter string) *ReqDumpFilter { + r.ArgsNameFilters = append(r.ArgsNameFilters, filter) + return r +} + +func (r *ReqDumpFilter) FilterBody(out *ParsedRequest) error { + if r.BodyDrop { + return nil + } + out.Body = r.req.Body + return nil +} + +func (r *ReqDumpFilter) FilterArgs(out *ParsedRequest) error { + if r.ArgsDrop { + return nil + } + if len(r.ArgsContentFilters) == 0 && len(r.ArgsNameFilters) == 0 { + out.Args = r.req.Args + return nil + } + out.Args = make(url.Values) + for k, vals := range r.req.Args { + reject := false + //exclude by match on name + for _, filter := range r.ArgsNameFilters { + ok, err := regexp.MatchString("(?i)"+filter, k) + if err != nil { + log.Debugf("error while matching string '%s' with '%s': %s", filter, k, err) + continue + } + if ok { + reject = true + break + } + } + + for _, v := range vals { + //exclude by content + for _, filter := range r.ArgsContentFilters { + ok, err := regexp.MatchString("(?i)"+filter, v) + if err != nil { + log.Debugf("error while matching string '%s' with '%s': %s", filter, v, err) + continue + } + if ok { + reject = true + break + } + + } + } + //if it was not rejected, let's add it + if !reject { + out.Args[k] = vals + } + } + return nil +} + +func (r *ReqDumpFilter) FilterHeaders(out *ParsedRequest) error { + if r.HeadersDrop { + return nil + } + + if len(r.HeadersContentFilters) == 0 && len(r.HeadersNameFilters) == 0 { + out.Headers = r.req.Headers + return nil + } + + out.Headers = make(http.Header) + for k, vals := range r.req.Headers { + reject := false + //exclude by match on name + for _, filter := range r.HeadersNameFilters { + ok, err := regexp.MatchString("(?i)"+filter, k) + if err != nil { + log.Debugf("error while matching string '%s' with '%s': %s", filter, k, err) + continue + } + if ok { + reject = true + break + } + } + + for _, v := range vals { + //exclude by content + for _, filter := range r.HeadersContentFilters { + ok, err := regexp.MatchString("(?i)"+filter, v) + if err != nil { + log.Debugf("error while matching string '%s' with '%s': %s", filter, v, err) + continue + } + if ok { + reject = true + break + } + + } + } + //if it was not rejected, let's add it + if !reject { + out.Headers[k] = vals + } + } + return nil +} + +func (r *ReqDumpFilter) GetFilteredRequest() *ParsedRequest { + //if there are no filters, we return the original request + if len(r.HeadersContentFilters) == 0 && + len(r.HeadersNameFilters) == 0 && + len(r.ArgsContentFilters) == 0 && + len(r.ArgsNameFilters) == 0 && + !r.BodyDrop && !r.HeadersDrop && !r.ArgsDrop { + log.Warningf("no filters, returning original request") + return r.req + } + + r2 := ParsedRequest{} + r.FilterHeaders(&r2) + r.FilterBody(&r2) + r.FilterArgs(&r2) + return &r2 +} + +func (r *ReqDumpFilter) ToJSON() error { + fd, err := os.CreateTemp("", "crowdsec_req_dump_*.json") + if err != nil { + return fmt.Errorf("while creating temp file: %w", err) + } + defer fd.Close() + enc := json.NewEncoder(fd) + enc.SetIndent("", " ") + + req := r.GetFilteredRequest() + + log.Tracef("dumping : %+v", req) + + if err := enc.Encode(req); err != nil { + //Don't clobber the temp directory with empty files + err2 := os.Remove(fd.Name()) + if err2 != nil { + log.Errorf("while removing temp file %s: %s", fd.Name(), err) + } + return fmt.Errorf("while encoding request: %w", err) + } + log.Infof("request dumped to %s", fd.Name()) + return nil +} + +// Generate a ParsedRequest from a http.Request. ParsedRequest can be consumed by the App security Engine +func NewParsedRequestFromRequest(r *http.Request, logger *log.Entry) (ParsedRequest, error) { + var err error + contentLength := r.ContentLength + if contentLength < 0 { + contentLength = 0 + } + body := make([]byte, contentLength) + if r.Body != nil { + _, err = io.ReadFull(r.Body, body) + if err != nil { + return ParsedRequest{}, fmt.Errorf("unable to read body: %s", err) + } + // reset the original body back as it's been read, i'm not sure its needed? + r.Body = io.NopCloser(bytes.NewBuffer(body)) + + } + clientIP := r.Header.Get(IPHeaderName) + if clientIP == "" { + return ParsedRequest{}, fmt.Errorf("missing '%s' header", IPHeaderName) + } + + clientURI := r.Header.Get(URIHeaderName) + if clientURI == "" { + return ParsedRequest{}, fmt.Errorf("missing '%s' header", URIHeaderName) + } + + clientMethod := r.Header.Get(VerbHeaderName) + if clientMethod == "" { + return ParsedRequest{}, fmt.Errorf("missing '%s' header", VerbHeaderName) + } + + clientHost := r.Header.Get(HostHeaderName) + if clientHost == "" { //this might be empty + logger.Debugf("missing '%s' header", HostHeaderName) + } + + userAgent := r.Header.Get(UserAgentHeaderName) //This one is optional + + // delete those headers before coraza process the request + delete(r.Header, IPHeaderName) + delete(r.Header, HostHeaderName) + delete(r.Header, URIHeaderName) + delete(r.Header, VerbHeaderName) + delete(r.Header, UserAgentHeaderName) + delete(r.Header, APIKeyHeaderName) + + originalHTTPRequest := r.Clone(r.Context()) + originalHTTPRequest.Body = io.NopCloser(bytes.NewBuffer(body)) + originalHTTPRequest.RemoteAddr = clientIP + originalHTTPRequest.RequestURI = clientURI + originalHTTPRequest.Method = clientMethod + originalHTTPRequest.Host = clientHost + if userAgent != "" { + originalHTTPRequest.Header.Set("User-Agent", userAgent) + r.Header.Set("User-Agent", userAgent) //Override the UA in the original request, as this is what will be used by the waf engine + } else { + //If we don't have a forwarded UA, delete the one that was set by the remediation in both original and incoming + originalHTTPRequest.Header.Del("User-Agent") + r.Header.Del("User-Agent") + } + + parsedURL, err := url.Parse(clientURI) + if err != nil { + return ParsedRequest{}, fmt.Errorf("unable to parse url '%s': %s", clientURI, err) + } + + var remoteAddrNormalized string + if r.RemoteAddr == "@" { + r.RemoteAddr = "127.0.0.1:65535" + } + // TODO we need to implement forwrded headers + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + log.Errorf("Invalid appsec remote IP source %v: %s", r.RemoteAddr, err.Error()) + remoteAddrNormalized = r.RemoteAddr + } else { + ip := net.ParseIP(host) + if ip == nil { + log.Errorf("Invalid appsec remote IP address source %v", r.RemoteAddr) + remoteAddrNormalized = r.RemoteAddr + } else { + remoteAddrNormalized = ip.String() + } + } + + return ParsedRequest{ + RemoteAddr: r.RemoteAddr, + UUID: uuid.New().String(), + ClientHost: clientHost, + ClientIP: clientIP, + URI: clientURI, + Method: clientMethod, + Host: clientHost, + Headers: r.Header, + URL: parsedURL, + Proto: r.Proto, + Body: body, + Args: ParseQuery(parsedURL.RawQuery), + TransferEncoding: r.TransferEncoding, + ResponseChannel: make(chan AppsecTempResponse), + RemoteAddrNormalized: remoteAddrNormalized, + HTTPRequest: originalHTTPRequest, + }, nil +} diff --git a/pkg/appsec/request_test.go b/pkg/appsec/request_test.go new file mode 100644 index 00000000000..f8333e4e5f9 --- /dev/null +++ b/pkg/appsec/request_test.go @@ -0,0 +1,181 @@ +package appsec + +import "testing" + +func TestBodyDumper(t *testing.T) { + + tests := []struct { + name string + req *ParsedRequest + expect *ParsedRequest + filter func(r *ReqDumpFilter) *ReqDumpFilter + }{ + { + name: "default filter (cookie+authorization stripped + no body)", + req: &ParsedRequest{ + Body: []byte("yo some body"), + Headers: map[string][]string{"cookie": {"toto"}, "authorization": {"tata"}, "foo": {"bar", "baz"}}, + }, + expect: &ParsedRequest{ + Body: []byte{}, + Headers: map[string][]string{"foo": {"bar", "baz"}}, + }, + filter: func(r *ReqDumpFilter) *ReqDumpFilter { + return r + }, + }, + { + name: "explicit empty filter", + req: &ParsedRequest{ + Body: []byte("yo some body"), + Headers: map[string][]string{"cookie": {"toto"}, "authorization": {"tata"}, "foo": {"bar", "baz"}}, + }, + expect: &ParsedRequest{ + Body: []byte("yo some body"), + Headers: map[string][]string{"cookie": {"toto"}, "authorization": {"tata"}, "foo": {"bar", "baz"}}, + }, + filter: func(r *ReqDumpFilter) *ReqDumpFilter { + return r.NoFilters() + }, + }, + { + name: "filter header", + req: &ParsedRequest{ + Body: []byte{}, + Headers: map[string][]string{"test1": {"toto"}, "test2": {"tata"}}, + }, + expect: &ParsedRequest{ + Body: []byte{}, + Headers: map[string][]string{"test1": {"toto"}}, + }, + filter: func(r *ReqDumpFilter) *ReqDumpFilter { + return r.WithNoBody().WithHeadersNameFilter("test2") + }, + }, + { + name: "filter header content", + req: &ParsedRequest{ + Body: []byte{}, + Headers: map[string][]string{"test1": {"toto"}, "test2": {"tata"}}, + }, + expect: &ParsedRequest{ + Body: []byte{}, + Headers: map[string][]string{"test1": {"toto"}}, + }, + filter: func(r *ReqDumpFilter) *ReqDumpFilter { + return r.WithHeadersContentFilter("tata") + }, + }, + { + name: "with headers", + req: &ParsedRequest{ + Body: []byte{}, + Headers: map[string][]string{"cookie1": {"lol"}}, + }, + expect: &ParsedRequest{ + Body: []byte{}, + Headers: map[string][]string{"cookie1": {"lol"}}, + }, + filter: func(r *ReqDumpFilter) *ReqDumpFilter { + return r.WithHeaders() + }, + }, + { + name: "drop headers", + req: &ParsedRequest{ + Body: []byte{}, + Headers: map[string][]string{"toto": {"lol"}}, + }, + expect: &ParsedRequest{ + Body: []byte{}, + Headers: map[string][]string{}, + }, + filter: func(r *ReqDumpFilter) *ReqDumpFilter { + return r.WithNoHeaders() + }, + }, + { + name: "with body", + req: &ParsedRequest{ + Body: []byte("toto"), + Headers: map[string][]string{"toto": {"lol"}}, + }, + expect: &ParsedRequest{ + Body: []byte("toto"), + Headers: map[string][]string{"toto": {"lol"}}, + }, + filter: func(r *ReqDumpFilter) *ReqDumpFilter { + return r.WithBody() + }, + }, + { + name: "with empty args filter", + req: &ParsedRequest{ + Args: map[string][]string{"toto": {"lol"}}, + }, + expect: &ParsedRequest{ + Args: map[string][]string{"toto": {"lol"}}, + }, + filter: func(r *ReqDumpFilter) *ReqDumpFilter { + return r.WithEmptyArgsFilters() + }, + }, + { + name: "with args name filter", + req: &ParsedRequest{ + Args: map[string][]string{"toto": {"lol"}, "totolol": {"lol"}}, + }, + expect: &ParsedRequest{ + Args: map[string][]string{"totolol": {"lol"}}, + }, + filter: func(r *ReqDumpFilter) *ReqDumpFilter { + return r.WithArgsNameFilter("toto") + }, + }, + { + name: "WithEmptyHeadersFilters", + req: &ParsedRequest{ + Args: map[string][]string{"cookie": {"lol"}, "totolol": {"lol"}}, + }, + expect: &ParsedRequest{ + Args: map[string][]string{"cookie": {"lol"}, "totolol": {"lol"}}, + }, + filter: func(r *ReqDumpFilter) *ReqDumpFilter { + return r.WithEmptyHeadersFilters() + }, + }, + { + name: "WithArgsContentFilters", + req: &ParsedRequest{ + Args: map[string][]string{"test": {"lol"}, "test2": {"toto"}}, + }, + expect: &ParsedRequest{ + Args: map[string][]string{"test": {"lol"}}, + }, + filter: func(r *ReqDumpFilter) *ReqDumpFilter { + return r.WithArgsContentFilter("toto") + }, + }, + } + + for idx, test := range tests { + + t.Run(test.name, func(t *testing.T) { + orig_dr := test.req.DumpRequest() + result := test.filter(orig_dr).GetFilteredRequest() + + if len(result.Body) != len(test.expect.Body) { + t.Fatalf("test %d (%s) failed, got %d, expected %d", idx, test.name, len(test.req.Body), len(test.expect.Body)) + } + if len(result.Headers) != len(test.expect.Headers) { + t.Fatalf("test %d (%s) failed, got %d, expected %d", idx, test.name, len(test.req.Headers), len(test.expect.Headers)) + } + for k, v := range result.Headers { + if len(v) != len(test.expect.Headers[k]) { + t.Fatalf("test %d (%s) failed, got %d, expected %d", idx, test.name, len(v), len(test.expect.Headers[k])) + } + } + }) + } + +} diff --git a/pkg/appsec/tx.go b/pkg/appsec/tx.go new file mode 100644 index 00000000000..47da19d1556 --- /dev/null +++ b/pkg/appsec/tx.go @@ -0,0 +1,93 @@ +package appsec + +import ( + "github.com/crowdsecurity/coraza/v3" + "github.com/crowdsecurity/coraza/v3/experimental" + "github.com/crowdsecurity/coraza/v3/experimental/plugins/plugintypes" + "github.com/crowdsecurity/coraza/v3/types" +) + +type ExtendedTransaction struct { + Tx experimental.FullTransaction +} + +func NewExtendedTransaction(engine coraza.WAF, uuid string) ExtendedTransaction { + inBoundTx := engine.NewTransactionWithID(uuid) + expTx := inBoundTx.(experimental.FullTransaction) + tx := NewTransaction(expTx) + return tx +} + +func NewTransaction(tx experimental.FullTransaction) ExtendedTransaction { + return ExtendedTransaction{Tx: tx} +} + +func (t *ExtendedTransaction) RemoveRuleByIDWithError(id int) error { + t.Tx.RemoveRuleByID(id) + return nil +} + +func (t *ExtendedTransaction) RemoveRuleByTagWithError(tag string) error { + t.Tx.RemoveRuleByTag(tag) + return nil +} + +func (t *ExtendedTransaction) IsRuleEngineOff() bool { + return t.Tx.IsRuleEngineOff() +} + +func (t *ExtendedTransaction) ProcessLogging() { + t.Tx.ProcessLogging() +} + +func (t *ExtendedTransaction) ProcessConnection(client string, cPort int, server string, sPort int) { + t.Tx.ProcessConnection(client, cPort, server, sPort) +} + +func (t *ExtendedTransaction) AddGetRequestArgument(name string, value string) { + t.Tx.AddGetRequestArgument(name, value) +} + +func (t *ExtendedTransaction) ProcessURI(uri string, method string, httpVersion string) { + t.Tx.ProcessURI(uri, method, httpVersion) +} + +func (t *ExtendedTransaction) AddRequestHeader(name string, value string) { + t.Tx.AddRequestHeader(name, value) +} + +func (t *ExtendedTransaction) SetServerName(name string) { + t.Tx.SetServerName(name) +} + +func (t *ExtendedTransaction) ProcessRequestHeaders() *types.Interruption { + return t.Tx.ProcessRequestHeaders() +} + +func (t *ExtendedTransaction) ProcessRequestBody() (*types.Interruption, error) { + return t.Tx.ProcessRequestBody() +} + +func (t *ExtendedTransaction) WriteRequestBody(body []byte) (*types.Interruption, int, error) { + return t.Tx.WriteRequestBody(body) +} + +func (t *ExtendedTransaction) Interruption() *types.Interruption { + return t.Tx.Interruption() +} + +func (t *ExtendedTransaction) IsInterrupted() bool { + return t.Tx.IsInterrupted() +} + +func (t *ExtendedTransaction) Variables() plugintypes.TransactionVariables { + return t.Tx.Variables() +} + +func (t *ExtendedTransaction) MatchedRules() []types.MatchedRule { + return t.Tx.MatchedRules() +} + +func (t *ExtendedTransaction) ID() string { + return t.Tx.ID() +} diff --git a/pkg/appsec/waf_helpers.go b/pkg/appsec/waf_helpers.go new file mode 100644 index 00000000000..3d9a96a0d4f --- /dev/null +++ b/pkg/appsec/waf_helpers.go @@ -0,0 +1,61 @@ +package appsec + +import ( + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +func GetOnLoadEnv(w *AppsecRuntimeConfig) map[string]interface{} { + return map[string]interface{}{ + "RemoveInBandRuleByID": w.DisableInBandRuleByID, + "RemoveInBandRuleByTag": w.DisableInBandRuleByTag, + "RemoveInBandRuleByName": w.DisableInBandRuleByName, + "RemoveOutBandRuleByID": w.DisableOutBandRuleByID, + "RemoveOutBandRuleByTag": w.DisableOutBandRuleByTag, + "RemoveOutBandRuleByName": w.DisableOutBandRuleByName, + "SetRemediationByTag": w.SetActionByTag, + "SetRemediationByID": w.SetActionByID, + "SetRemediationByName": w.SetActionByName, + } +} + +func GetPreEvalEnv(w *AppsecRuntimeConfig, request *ParsedRequest) map[string]interface{} { + return map[string]interface{}{ + "IsInBand": request.IsInBand, + "IsOutBand": request.IsOutBand, + "req": request.HTTPRequest, + "RemoveInBandRuleByID": w.RemoveInbandRuleByID, + "RemoveInBandRuleByName": w.RemoveInbandRuleByName, + "RemoveInBandRuleByTag": w.RemoveInbandRuleByTag, + "RemoveOutBandRuleByID": w.RemoveOutbandRuleByID, + "RemoveOutBandRuleByTag": w.RemoveOutbandRuleByTag, + "RemoveOutBandRuleByName": w.RemoveOutbandRuleByName, + "SetRemediationByTag": w.SetActionByTag, + "SetRemediationByID": w.SetActionByID, + "SetRemediationByName": w.SetActionByName, + } +} + +func GetPostEvalEnv(w *AppsecRuntimeConfig, request *ParsedRequest) map[string]interface{} { + return map[string]interface{}{ + "IsInBand": request.IsInBand, + "IsOutBand": request.IsOutBand, + "DumpRequest": request.DumpRequest, + "req": request.HTTPRequest, + } +} + +func GetOnMatchEnv(w *AppsecRuntimeConfig, request *ParsedRequest, evt types.Event) map[string]interface{} { + return map[string]interface{}{ + "evt": evt, + "req": request.HTTPRequest, + "IsInBand": request.IsInBand, + "IsOutBand": request.IsOutBand, + "SetRemediation": w.SetAction, + "SetReturnCode": w.SetHTTPCode, + "CancelEvent": w.CancelEvent, + "SendEvent": w.SendEvent, + "CancelAlert": w.CancelAlert, + "SendAlert": w.SendAlert, + "DumpRequest": request.DumpRequest, + } +} diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index 0f3b9c4a0d5..8a696caf1f4 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -1,6 +1,8 @@ package cache import ( + "errors" + "fmt" "time" "github.com/bluele/gcache" @@ -10,9 +12,11 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -var Caches []gcache.Cache -var CacheNames []string -var CacheConfig []CacheCfg +var ( + Caches []gcache.Cache + CacheNames []string + CacheConfig []CacheCfg +) /*prometheus*/ var CacheMetrics = prometheus.NewGaugeVec( @@ -26,6 +30,7 @@ var CacheMetrics = prometheus.NewGaugeVec( // UpdateCacheMetrics is called directly by the prom handler func UpdateCacheMetrics() { CacheMetrics.Reset() + for i, name := range CacheNames { CacheMetrics.With(prometheus.Labels{"name": name, "type": CacheConfig[i].Strategy}).Set(float64(Caches[i].Len(false))) } @@ -41,27 +46,28 @@ type CacheCfg struct { } func CacheInit(cfg CacheCfg) error { - for _, name := range CacheNames { if name == cfg.Name { log.Infof("Cache %s already exists", cfg.Name) } } - //get a default logger + // get a default logger if cfg.LogLevel == nil { cfg.LogLevel = new(log.Level) *cfg.LogLevel = log.InfoLevel } - var clog = log.New() + + clog := log.New() + if err := types.ConfigureLogger(clog); err != nil { - log.Fatalf("While creating cache logger : %s", err) + return fmt.Errorf("while creating cache logger: %w", err) } + clog.SetLevel(*cfg.LogLevel) - cfg.Logger = clog.WithFields(log.Fields{ - "cache": cfg.Name, - }) + cfg.Logger = clog.WithField("cache", cfg.Name) tmpCache := gcache.New(cfg.Size) + switch cfg.Strategy { case "LRU": tmpCache = tmpCache.LRU() @@ -72,7 +78,6 @@ func CacheInit(cfg CacheCfg) error { default: cfg.Strategy = "LRU" tmpCache = tmpCache.LRU() - } CTICache := tmpCache.Build() @@ -84,36 +89,42 @@ func CacheInit(cfg CacheCfg) error { } func SetKey(cacheName string, key string, value string, expiration *time.Duration) error { - for i, name := range CacheNames { if name == cacheName { if expiration == nil { expiration = &CacheConfig[i].TTL } + CacheConfig[i].Logger.Debugf("Setting key %s to %s with expiration %v", key, value, *expiration) + if err := Caches[i].SetWithExpire(key, value, *expiration); err != nil { CacheConfig[i].Logger.Warningf("While setting key %s in cache %s: %s", key, cacheName, err) } } } + return nil } func GetKey(cacheName string, key string) (string, error) { for i, name := range CacheNames { if name == cacheName { - if value, err := Caches[i].Get(key); err != nil { - //do not warn or log if key not found - if err == gcache.KeyNotFoundError { + value, err := Caches[i].Get(key) + if err != nil { + // do not warn or log if key not found + if errors.Is(err, gcache.KeyNotFoundError) { return "", nil } CacheConfig[i].Logger.Warningf("While getting key %s in cache %s: %s", key, cacheName, err) + return "", err - } else { - return value.(string), nil } + + return value.(string), nil } } + log.Warningf("Cache %s not found", cacheName) + return "", nil } diff --git a/pkg/csconfig/api.go b/pkg/csconfig/api.go index 06d6a9712e9..3014b729a9e 100644 --- a/pkg/csconfig/api.go +++ b/pkg/csconfig/api.go @@ -1,19 +1,23 @@ package csconfig import ( + "bytes" "crypto/tls" "crypto/x509" + "errors" "fmt" + "io" "net" "os" "strings" "time" log "github.com/sirupsen/logrus" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" - "github.com/crowdsecurity/go-cs-lib/pkg/ptr" - "github.com/crowdsecurity/go-cs-lib/pkg/yamlpatch" + "github.com/crowdsecurity/go-cs-lib/csstring" + "github.com/crowdsecurity/go-cs-lib/ptr" + "github.com/crowdsecurity/go-cs-lib/yamlpatch" "github.com/crowdsecurity/crowdsec/pkg/apiclient" ) @@ -56,40 +60,58 @@ type CTICfg struct { } func (a *CTICfg) Load() error { - if a.Key == nil { - *a.Enabled = false + a.Enabled = ptr.Of(false) } + if a.Key != nil && *a.Key == "" { - return fmt.Errorf("empty cti key") + return errors.New("empty cti key") } + if a.Enabled == nil { - a.Enabled = new(bool) - *a.Enabled = true + a.Enabled = ptr.Of(true) } + if a.CacheTimeout == nil { a.CacheTimeout = new(time.Duration) *a.CacheTimeout = 10 * time.Minute } + if a.CacheSize == nil { a.CacheSize = new(int) *a.CacheSize = 100 } + return nil } func (o *OnlineApiClientCfg) Load() error { o.Credentials = new(ApiCredentialsCfg) + fcontent, err := os.ReadFile(o.CredentialsFilePath) if err != nil { - return fmt.Errorf("failed to read api server credentials configuration file '%s': %w", o.CredentialsFilePath, err) + return err } - err = yaml.UnmarshalStrict(fcontent, o.Credentials) + + dec := yaml.NewDecoder(bytes.NewReader(fcontent)) + dec.KnownFields(true) + + err = dec.Decode(o.Credentials) if err != nil { - return fmt.Errorf("failed unmarshaling api server credentials configuration file '%s': %w", o.CredentialsFilePath, err) + if !errors.Is(err, io.EOF) { + return fmt.Errorf("failed to parse api server credentials configuration file '%s': %w", o.CredentialsFilePath, err) + } } - if o.Credentials.Login == "" || o.Credentials.Password == "" || o.Credentials.URL == "" { - log.Warningf("can't load CAPI credentials from '%s' (missing field)", o.CredentialsFilePath) + + switch { + case o.Credentials.Login == "": + log.Warningf("can't load CAPI credentials from '%s' (missing login field)", o.CredentialsFilePath) + o.Credentials = nil + case o.Credentials.Password == "": + log.Warningf("can't load CAPI credentials from '%s' (missing password field)", o.CredentialsFilePath) + o.Credentials = nil + case o.Credentials.URL == "": + log.Warningf("can't load CAPI credentials from '%s' (missing url field)", o.CredentialsFilePath) o.Credentials = nil } @@ -98,26 +120,49 @@ func (o *OnlineApiClientCfg) Load() error { func (l *LocalApiClientCfg) Load() error { patcher := yamlpatch.NewPatcher(l.CredentialsFilePath, ".local") + fcontent, err := patcher.MergedPatchContent() if err != nil { return err } - err = yaml.UnmarshalStrict(fcontent, &l.Credentials) + + configData := csstring.StrictExpand(string(fcontent), os.LookupEnv) + + dec := yaml.NewDecoder(strings.NewReader(configData)) + dec.KnownFields(true) + + err = dec.Decode(&l.Credentials) if err != nil { - return fmt.Errorf("failed unmarshaling api client credential configuration file '%s': %w", l.CredentialsFilePath, err) + if !errors.Is(err, io.EOF) { + return fmt.Errorf("failed to parse api client credential configuration file '%s': %w", l.CredentialsFilePath, err) + } } + if l.Credentials == nil || l.Credentials.URL == "" { return fmt.Errorf("no credentials or URL found in api client configuration '%s'", l.CredentialsFilePath) } if l.Credentials != nil && l.Credentials.URL != "" { - if !strings.HasSuffix(l.Credentials.URL, "/") { + // don't append a trailing slash if the URL is a unix socket + if strings.HasPrefix(l.Credentials.URL, "http") && !strings.HasSuffix(l.Credentials.URL, "/") { l.Credentials.URL += "/" } } - if l.Credentials.Login != "" && (l.Credentials.CertPath != "" || l.Credentials.KeyPath != "") { - return fmt.Errorf("user/password authentication and TLS authentication are mutually exclusive") + // is the configuration asking for client authentication via TLS? + credTLSClientAuth := l.Credentials.CertPath != "" || l.Credentials.KeyPath != "" + + // is the configuration asking for TLS encryption and server authentication? + credTLS := credTLSClientAuth || l.Credentials.CACertPath != "" + + credSocket := strings.HasPrefix(l.Credentials.URL, "/") + + if credTLS && credSocket { + return errors.New("cannot use TLS with a unix socket") + } + + if credTLSClientAuth && l.Credentials.Login != "" { + return errors.New("user/password authentication and TLS authentication are mutually exclusive") } if l.InsecureSkipVerify == nil { @@ -136,9 +181,11 @@ func (l *LocalApiClientCfg) Load() error { if err != nil { log.Warningf("Error loading system CA certificates: %s", err) } + if caCertPool == nil { caCertPool = x509.NewCertPool() } + caCertPool.AppendCertsFromPEM(caCert) apiclient.CaCertPool = caCertPool } @@ -155,16 +202,20 @@ func (l *LocalApiClientCfg) Load() error { return nil } -func (lapiCfg *LocalApiServerCfg) GetTrustedIPs() ([]net.IPNet, error) { +func (c *LocalApiServerCfg) GetTrustedIPs() ([]net.IPNet, error) { trustedIPs := make([]net.IPNet, 0) - for _, ip := range lapiCfg.TrustedIPs { + + for _, ip := range c.TrustedIPs { cidr := toValidCIDR(ip) + _, ipNet, err := net.ParseCIDR(cidr) if err != nil { return nil, err } + trustedIPs = append(trustedIPs, *ipNet) } + return trustedIPs, nil } @@ -176,6 +227,7 @@ func toValidCIDR(ip string) string { if strings.Contains(ip, ":") { return ip + "/128" } + return ip + "/32" } @@ -184,57 +236,91 @@ type CapiWhitelist struct { Cidrs []*net.IPNet `yaml:"cidrs,omitempty"` } +type LocalAPIAutoRegisterCfg struct { + Enable *bool `yaml:"enabled"` + Token string `yaml:"token"` + AllowedRanges []string `yaml:"allowed_ranges,omitempty"` + AllowedRangesParsed []*net.IPNet `yaml:"-"` +} + /*local api service configuration*/ type LocalApiServerCfg struct { - Enable *bool `yaml:"enable"` - ListenURI string `yaml:"listen_uri,omitempty"` // 127.0.0.1:8080 - TLS *TLSCfg `yaml:"tls"` - DbConfig *DatabaseCfg `yaml:"-"` - LogDir string `yaml:"-"` - LogMedia string `yaml:"-"` - OnlineClient *OnlineApiClientCfg `yaml:"online_client"` - ProfilesPath string `yaml:"profiles_path,omitempty"` - ConsoleConfigPath string `yaml:"console_path,omitempty"` - ConsoleConfig *ConsoleConfig `yaml:"-"` - Profiles []*ProfileCfg `yaml:"-"` - LogLevel *log.Level `yaml:"log_level"` - UseForwardedForHeaders bool `yaml:"use_forwarded_for_headers,omitempty"` - TrustedProxies *[]string `yaml:"trusted_proxies,omitempty"` - CompressLogs *bool `yaml:"-"` - LogMaxSize int `yaml:"-"` - LogMaxAge int `yaml:"-"` - LogMaxFiles int `yaml:"-"` - TrustedIPs []string `yaml:"trusted_ips,omitempty"` - PapiLogLevel *log.Level `yaml:"papi_log_level"` - DisableRemoteLapiRegistration bool `yaml:"disable_remote_lapi_registration,omitempty"` - CapiWhitelistsPath string `yaml:"capi_whitelists_path,omitempty"` - CapiWhitelists *CapiWhitelist `yaml:"-"` + Enable *bool `yaml:"enable"` + ListenURI string `yaml:"listen_uri,omitempty"` // 127.0.0.1:8080 + ListenSocket string `yaml:"listen_socket,omitempty"` + TLS *TLSCfg `yaml:"tls"` + DbConfig *DatabaseCfg `yaml:"-"` + LogDir string `yaml:"-"` + LogMedia string `yaml:"-"` + OnlineClient *OnlineApiClientCfg `yaml:"online_client"` + ProfilesPath string `yaml:"profiles_path,omitempty"` + ConsoleConfigPath string `yaml:"console_path,omitempty"` + ConsoleConfig *ConsoleConfig `yaml:"-"` + Profiles []*ProfileCfg `yaml:"-"` + LogLevel *log.Level `yaml:"log_level"` + UseForwardedForHeaders bool `yaml:"use_forwarded_for_headers,omitempty"` + TrustedProxies *[]string `yaml:"trusted_proxies,omitempty"` + CompressLogs *bool `yaml:"-"` + LogMaxSize int `yaml:"-"` + LogMaxAge int `yaml:"-"` + LogMaxFiles int `yaml:"-"` + TrustedIPs []string `yaml:"trusted_ips,omitempty"` + PapiLogLevel *log.Level `yaml:"papi_log_level"` + DisableRemoteLapiRegistration bool `yaml:"disable_remote_lapi_registration,omitempty"` + CapiWhitelistsPath string `yaml:"capi_whitelists_path,omitempty"` + CapiWhitelists *CapiWhitelist `yaml:"-"` + AutoRegister *LocalAPIAutoRegisterCfg `yaml:"auto_registration,omitempty"` } -type TLSCfg struct { - CertFilePath string `yaml:"cert_file"` - KeyFilePath string `yaml:"key_file"` - ClientVerification string `yaml:"client_verification,omitempty"` - ServerName string `yaml:"server_name"` - CACertPath string `yaml:"ca_cert_path"` - AllowedAgentsOU []string `yaml:"agents_allowed_ou"` - AllowedBouncersOU []string `yaml:"bouncers_allowed_ou"` - CRLPath string `yaml:"crl_path"` - CacheExpiration *time.Duration `yaml:"cache_expiration,omitempty"` +func (c *LocalApiServerCfg) ClientURL() string { + if c == nil { + return "" + } + + if c.ListenSocket != "" { + return c.ListenSocket + } + + if c.ListenURI != "" { + return "http://" + c.ListenURI + } + + return "" } -func (c *Config) LoadAPIServer() error { +func (c *Config) LoadAPIServer(inCli bool) error { if c.DisableAPI { log.Warning("crowdsec local API is disabled from flag") } if c.API.Server == nil { log.Warning("crowdsec local API is disabled") + c.DisableAPI = true + return nil } - //inherit log level from common, then api->server + if c.API.Server.Enable == nil { + // if the option is not present, it is enabled by default + c.API.Server.Enable = ptr.Of(true) + } + + if !*c.API.Server.Enable { + log.Warning("crowdsec local API is disabled because 'enable' is set to false") + + c.DisableAPI = true + } + + if c.DisableAPI { + return nil + } + + if c.API.Server.ListenURI == "" && c.API.Server.ListenSocket == "" { + return errors.New("no listen_uri or listen_socket specified") + } + + // inherit log level from common, then api->server var logLevel log.Level if c.API.Server.LogLevel != nil { logLevel = *c.API.Server.LogLevel @@ -253,10 +339,12 @@ func (c *Config) LoadAPIServer() error { return fmt.Errorf("loading online client credentials: %w", err) } } - if c.API.Server.OnlineClient == nil || c.API.Server.OnlineClient.Credentials == nil { + + if (c.API.Server.OnlineClient == nil || c.API.Server.OnlineClient.Credentials == nil) && !inCli { log.Printf("push and pull to Central API disabled") } - if err := c.LoadDBConfig(); err != nil { + + if err := c.LoadDBConfig(inCli); err != nil { return err } @@ -264,59 +352,45 @@ func (c *Config) LoadAPIServer() error { return err } - if c.API.Server.CapiWhitelistsPath != "" { + if c.API.Server.CapiWhitelistsPath != "" && !inCli { log.Infof("loaded capi whitelist from %s: %d IPs, %d CIDRs", c.API.Server.CapiWhitelistsPath, len(c.API.Server.CapiWhitelists.Ips), len(c.API.Server.CapiWhitelists.Cidrs)) } - if c.API.Server.Enable == nil { - // if the option is not present, it is enabled by default - c.API.Server.Enable = ptr.Of(true) - } - - if !*c.API.Server.Enable { - log.Warning("crowdsec local API is disabled because 'enable' is set to false") - c.DisableAPI = true - return nil + if err := c.API.Server.LoadAutoRegister(); err != nil { + return err } - if c.DisableAPI { - return nil + if c.API.Server.AutoRegister != nil && c.API.Server.AutoRegister.Enable != nil && *c.API.Server.AutoRegister.Enable && !inCli { + log.Infof("auto LAPI registration enabled for ranges %+v", c.API.Server.AutoRegister.AllowedRanges) } - if err := c.LoadCommon(); err != nil { - return fmt.Errorf("loading common configuration: %s", err) - } c.API.Server.LogDir = c.Common.LogDir c.API.Server.LogMedia = c.Common.LogMedia c.API.Server.CompressLogs = c.Common.CompressLogs c.API.Server.LogMaxSize = c.Common.LogMaxSize c.API.Server.LogMaxAge = c.Common.LogMaxAge c.API.Server.LogMaxFiles = c.Common.LogMaxFiles + if c.API.Server.UseForwardedForHeaders && c.API.Server.TrustedProxies == nil { c.API.Server.TrustedProxies = &[]string{"0.0.0.0/0"} } + if c.API.Server.TrustedProxies != nil { c.API.Server.UseForwardedForHeaders = true } + if err := c.API.Server.LoadProfiles(); err != nil { return fmt.Errorf("while loading profiles for LAPI: %w", err) } + if c.API.Server.ConsoleConfigPath == "" { c.API.Server.ConsoleConfigPath = DefaultConsoleConfigFilePath } + if err := c.API.Server.LoadConsoleConfig(); err != nil { return fmt.Errorf("while loading console options: %w", err) } - if c.API.Server.OnlineClient != nil && c.API.Server.OnlineClient.CredentialsFilePath != "" { - if err := c.API.Server.OnlineClient.Load(); err != nil { - return fmt.Errorf("loading online client credentials: %w", err) - } - } - if c.API.Server.OnlineClient == nil || c.API.Server.OnlineClient.Credentials == nil { - log.Printf("push and pull to Central API disabled") - } - if c.API.CTI != nil { if err := c.API.CTI.Load(); err != nil { return fmt.Errorf("loading CTI configuration: %w", err) @@ -332,50 +406,111 @@ type capiWhitelists struct { Cidrs []string `yaml:"cidrs"` } -func (s *LocalApiServerCfg) LoadCapiWhitelists() error { - if s.CapiWhitelistsPath == "" { - return nil - } - if _, err := os.Stat(s.CapiWhitelistsPath); os.IsNotExist(err) { - return fmt.Errorf("capi whitelist file '%s' does not exist", s.CapiWhitelistsPath) - } - fd, err := os.Open(s.CapiWhitelistsPath) - if err != nil { - return fmt.Errorf("unable to open capi whitelist file '%s': %s", s.CapiWhitelistsPath, err) - } - - var fromCfg capiWhitelists - s.CapiWhitelists = &CapiWhitelist{} +func parseCapiWhitelists(fd io.Reader) (*CapiWhitelist, error) { + fromCfg := capiWhitelists{} - defer fd.Close() decoder := yaml.NewDecoder(fd) if err := decoder.Decode(&fromCfg); err != nil { - return fmt.Errorf("while parsing capi whitelist file '%s': %s", s.CapiWhitelistsPath, err) + if errors.Is(err, io.EOF) { + return nil, errors.New("empty file") + } + + return nil, err } - for _, v := range fromCfg.Ips { + + ret := &CapiWhitelist{ + Ips: make([]net.IP, len(fromCfg.Ips)), + Cidrs: make([]*net.IPNet, len(fromCfg.Cidrs)), + } + + for idx, v := range fromCfg.Ips { ip := net.ParseIP(v) if ip == nil { - return fmt.Errorf("unable to parse ip whitelist '%s'", v) + return nil, fmt.Errorf("invalid IP address: %s", v) } - s.CapiWhitelists.Ips = append(s.CapiWhitelists.Ips, ip) + + ret.Ips[idx] = ip } - for _, v := range fromCfg.Cidrs { + + for idx, v := range fromCfg.Cidrs { _, tnet, err := net.ParseCIDR(v) if err != nil { - return fmt.Errorf("unable to parse cidr whitelist '%s' : %v", v, err) + return nil, err } - s.CapiWhitelists.Cidrs = append(s.CapiWhitelists.Cidrs, tnet) + + ret.Cidrs[idx] = tnet + } + + return ret, nil +} + +func (c *LocalApiServerCfg) LoadCapiWhitelists() error { + if c.CapiWhitelistsPath == "" { + return nil + } + + fd, err := os.Open(c.CapiWhitelistsPath) + if err != nil { + return fmt.Errorf("while opening capi whitelist file: %w", err) + } + + defer fd.Close() + + c.CapiWhitelists, err = parseCapiWhitelists(fd) + if err != nil { + return fmt.Errorf("while parsing capi whitelist file '%s': %w", c.CapiWhitelistsPath, err) } + return nil } func (c *Config) LoadAPIClient() error { if c.API == nil || c.API.Client == nil || c.API.Client.CredentialsFilePath == "" || c.DisableAgent { - return fmt.Errorf("no API client section in configuration") + return errors.New("no API client section in configuration") } - if err := c.API.Client.Load(); err != nil { - return err + return c.API.Client.Load() +} + +func (c *LocalApiServerCfg) LoadAutoRegister() error { + if c.AutoRegister == nil { + c.AutoRegister = &LocalAPIAutoRegisterCfg{ + Enable: ptr.Of(false), + } + + return nil + } + + // Disable by default + if c.AutoRegister.Enable == nil { + c.AutoRegister.Enable = ptr.Of(false) + } + + if !*c.AutoRegister.Enable { + return nil + } + + if c.AutoRegister.Token == "" { + return errors.New("missing token value for api.server.auto_register") + } + + if len(c.AutoRegister.Token) < 32 { + return errors.New("token value for api.server.auto_register is too short (min 32 characters)") + } + + if c.AutoRegister.AllowedRanges == nil { + return errors.New("missing allowed_ranges value for api.server.auto_register") + } + + c.AutoRegister.AllowedRangesParsed = make([]*net.IPNet, 0, len(c.AutoRegister.AllowedRanges)) + + for _, ipRange := range c.AutoRegister.AllowedRanges { + _, ipNet, err := net.ParseCIDR(ipRange) + if err != nil { + return fmt.Errorf("auto_register: failed to parse allowed range '%s': %w", ipRange, err) + } + + c.AutoRegister.AllowedRangesParsed = append(c.AutoRegister.AllowedRangesParsed, ipNet) } return nil diff --git a/pkg/csconfig/api_test.go b/pkg/csconfig/api_test.go index 7450800e97a..dff3c3afc8c 100644 --- a/pkg/csconfig/api_test.go +++ b/pkg/csconfig/api_test.go @@ -1,18 +1,18 @@ package csconfig import ( - "fmt" + "net" "os" - "path/filepath" "strings" "testing" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" - "gopkg.in/yaml.v2" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" - "github.com/crowdsecurity/go-cs-lib/pkg/cstest" - "github.com/crowdsecurity/go-cs-lib/pkg/ptr" + "github.com/crowdsecurity/go-cs-lib/cstest" + "github.com/crowdsecurity/go-cs-lib/ptr" ) func TestLoadLocalApiClientCfg(t *testing.T) { @@ -25,7 +25,7 @@ func TestLoadLocalApiClientCfg(t *testing.T) { { name: "basic valid configuration", input: &LocalApiClientCfg{ - CredentialsFilePath: "./tests/lapi-secrets.yaml", + CredentialsFilePath: "./testdata/lapi-secrets.yaml", }, expected: &ApiCredentialsCfg{ URL: "http://localhost:8080/", @@ -36,7 +36,7 @@ func TestLoadLocalApiClientCfg(t *testing.T) { { name: "invalid configuration", input: &LocalApiClientCfg{ - CredentialsFilePath: "./tests/bad_lapi-secrets.yaml", + CredentialsFilePath: "./testdata/bad_lapi-secrets.yaml", }, expected: &ApiCredentialsCfg{}, expectedErr: "field unknown_key not found in type csconfig.ApiCredentialsCfg", @@ -44,15 +44,15 @@ func TestLoadLocalApiClientCfg(t *testing.T) { { name: "invalid configuration filepath", input: &LocalApiClientCfg{ - CredentialsFilePath: "./tests/nonexist_lapi-secrets.yaml", + CredentialsFilePath: "./testdata/nonexist_lapi-secrets.yaml", }, expected: nil, - expectedErr: "open ./tests/nonexist_lapi-secrets.yaml: " + cstest.FileNotFoundMessage, + expectedErr: "open ./testdata/nonexist_lapi-secrets.yaml: " + cstest.FileNotFoundMessage, }, { name: "valid configuration with insecure skip verify", input: &LocalApiClientCfg{ - CredentialsFilePath: "./tests/lapi-secrets.yaml", + CredentialsFilePath: "./testdata/lapi-secrets.yaml", InsecureSkipVerify: ptr.Of(false), }, expected: &ApiCredentialsCfg{ @@ -64,10 +64,10 @@ func TestLoadLocalApiClientCfg(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { err := tc.input.Load() cstest.RequireErrorContains(t, err, tc.expectedErr) + if tc.expectedErr != "" { return } @@ -87,7 +87,7 @@ func TestLoadOnlineApiClientCfg(t *testing.T) { { name: "basic valid configuration", input: &OnlineApiClientCfg{ - CredentialsFilePath: "./tests/online-api-secrets.yaml", + CredentialsFilePath: "./testdata/online-api-secrets.yaml", }, expected: &ApiCredentialsCfg{ URL: "http://crowdsec.api", @@ -98,33 +98,33 @@ func TestLoadOnlineApiClientCfg(t *testing.T) { { name: "invalid configuration", input: &OnlineApiClientCfg{ - CredentialsFilePath: "./tests/bad_lapi-secrets.yaml", + CredentialsFilePath: "./testdata/bad_lapi-secrets.yaml", }, expected: &ApiCredentialsCfg{}, - expectedErr: "failed unmarshaling api server credentials", + expectedErr: "failed to parse api server credentials", }, { name: "missing field configuration", input: &OnlineApiClientCfg{ - CredentialsFilePath: "./tests/bad_online-api-secrets.yaml", + CredentialsFilePath: "./testdata/bad_online-api-secrets.yaml", }, expected: nil, }, { name: "invalid configuration filepath", input: &OnlineApiClientCfg{ - CredentialsFilePath: "./tests/nonexist_online-api-secrets.yaml", + CredentialsFilePath: "./testdata/nonexist_online-api-secrets.yaml", }, expected: &ApiCredentialsCfg{}, - expectedErr: "failed to read api server credentials", + expectedErr: "open ./testdata/nonexist_online-api-secrets.yaml: " + cstest.FileNotFoundMessage, }, } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { err := tc.input.Load() cstest.RequireErrorContains(t, err, tc.expectedErr) + if tc.expectedErr != "" { return } @@ -136,27 +136,24 @@ func TestLoadOnlineApiClientCfg(t *testing.T) { func TestLoadAPIServer(t *testing.T) { tmpLAPI := &LocalApiServerCfg{ - ProfilesPath: "./tests/profiles.yaml", - } - if err := tmpLAPI.LoadProfiles(); err != nil { - t.Fatalf("loading tmp profiles: %+v", err) + ProfilesPath: "./testdata/profiles.yaml", } + err := tmpLAPI.LoadProfiles() + require.NoError(t, err) - LogDirFullPath, err := filepath.Abs("./tests") - if err != nil { - t.Fatal(err) - } logLevel := log.InfoLevel config := &Config{} - fcontent, err := os.ReadFile("./tests/config.yaml") - if err != nil { - t.Fatal(err) - } + fcontent, err := os.ReadFile("./testdata/config.yaml") + require.NoError(t, err) + configData := os.ExpandEnv(string(fcontent)) - err = yaml.UnmarshalStrict([]byte(configData), &config) - if err != nil { - t.Fatal(err) - } + + dec := yaml.NewDecoder(strings.NewReader(configData)) + dec.KnownFields(true) + + err = dec.Decode(&config) + require.NoError(t, err) + tests := []struct { name string input *Config @@ -171,18 +168,18 @@ func TestLoadAPIServer(t *testing.T) { Server: &LocalApiServerCfg{ ListenURI: "http://crowdsec.api", OnlineClient: &OnlineApiClientCfg{ - CredentialsFilePath: "./tests/online-api-secrets.yaml", + CredentialsFilePath: "./testdata/online-api-secrets.yaml", }, - ProfilesPath: "./tests/profiles.yaml", + ProfilesPath: "./testdata/profiles.yaml", PapiLogLevel: &logLevel, }, }, DbConfig: &DatabaseCfg{ Type: "sqlite", - DbPath: "./tests/test.db", + DbPath: "./testdata/test.db", }, Common: &CommonCfg{ - LogDir: "./tests/", + LogDir: "./testdata", LogMedia: "stdout", }, DisableAPI: false, @@ -192,9 +189,11 @@ func TestLoadAPIServer(t *testing.T) { ListenURI: "http://crowdsec.api", TLS: nil, DbConfig: &DatabaseCfg{ - DbPath: "./tests/test.db", - Type: "sqlite", - MaxOpenConns: ptr.Of(DEFAULT_MAX_OPEN_CONNS), + DbPath: "./testdata/test.db", + Type: "sqlite", + MaxOpenConns: DEFAULT_MAX_OPEN_CONNS, + UseWal: ptr.Of(true), // autodetected + DecisionBulkSize: defaultDecisionBulkSize, }, ConsoleConfigPath: DefaultConfigPath("console.yaml"), ConsoleConfig: &ConsoleConfig{ @@ -204,10 +203,10 @@ func TestLoadAPIServer(t *testing.T) { ShareContext: ptr.Of(false), ConsoleManagement: ptr.Of(false), }, - LogDir: LogDirFullPath, + LogDir: "./testdata", LogMedia: "stdout", OnlineClient: &OnlineApiClientCfg{ - CredentialsFilePath: "./tests/online-api-secrets.yaml", + CredentialsFilePath: "./testdata/online-api-secrets.yaml", Credentials: &ApiCredentialsCfg{ URL: "http://crowdsec.api", Login: "test", @@ -215,9 +214,15 @@ func TestLoadAPIServer(t *testing.T) { }, }, Profiles: tmpLAPI.Profiles, - ProfilesPath: "./tests/profiles.yaml", + ProfilesPath: "./testdata/profiles.yaml", UseForwardedForHeaders: false, PapiLogLevel: &logLevel, + AutoRegister: &LocalAPIAutoRegisterCfg{ + Enable: ptr.Of(false), + Token: "", + AllowedRanges: nil, + AllowedRangesParsed: nil, + }, }, }, { @@ -225,36 +230,97 @@ func TestLoadAPIServer(t *testing.T) { input: &Config{ Self: []byte(configData), API: &APICfg{ - Server: &LocalApiServerCfg{}, + Server: &LocalApiServerCfg{ + ListenURI: "http://crowdsec.api", + }, }, Common: &CommonCfg{ - LogDir: "./tests/", + LogDir: "./testdata/", LogMedia: "stdout", }, DisableAPI: false, }, expected: &LocalApiServerCfg{ + Enable: ptr.Of(true), PapiLogLevel: &logLevel, }, expectedErr: "no database configuration provided", }, } - for idx, test := range tests { - err := test.input.LoadAPIServer() - if err == nil && test.expectedErr != "" { - fmt.Printf("TEST '%s': NOK\n", test.name) - t.Fatalf("Test number %d/%d expected error, didn't get it", idx+1, len(tests)) - } else if test.expectedErr != "" { - fmt.Printf("ERR: %+v\n", err) - if !strings.HasPrefix(fmt.Sprintf("%s", err), test.expectedErr) { - fmt.Printf("TEST '%s': NOK\n", test.name) - t.Fatalf("%d/%d expected '%s' got '%s'", idx, len(tests), - test.expectedErr, - fmt.Sprintf("%s", err)) + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := tc.input.LoadAPIServer(false) + cstest.RequireErrorContains(t, err, tc.expectedErr) + + if tc.expectedErr != "" { + return + } + + assert.Equal(t, tc.expected, tc.input.API.Server) + }) + } +} + +func mustParseCIDRNet(t *testing.T, s string) *net.IPNet { + _, ipNet, err := net.ParseCIDR(s) + require.NoError(t, err) + + return ipNet +} + +func TestParseCapiWhitelists(t *testing.T) { + tests := []struct { + name string + input string + expected *CapiWhitelist + expectedErr string + }{ + { + name: "empty file", + input: "", + expected: &CapiWhitelist{ + Ips: []net.IP{}, + Cidrs: []*net.IPNet{}, + }, + expectedErr: "empty file", + }, + { + name: "empty ip and cidr", + input: `{"ips": [], "cidrs": []}`, + expected: &CapiWhitelist{ + Ips: []net.IP{}, + Cidrs: []*net.IPNet{}, + }, + }, + { + name: "some ip", + input: `{"ips": ["1.2.3.4"]}`, + expected: &CapiWhitelist{ + Ips: []net.IP{net.IPv4(1, 2, 3, 4)}, + Cidrs: []*net.IPNet{}, + }, + }, + { + name: "some cidr", + input: `{"cidrs": ["1.2.3.0/24"]}`, + expected: &CapiWhitelist{ + Ips: []net.IP{}, + Cidrs: []*net.IPNet{mustParseCIDRNet(t, "1.2.3.0/24")}, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + wl, err := parseCapiWhitelists(strings.NewReader(tc.input)) + cstest.RequireErrorContains(t, err, tc.expectedErr) + + if tc.expectedErr != "" { + return } - assert.Equal(t, test.expected, test.input.API.Server) - } + assert.Equal(t, tc.expected, wl) + }) } } diff --git a/pkg/csconfig/common.go b/pkg/csconfig/common.go index 9d80cd95ac1..7e1ef6e5c98 100644 --- a/pkg/csconfig/common.go +++ b/pkg/csconfig/common.go @@ -14,7 +14,7 @@ type CommonCfg struct { LogMedia string `yaml:"log_media"` LogDir string `yaml:"log_dir,omitempty"` //if LogMedia = file LogLevel *log.Level `yaml:"log_level"` - WorkingDir string `yaml:"working_dir,omitempty"` ///var/run + WorkingDir string `yaml:"working_dir,omitempty"` // TODO: This is just for backward compat. Remove this later CompressLogs *bool `yaml:"compress_logs,omitempty"` LogMaxSize int `yaml:"log_max_size,omitempty"` LogMaxAge int `yaml:"log_max_age,omitempty"` @@ -22,15 +22,18 @@ type CommonCfg struct { ForceColorLogs bool `yaml:"force_color_logs,omitempty"` } -func (c *Config) LoadCommon() error { +func (c *Config) loadCommon() error { var err error if c.Common == nil { - return fmt.Errorf("no common block provided in configuration file") + c.Common = &CommonCfg{} + } + + if c.Common.LogMedia == "" { + c.Common.LogMedia = "stdout" } var CommonCleanup = []*string{ &c.Common.LogDir, - &c.Common.WorkingDir, } for _, k := range CommonCleanup { if *k == "" { diff --git a/pkg/csconfig/common_test.go b/pkg/csconfig/common_test.go deleted file mode 100644 index 5666f2d7472..00000000000 --- a/pkg/csconfig/common_test.go +++ /dev/null @@ -1,94 +0,0 @@ -package csconfig - -import ( - "fmt" - "path/filepath" - "strings" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestLoadCommon(t *testing.T) { - pidDirPath := "./tests" - LogDirFullPath, err := filepath.Abs("./tests/log/") - if err != nil { - t.Fatal(err) - } - - WorkingDirFullPath, err := filepath.Abs("./tests") - if err != nil { - t.Fatal(err) - } - - tests := []struct { - name string - Input *Config - expectedResult *CommonCfg - err string - }{ - { - name: "basic valid configuration", - Input: &Config{ - Common: &CommonCfg{ - Daemonize: true, - PidDir: "./tests", - LogMedia: "file", - LogDir: "./tests/log/", - WorkingDir: "./tests/", - }, - }, - expectedResult: &CommonCfg{ - Daemonize: true, - PidDir: pidDirPath, - LogMedia: "file", - LogDir: LogDirFullPath, - WorkingDir: WorkingDirFullPath, - }, - }, - { - name: "empty working dir", - Input: &Config{ - Common: &CommonCfg{ - Daemonize: true, - PidDir: "./tests", - LogMedia: "file", - LogDir: "./tests/log/", - }, - }, - expectedResult: &CommonCfg{ - Daemonize: true, - PidDir: pidDirPath, - LogMedia: "file", - LogDir: LogDirFullPath, - }, - }, - { - name: "no common", - Input: &Config{}, - expectedResult: nil, - }, - } - - for idx, test := range tests { - err := test.Input.LoadCommon() - if err == nil && test.err != "" { - fmt.Printf("TEST '%s': NOK\n", test.name) - t.Fatalf("%d/%d expected error, didn't get it", idx, len(tests)) - } else if test.err != "" { - if !strings.HasPrefix(fmt.Sprintf("%s", err), test.err) { - fmt.Printf("TEST '%s': NOK\n", test.name) - t.Fatalf("%d/%d expected '%s' got '%s'", idx, len(tests), - test.err, - fmt.Sprintf("%s", err)) - } - } - - isOk := assert.Equal(t, test.expectedResult, test.Input.Common) - if !isOk { - t.Fatalf("TEST '%s': NOK", test.name) - } else { - fmt.Printf("TEST '%s': OK\n", test.name) - } - } -} diff --git a/pkg/csconfig/config.go b/pkg/csconfig/config.go index d4b4aa4af3a..3bbdf607187 100644 --- a/pkg/csconfig/config.go +++ b/pkg/csconfig/config.go @@ -1,16 +1,22 @@ +// Package csconfig contains the configuration structures for crowdsec and cscli. package csconfig import ( + "errors" "fmt" + "io" "os" "path/filepath" + "strings" log "github.com/sirupsen/logrus" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" - "github.com/crowdsecurity/go-cs-lib/pkg/csstring" - "github.com/crowdsecurity/go-cs-lib/pkg/ptr" - "github.com/crowdsecurity/go-cs-lib/pkg/yamlpatch" + "github.com/crowdsecurity/go-cs-lib/csstring" + "github.com/crowdsecurity/go-cs-lib/ptr" + "github.com/crowdsecurity/go-cs-lib/yamlpatch" + + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" ) // defaultConfigDir is the base path to all configuration files, to be overridden in the Makefile */ @@ -19,9 +25,11 @@ var defaultConfigDir = "/etc/crowdsec" // defaultDataDir is the base path to all data files, to be overridden in the Makefile */ var defaultDataDir = "/var/lib/crowdsec/data/" +var globalConfig = Config{} + // Config contains top-level defaults -> overridden by configuration file -> overridden by CLI flags type Config struct { - //just a path to ourself :p + // just a path to ourselves :p FilePath *string `yaml:"-"` Self []byte `yaml:"-"` Common *CommonCfg `yaml:"common,omitempty"` @@ -34,25 +42,18 @@ type Config struct { PluginConfig *PluginCfg `yaml:"plugin_config,omitempty"` DisableAPI bool `yaml:"-"` DisableAgent bool `yaml:"-"` - Hub *Hub `yaml:"-"` -} - -func (c *Config) Dump() error { - out, err := yaml.Marshal(c) - if err != nil { - return fmt.Errorf("failed marshaling config: %w", err) - } - fmt.Printf("%s", string(out)) - return nil + Hub *LocalHubCfg `yaml:"-"` } -func NewConfig(configFile string, disableAgent bool, disableAPI bool, quiet bool) (*Config, string, error) { +func NewConfig(configFile string, disableAgent bool, disableAPI bool, inCli bool) (*Config, string, error) { patcher := yamlpatch.NewPatcher(configFile, ".local") - patcher.SetQuiet(quiet) + patcher.SetQuiet(inCli) + fcontent, err := patcher.MergedPatchContent() if err != nil { return nil, "", err } + configData := csstring.StrictExpand(string(fcontent), os.LookupEnv) cfg := Config{ FilePath: &configFile, @@ -60,27 +61,66 @@ func NewConfig(configFile string, disableAgent bool, disableAPI bool, quiet bool DisableAPI: disableAPI, } - err = yaml.UnmarshalStrict([]byte(configData), &cfg) + dec := yaml.NewDecoder(strings.NewReader(configData)) + dec.KnownFields(true) + + err = dec.Decode(&cfg) if err != nil { - // this is actually the "merged" yaml - return nil, "", fmt.Errorf("%s: %w", configFile, err) + if !errors.Is(err, io.EOF) { + // this is actually the "merged" yaml + return nil, "", fmt.Errorf("%s: %w", configFile, err) + } + } + + if cfg.Prometheus == nil { + cfg.Prometheus = &PrometheusCfg{} + } + + if cfg.Prometheus.ListenAddr == "" { + cfg.Prometheus.ListenAddr = "127.0.0.1" + log.Debugf("prometheus.listen_addr is empty, defaulting to %s", cfg.Prometheus.ListenAddr) + } + + if cfg.Prometheus.ListenPort == 0 { + cfg.Prometheus.ListenPort = 6060 + log.Debugf("prometheus.listen_port is empty or zero, defaulting to %d", cfg.Prometheus.ListenPort) } + + if err = cfg.loadCommon(); err != nil { + return nil, "", err + } + + if err = cfg.loadConfigurationPaths(); err != nil { + return nil, "", err + } + + if err = cfg.loadHub(); err != nil { + return nil, "", err + } + + if err = cfg.loadCSCLI(); err != nil { + return nil, "", err + } + + globalConfig = cfg + return &cfg, configData, nil } +func GetConfig() Config { + return globalConfig +} + func NewDefaultConfig() *Config { logLevel := log.InfoLevel commonCfg := CommonCfg{ Daemonize: false, - PidDir: "/tmp/", LogMedia: "stdout", - //LogDir unneeded - LogLevel: &logLevel, - WorkingDir: ".", + LogLevel: &logLevel, } prometheus := PrometheusCfg{ Enabled: true, - Level: "full", + Level: configuration.CFG_METRICS_FULL, } configPaths := ConfigurationPaths{ ConfigDir: DefaultConfigPath("."), @@ -118,7 +158,7 @@ func NewDefaultConfig() *Config { dbConfig := DatabaseCfg{ Type: "sqlite", DbPath: DefaultDataPath("crowdsec.db"), - MaxOpenConns: ptr.Of(DEFAULT_MAX_OPEN_CONNS), + MaxOpenConns: DEFAULT_MAX_OPEN_CONNS, } globalCfg := Config{ diff --git a/pkg/csconfig/config_paths.go b/pkg/csconfig/config_paths.go index 24ff454b78d..a8d39a664f3 100644 --- a/pkg/csconfig/config_paths.go +++ b/pkg/csconfig/config_paths.go @@ -1,6 +1,7 @@ package csconfig import ( + "errors" "fmt" "path/filepath" ) @@ -9,31 +10,36 @@ type ConfigurationPaths struct { ConfigDir string `yaml:"config_dir"` DataDir string `yaml:"data_dir,omitempty"` SimulationFilePath string `yaml:"simulation_path,omitempty"` - HubIndexFile string `yaml:"index_path,omitempty"` //path of the .index.json + HubIndexFile string `yaml:"index_path,omitempty"` // path of the .index.json HubDir string `yaml:"hub_dir,omitempty"` PluginDir string `yaml:"plugin_dir,omitempty"` NotificationDir string `yaml:"notification_dir,omitempty"` + PatternDir string `yaml:"pattern_dir,omitempty"` } -func (c *Config) LoadConfigurationPaths() error { +func (c *Config) loadConfigurationPaths() error { var err error if c.ConfigPaths == nil { - return fmt.Errorf("no configuration paths provided") + return errors.New("no configuration paths provided") } if c.ConfigPaths.DataDir == "" { - return fmt.Errorf("please provide a data directory with the 'data_dir' directive in the 'config_paths' section") + return errors.New("please provide a data directory with the 'data_dir' directive in the 'config_paths' section") } if c.ConfigPaths.HubDir == "" { - c.ConfigPaths.HubDir = filepath.Clean(c.ConfigPaths.ConfigDir + "/hub") + c.ConfigPaths.HubDir = filepath.Join(c.ConfigPaths.ConfigDir, "hub") } if c.ConfigPaths.HubIndexFile == "" { - c.ConfigPaths.HubIndexFile = filepath.Clean(c.ConfigPaths.HubDir + "/.index.json") + c.ConfigPaths.HubIndexFile = filepath.Join(c.ConfigPaths.HubDir, ".index.json") } - var configPathsCleanup = []*string{ + if c.ConfigPaths.PatternDir == "" { + c.ConfigPaths.PatternDir = filepath.Join(c.ConfigPaths.ConfigDir, "patterns") + } + + configPathsCleanup := []*string{ &c.ConfigPaths.HubDir, &c.ConfigPaths.HubIndexFile, &c.ConfigPaths.ConfigDir, @@ -41,6 +47,7 @@ func (c *Config) LoadConfigurationPaths() error { &c.ConfigPaths.SimulationFilePath, &c.ConfigPaths.PluginDir, &c.ConfigPaths.NotificationDir, + &c.ConfigPaths.PatternDir, } for _, k := range configPathsCleanup { if *k == "" { diff --git a/pkg/csconfig/config_test.go b/pkg/csconfig/config_test.go index 7a53d0e7257..b69954de178 100644 --- a/pkg/csconfig/config_test.go +++ b/pkg/csconfig/config_test.go @@ -5,42 +5,42 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" - "github.com/crowdsecurity/go-cs-lib/pkg/cstest" + "github.com/crowdsecurity/go-cs-lib/cstest" ) func TestNormalLoad(t *testing.T) { - _, _, err := NewConfig("./tests/config.yaml", false, false, false) + _, _, err := NewConfig("./testdata/config.yaml", false, false, false) require.NoError(t, err) - _, _, err = NewConfig("./tests/xxx.yaml", false, false, false) - assert.EqualError(t, err, "while reading yaml file: open ./tests/xxx.yaml: "+cstest.FileNotFoundMessage) + _, _, err = NewConfig("./testdata/xxx.yaml", false, false, false) + require.EqualError(t, err, "while reading yaml file: open ./testdata/xxx.yaml: "+cstest.FileNotFoundMessage) - _, _, err = NewConfig("./tests/simulation.yaml", false, false, false) - assert.EqualError(t, err, "./tests/simulation.yaml: yaml: unmarshal errors:\n line 1: field simulation not found in type csconfig.Config") + _, _, err = NewConfig("./testdata/simulation.yaml", false, false, false) + require.EqualError(t, err, "./testdata/simulation.yaml: yaml: unmarshal errors:\n line 1: field simulation not found in type csconfig.Config") } func TestNewCrowdSecConfig(t *testing.T) { tests := []struct { - name string - expectedResult *Config + name string + expected *Config }{ { - name: "new configuration: basic", - expectedResult: &Config{}, + name: "new configuration: basic", + expected: &Config{}, }, } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { result := &Config{} - assert.Equal(t, tc.expectedResult, result) + assert.Equal(t, tc.expected, result) }) } } func TestDefaultConfig(t *testing.T) { x := NewDefaultConfig() - err := x.Dump() - require.NoError(t, err) + _, err := yaml.Marshal(x) + require.NoError(t, err, "failed to serialize config: %s", err) } diff --git a/pkg/csconfig/console.go b/pkg/csconfig/console.go index 5adf6ca3720..21ecbf3d736 100644 --- a/pkg/csconfig/console.go +++ b/pkg/csconfig/console.go @@ -5,11 +5,9 @@ import ( "os" log "github.com/sirupsen/logrus" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" - "github.com/crowdsecurity/go-cs-lib/pkg/ptr" - - "github.com/crowdsecurity/crowdsec/pkg/fflag" + "github.com/crowdsecurity/go-cs-lib/ptr" ) const ( @@ -21,6 +19,13 @@ const ( ) var CONSOLE_CONFIGS = []string{SEND_CUSTOM_SCENARIOS, SEND_MANUAL_SCENARIOS, SEND_TAINTED_SCENARIOS, SEND_CONTEXT, CONSOLE_MANAGEMENT} +var CONSOLE_CONFIGS_HELP = map[string]string{ + SEND_CUSTOM_SCENARIOS: "Forward alerts from custom scenarios to the console", + SEND_MANUAL_SCENARIOS: "Forward manual decisions to the console", + SEND_TAINTED_SCENARIOS: "Forward alerts from tainted scenarios to the console", + SEND_CONTEXT: "Forward context with alerts to the console", + CONSOLE_MANAGEMENT: "Receive decisions from console", +} var DefaultConsoleConfigFilePath = DefaultConfigPath("console.yaml") @@ -32,43 +37,83 @@ type ConsoleConfig struct { ShareContext *bool `yaml:"share_context"` } +func (c *ConsoleConfig) EnabledOptions() []string { + ret := []string{} + if c == nil { + return ret + } + + if c.ShareCustomScenarios != nil && *c.ShareCustomScenarios { + ret = append(ret, SEND_CUSTOM_SCENARIOS) + } + + if c.ShareTaintedScenarios != nil && *c.ShareTaintedScenarios { + ret = append(ret, SEND_TAINTED_SCENARIOS) + } + + if c.ShareManualDecisions != nil && *c.ShareManualDecisions { + ret = append(ret, SEND_MANUAL_SCENARIOS) + } + + if c.ConsoleManagement != nil && *c.ConsoleManagement { + ret = append(ret, CONSOLE_MANAGEMENT) + } + + if c.ShareContext != nil && *c.ShareContext { + ret = append(ret, SEND_CONTEXT) + } + + return ret +} + +func (c *ConsoleConfig) IsPAPIEnabled() bool { + if c == nil || c.ConsoleManagement == nil { + return false + } + + return *c.ConsoleManagement +} + func (c *LocalApiServerCfg) LoadConsoleConfig() error { c.ConsoleConfig = &ConsoleConfig{} if _, err := os.Stat(c.ConsoleConfigPath); err != nil && os.IsNotExist(err) { log.Debugf("no console configuration to load") + c.ConsoleConfig.ShareCustomScenarios = ptr.Of(true) c.ConsoleConfig.ShareTaintedScenarios = ptr.Of(true) c.ConsoleConfig.ShareManualDecisions = ptr.Of(false) c.ConsoleConfig.ConsoleManagement = ptr.Of(false) c.ConsoleConfig.ShareContext = ptr.Of(false) + return nil } yamlFile, err := os.ReadFile(c.ConsoleConfigPath) if err != nil { - return fmt.Errorf("reading console config file '%s': %s", c.ConsoleConfigPath, err) + return fmt.Errorf("reading console config file '%s': %w", c.ConsoleConfigPath, err) } + err = yaml.Unmarshal(yamlFile, c.ConsoleConfig) if err != nil { - return fmt.Errorf("unmarshaling console config file '%s': %s", c.ConsoleConfigPath, err) + return fmt.Errorf("parsing console config file '%s': %w", c.ConsoleConfigPath, err) } if c.ConsoleConfig.ShareCustomScenarios == nil { log.Debugf("no share_custom scenarios found, setting to true") c.ConsoleConfig.ShareCustomScenarios = ptr.Of(true) } + if c.ConsoleConfig.ShareTaintedScenarios == nil { log.Debugf("no share_tainted scenarios found, setting to true") c.ConsoleConfig.ShareTaintedScenarios = ptr.Of(true) } + if c.ConsoleConfig.ShareManualDecisions == nil { log.Debugf("no share_manual scenarios found, setting to false") c.ConsoleConfig.ShareManualDecisions = ptr.Of(false) } - if !fflag.PapiClient.IsEnabled() { - c.ConsoleConfig.ConsoleManagement = ptr.Of(false) - } else if c.ConsoleConfig.ConsoleManagement == nil { + if c.ConsoleConfig.ConsoleManagement == nil { log.Debugf("no console_management found, setting to false") c.ConsoleConfig.ConsoleManagement = ptr.Of(false) } @@ -82,23 +127,3 @@ func (c *LocalApiServerCfg) LoadConsoleConfig() error { return nil } - -func (c *LocalApiServerCfg) DumpConsoleConfig() error { - var out []byte - var err error - - if out, err = yaml.Marshal(c.ConsoleConfig); err != nil { - return fmt.Errorf("while marshaling ConsoleConfig (for %s): %w", c.ConsoleConfigPath, err) - } - if c.ConsoleConfigPath == "" { - c.ConsoleConfigPath = DefaultConsoleConfigFilePath - log.Debugf("Empty console_path, defaulting to %s", c.ConsoleConfigPath) - - } - - if err := os.WriteFile(c.ConsoleConfigPath, out, 0600); err != nil { - return fmt.Errorf("while dumping console config to %s: %w", c.ConsoleConfigPath, err) - } - - return nil -} diff --git a/pkg/csconfig/crowdsec_service.go b/pkg/csconfig/crowdsec_service.go index 2642603cf04..cf796805dee 100644 --- a/pkg/csconfig/crowdsec_service.go +++ b/pkg/csconfig/crowdsec_service.go @@ -6,9 +6,9 @@ import ( "path/filepath" log "github.com/sirupsen/logrus" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" - "github.com/crowdsecurity/go-cs-lib/pkg/ptr" + "github.com/crowdsecurity/go-cs-lib/ptr" ) // CrowdsecServiceCfg contains the location of parsers/scenarios/... and acquisition files @@ -23,15 +23,10 @@ type CrowdsecServiceCfg struct { BucketsRoutinesCount int `yaml:"buckets_routines"` OutputRoutinesCount int `yaml:"output_routines"` SimulationConfig *SimulationConfig `yaml:"-"` - LintOnly bool `yaml:"-"` // if set to true, exit after loading configs BucketStateFile string `yaml:"state_input_file,omitempty"` // if we need to unserialize buckets at start BucketStateDumpDir string `yaml:"state_output_dir,omitempty"` // if we need to unserialize buckets on shutdown BucketsGCEnabled bool `yaml:"-"` // we need to garbage collect buckets when in forensic mode - HubDir string `yaml:"-"` - DataDir string `yaml:"-"` - ConfigDir string `yaml:"-"` - HubIndexFile string `yaml:"-"` SimulationFilePath string `yaml:"-"` ContextToSend map[string][]string `yaml:"-"` } @@ -101,11 +96,6 @@ func (c *Config) LoadCrowdsec() error { return fmt.Errorf("load error (simulation): %w", err) } - c.Crowdsec.ConfigDir = c.ConfigPaths.ConfigDir - c.Crowdsec.DataDir = c.ConfigPaths.DataDir - c.Crowdsec.HubDir = c.ConfigPaths.HubDir - c.Crowdsec.HubIndexFile = c.ConfigPaths.HubIndexFile - if c.Crowdsec.ParserRoutinesCount <= 0 { c.Crowdsec.ParserRoutinesCount = 1 } @@ -118,8 +108,9 @@ func (c *Config) LoadCrowdsec() error { c.Crowdsec.OutputRoutinesCount = 1 } - var crowdsecCleanup = []*string{ + crowdsecCleanup := []*string{ &c.Crowdsec.AcquisitionFilePath, + &c.Crowdsec.ConsoleContextPath, } for _, k := range crowdsecCleanup { @@ -141,54 +132,25 @@ func (c *Config) LoadCrowdsec() error { c.Crowdsec.AcquisitionFiles[i] = f } - if err := c.LoadAPIClient(); err != nil { - return fmt.Errorf("loading api client: %s", err) - } - - if err := c.LoadHub(); err != nil { - return fmt.Errorf("while loading hub: %w", err) - } - - c.Crowdsec.ContextToSend = make(map[string][]string, 0) - fallback := false - if c.Crowdsec.ConsoleContextPath == "" { - // fallback to default config file - c.Crowdsec.ConsoleContextPath = filepath.Join(c.Crowdsec.ConfigDir, "console", "context.yaml") - fallback = true - } - - f, err := filepath.Abs(c.Crowdsec.ConsoleContextPath) - if err != nil { - return fmt.Errorf("fail to get absolute path of %s: %s", c.Crowdsec.ConsoleContextPath, err) - } - - c.Crowdsec.ConsoleContextPath = f - yamlFile, err := os.ReadFile(c.Crowdsec.ConsoleContextPath) - if err != nil { - if fallback { - log.Debugf("Default context config file doesn't exist, will not use it") - } else { - return fmt.Errorf("failed to open context file: %s", err) - } - } else { - err = yaml.Unmarshal(yamlFile, c.Crowdsec.ContextToSend) - if err != nil { - return fmt.Errorf("unmarshaling labels console config file '%s': %s", c.Crowdsec.ConsoleContextPath, err) - } + if err = c.LoadAPIClient(); err != nil { + return fmt.Errorf("loading api client: %w", err) } return nil } func (c *CrowdsecServiceCfg) DumpContextConfigFile() error { - var out []byte - var err error + // XXX: MakeDirs + out, err := yaml.Marshal(c.ContextToSend) + if err != nil { + return fmt.Errorf("while serializing ConsoleConfig (for %s): %w", c.ConsoleContextPath, err) + } - if out, err = yaml.Marshal(c.ContextToSend); err != nil { - return fmt.Errorf("while marshaling ConsoleConfig (for %s): %w", c.ConsoleContextPath, err) + if err = os.MkdirAll(filepath.Dir(c.ConsoleContextPath), 0o700); err != nil { + return fmt.Errorf("while creating directories for %s: %w", c.ConsoleContextPath, err) } - if err := os.WriteFile(c.ConsoleContextPath, out, 0600); err != nil { + if err := os.WriteFile(c.ConsoleContextPath, out, 0o600); err != nil { return fmt.Errorf("while dumping console config to %s: %w", c.ConsoleContextPath, err) } diff --git a/pkg/csconfig/crowdsec_service_test.go b/pkg/csconfig/crowdsec_service_test.go index 5423d1a45f4..7570b63011e 100644 --- a/pkg/csconfig/crowdsec_service_test.go +++ b/pkg/csconfig/crowdsec_service_test.go @@ -1,85 +1,69 @@ package csconfig import ( - "fmt" "path/filepath" "testing" - "github.com/crowdsecurity/go-cs-lib/pkg/cstest" - "github.com/crowdsecurity/go-cs-lib/pkg/ptr" - "github.com/stretchr/testify/require" + + "github.com/crowdsecurity/go-cs-lib/cstest" + "github.com/crowdsecurity/go-cs-lib/ptr" ) func TestLoadCrowdsec(t *testing.T) { - acquisFullPath, err := filepath.Abs("./tests/acquis.yaml") - require.NoError(t, err) - - acquisInDirFullPath, err := filepath.Abs("./tests/acquis/acquis.yaml") - require.NoError(t, err) - - acquisDirFullPath, err := filepath.Abs("./tests/acquis") - require.NoError(t, err) - - hubFullPath, err := filepath.Abs("./hub") - require.NoError(t, err) - - dataFullPath, err := filepath.Abs("./data") + acquisFullPath, err := filepath.Abs("./testdata/acquis.yaml") require.NoError(t, err) - configDirFullPath, err := filepath.Abs("./tests") + acquisInDirFullPath, err := filepath.Abs("./testdata/acquis/acquis.yaml") require.NoError(t, err) - hubIndexFileFullPath, err := filepath.Abs("./hub/.index.json") + acquisDirFullPath, err := filepath.Abs("./testdata/acquis") require.NoError(t, err) - contextFileFullPath, err := filepath.Abs("./tests/context.yaml") + contextFileFullPath, err := filepath.Abs("./testdata/context.yaml") require.NoError(t, err) tests := []struct { - name string - input *Config - expectedResult *CrowdsecServiceCfg - expectedErr string + name string + input *Config + expected *CrowdsecServiceCfg + expectedErr string }{ { name: "basic valid configuration", input: &Config{ ConfigPaths: &ConfigurationPaths{ - ConfigDir: "./tests", + ConfigDir: "./testdata", DataDir: "./data", HubDir: "./hub", }, API: &APICfg{ Client: &LocalApiClientCfg{ - CredentialsFilePath: "./tests/lapi-secrets.yaml", + CredentialsFilePath: "./testdata/lapi-secrets.yaml", }, }, Crowdsec: &CrowdsecServiceCfg{ - AcquisitionFilePath: "./tests/acquis.yaml", - SimulationFilePath: "./tests/simulation.yaml", - ConsoleContextPath: "./tests/context.yaml", + AcquisitionFilePath: "./testdata/acquis.yaml", + SimulationFilePath: "./testdata/simulation.yaml", + ConsoleContextPath: "./testdata/context.yaml", ConsoleContextValueLength: 2500, }, }, - expectedResult: &CrowdsecServiceCfg{ + expected: &CrowdsecServiceCfg{ Enable: ptr.Of(true), AcquisitionDirPath: "", ConsoleContextPath: contextFileFullPath, AcquisitionFilePath: acquisFullPath, - ConfigDir: configDirFullPath, - DataDir: dataFullPath, - HubDir: hubFullPath, - HubIndexFile: hubIndexFileFullPath, BucketsRoutinesCount: 1, ParserRoutinesCount: 1, OutputRoutinesCount: 1, ConsoleContextValueLength: 2500, AcquisitionFiles: []string{acquisFullPath}, - SimulationFilePath: "./tests/simulation.yaml", - ContextToSend: map[string][]string{ - "source_ip": {"evt.Parsed.source_ip"}, - }, + SimulationFilePath: "./testdata/simulation.yaml", + // context is loaded in pkg/alertcontext + // ContextToSend: map[string][]string{ + // "source_ip": {"evt.Parsed.source_ip"}, + // }, SimulationConfig: &SimulationConfig{ Simulation: ptr.Of(false), }, @@ -89,40 +73,37 @@ func TestLoadCrowdsec(t *testing.T) { name: "basic valid configuration with acquisition dir", input: &Config{ ConfigPaths: &ConfigurationPaths{ - ConfigDir: "./tests", + ConfigDir: "./testdata", DataDir: "./data", HubDir: "./hub", }, API: &APICfg{ Client: &LocalApiClientCfg{ - CredentialsFilePath: "./tests/lapi-secrets.yaml", + CredentialsFilePath: "./testdata/lapi-secrets.yaml", }, }, Crowdsec: &CrowdsecServiceCfg{ - AcquisitionFilePath: "./tests/acquis.yaml", - AcquisitionDirPath: "./tests/acquis/", - SimulationFilePath: "./tests/simulation.yaml", - ConsoleContextPath: "./tests/context.yaml", + AcquisitionFilePath: "./testdata/acquis.yaml", + AcquisitionDirPath: "./testdata/acquis/", + SimulationFilePath: "./testdata/simulation.yaml", + ConsoleContextPath: "./testdata/context.yaml", }, }, - expectedResult: &CrowdsecServiceCfg{ + expected: &CrowdsecServiceCfg{ Enable: ptr.Of(true), AcquisitionDirPath: acquisDirFullPath, AcquisitionFilePath: acquisFullPath, ConsoleContextPath: contextFileFullPath, - ConfigDir: configDirFullPath, - HubIndexFile: hubIndexFileFullPath, - DataDir: dataFullPath, - HubDir: hubFullPath, BucketsRoutinesCount: 1, ParserRoutinesCount: 1, OutputRoutinesCount: 1, ConsoleContextValueLength: 0, AcquisitionFiles: []string{acquisFullPath, acquisInDirFullPath}, - ContextToSend: map[string][]string{ - "source_ip": {"evt.Parsed.source_ip"}, - }, - SimulationFilePath: "./tests/simulation.yaml", + // context is loaded in pkg/alertcontext + // ContextToSend: map[string][]string{ + // "source_ip": {"evt.Parsed.source_ip"}, + // }, + SimulationFilePath: "./testdata/simulation.yaml", SimulationConfig: &SimulationConfig{ Simulation: ptr.Of(false), }, @@ -132,28 +113,24 @@ func TestLoadCrowdsec(t *testing.T) { name: "no acquisition file and dir", input: &Config{ ConfigPaths: &ConfigurationPaths{ - ConfigDir: "./tests", + ConfigDir: "./testdata", DataDir: "./data", HubDir: "./hub", }, API: &APICfg{ Client: &LocalApiClientCfg{ - CredentialsFilePath: "./tests/lapi-secrets.yaml", + CredentialsFilePath: "./testdata/lapi-secrets.yaml", }, }, Crowdsec: &CrowdsecServiceCfg{ - ConsoleContextPath: contextFileFullPath, + ConsoleContextPath: "./testdata/context.yaml", ConsoleContextValueLength: 10, }, }, - expectedResult: &CrowdsecServiceCfg{ + expected: &CrowdsecServiceCfg{ Enable: ptr.Of(true), AcquisitionDirPath: "", AcquisitionFilePath: "", - ConfigDir: configDirFullPath, - HubIndexFile: hubIndexFileFullPath, - DataDir: dataFullPath, - HubDir: hubFullPath, ConsoleContextPath: contextFileFullPath, BucketsRoutinesCount: 1, ParserRoutinesCount: 1, @@ -161,9 +138,10 @@ func TestLoadCrowdsec(t *testing.T) { ConsoleContextValueLength: 10, AcquisitionFiles: []string{}, SimulationFilePath: "", - ContextToSend: map[string][]string{ - "source_ip": {"evt.Parsed.source_ip"}, - }, + // context is loaded in pkg/alertcontext + // ContextToSend: map[string][]string{ + // "source_ip": {"evt.Parsed.source_ip"}, + // }, SimulationConfig: &SimulationConfig{ Simulation: ptr.Of(false), }, @@ -173,18 +151,18 @@ func TestLoadCrowdsec(t *testing.T) { name: "non existing acquisition file", input: &Config{ ConfigPaths: &ConfigurationPaths{ - ConfigDir: "./tests", + ConfigDir: "./testdata", DataDir: "./data", HubDir: "./hub", }, API: &APICfg{ Client: &LocalApiClientCfg{ - CredentialsFilePath: "./tests/lapi-secrets.yaml", + CredentialsFilePath: "./testdata/lapi-secrets.yaml", }, }, Crowdsec: &CrowdsecServiceCfg{ ConsoleContextPath: "", - AcquisitionFilePath: "./tests/acquis_not_exist.yaml", + AcquisitionFilePath: "./testdata/acquis_not_exist.yaml", }, }, expectedErr: cstest.FileNotFoundMessage, @@ -193,26 +171,25 @@ func TestLoadCrowdsec(t *testing.T) { name: "agent disabled", input: &Config{ ConfigPaths: &ConfigurationPaths{ - ConfigDir: "./tests", + ConfigDir: "./testdata", DataDir: "./data", HubDir: "./hub", }, }, - expectedResult: nil, + expected: nil, }, } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { - fmt.Printf("TEST '%s'\n", tc.name) err := tc.input.LoadCrowdsec() cstest.RequireErrorContains(t, err, tc.expectedErr) + if tc.expectedErr != "" { return } - require.Equal(t, tc.expectedResult, tc.input.Crowdsec) + require.Equal(t, tc.expected, tc.input.Crowdsec) }) } } diff --git a/pkg/csconfig/cscli.go b/pkg/csconfig/cscli.go index 6b0bf5ae4ad..9393156c0ed 100644 --- a/pkg/csconfig/cscli.go +++ b/pkg/csconfig/cscli.go @@ -1,31 +1,36 @@ package csconfig +import ( + "fmt" +) + /*cscli specific config, such as hub directory*/ type CscliCfg struct { - Output string `yaml:"output,omitempty"` - Color string `yaml:"color,omitempty"` - HubBranch string `yaml:"hub_branch"` - SimulationConfig *SimulationConfig `yaml:"-"` - DbConfig *DatabaseCfg `yaml:"-"` - HubDir string `yaml:"-"` - DataDir string `yaml:"-"` - ConfigDir string `yaml:"-"` - HubIndexFile string `yaml:"-"` - SimulationFilePath string `yaml:"-"` - PrometheusUrl string `yaml:"prometheus_uri"` + Output string `yaml:"output,omitempty"` + Color string `yaml:"color,omitempty"` + HubBranch string `yaml:"hub_branch"` + HubURLTemplate string `yaml:"__hub_url_template__,omitempty"` + SimulationConfig *SimulationConfig `yaml:"-"` + DbConfig *DatabaseCfg `yaml:"-"` + + SimulationFilePath string `yaml:"-"` + PrometheusUrl string `yaml:"prometheus_uri"` } -func (c *Config) LoadCSCLI() error { +const defaultHubURLTemplate = "https://cdn-hub.crowdsec.net/crowdsecurity/%s/%s" + +func (c *Config) loadCSCLI() error { if c.Cscli == nil { c.Cscli = &CscliCfg{} } - if err := c.LoadConfigurationPaths(); err != nil { - return err + + if c.Prometheus.ListenAddr != "" && c.Prometheus.ListenPort != 0 { + c.Cscli.PrometheusUrl = fmt.Sprintf("http://%s:%d/metrics", c.Prometheus.ListenAddr, c.Prometheus.ListenPort) + } + + if c.Cscli.HubURLTemplate == "" { + c.Cscli.HubURLTemplate = defaultHubURLTemplate } - c.Cscli.ConfigDir = c.ConfigPaths.ConfigDir - c.Cscli.DataDir = c.ConfigPaths.DataDir - c.Cscli.HubDir = c.ConfigPaths.HubDir - c.Cscli.HubIndexFile = c.ConfigPaths.HubIndexFile return nil } diff --git a/pkg/csconfig/cscli_test.go b/pkg/csconfig/cscli_test.go index 1f432f6e354..a58fdd6f857 100644 --- a/pkg/csconfig/cscli_test.go +++ b/pkg/csconfig/cscli_test.go @@ -1,84 +1,52 @@ package csconfig import ( - "fmt" - "path/filepath" - "strings" "testing" "github.com/stretchr/testify/assert" + + "github.com/crowdsecurity/go-cs-lib/cstest" ) func TestLoadCSCLI(t *testing.T) { - hubFullPath, err := filepath.Abs("./hub") - if err != nil { - t.Fatal(err) - } - - dataFullPath, err := filepath.Abs("./data") - if err != nil { - t.Fatal(err) - } - - configDirFullPath, err := filepath.Abs("./tests") - if err != nil { - t.Fatal(err) - } - - hubIndexFileFullPath, err := filepath.Abs("./hub/.index.json") - if err != nil { - t.Fatal(err) - } - tests := []struct { - name string - Input *Config - expectedResult *CscliCfg - err string + name string + input *Config + expected *CscliCfg + expectedErr string }{ { name: "basic valid configuration", - Input: &Config{ + input: &Config{ ConfigPaths: &ConfigurationPaths{ - ConfigDir: "./tests", + ConfigDir: "./testdata", DataDir: "./data", HubDir: "./hub", HubIndexFile: "./hub/.index.json", }, + Prometheus: &PrometheusCfg{ + Enabled: true, + Level: "full", + ListenAddr: "127.0.0.1", + ListenPort: 6060, + }, }, - expectedResult: &CscliCfg{ - ConfigDir: configDirFullPath, - DataDir: dataFullPath, - HubDir: hubFullPath, - HubIndexFile: hubIndexFileFullPath, + expected: &CscliCfg{ + PrometheusUrl: "http://127.0.0.1:6060/metrics", + HubURLTemplate: defaultHubURLTemplate, }, }, - { - name: "no configuration path", - Input: &Config{}, - expectedResult: &CscliCfg{}, - }, } - for idx, test := range tests { - err := test.Input.LoadCSCLI() - if err == nil && test.err != "" { - fmt.Printf("TEST '%s': NOK\n", test.name) - t.Fatalf("%d/%d expected error, didn't get it", idx, len(tests)) - } else if test.err != "" { - if !strings.HasPrefix(fmt.Sprintf("%s", err), test.err) { - fmt.Printf("TEST '%s': NOK\n", test.name) - t.Fatalf("%d/%d expected '%s' got '%s'", idx, len(tests), - test.err, - fmt.Sprintf("%s", err)) + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := tc.input.loadCSCLI() + cstest.RequireErrorContains(t, err, tc.expectedErr) + if tc.expectedErr != "" { + return } - } - isOk := assert.Equal(t, test.expectedResult, test.Input.Cscli) - if !isOk { - t.Fatalf("TEST '%s': NOK", test.name) - } else { - fmt.Printf("TEST '%s': OK\n", test.name) - } + assert.Equal(t, tc.expected, tc.input.Cscli) + }) } } diff --git a/pkg/csconfig/database.go b/pkg/csconfig/database.go index 0f886682486..4ca582cf576 100644 --- a/pkg/csconfig/database.go +++ b/pkg/csconfig/database.go @@ -1,30 +1,41 @@ package csconfig import ( + "errors" "fmt" + "path/filepath" "time" "entgo.io/ent/dialect" log "github.com/sirupsen/logrus" - "github.com/crowdsecurity/go-cs-lib/pkg/ptr" + "github.com/crowdsecurity/go-cs-lib/ptr" + + "github.com/crowdsecurity/crowdsec/pkg/types" ) -var DEFAULT_MAX_OPEN_CONNS = 100 +const ( + DEFAULT_MAX_OPEN_CONNS = 100 + defaultDecisionBulkSize = 1000 + // we need an upper bound due to the sqlite limit of 32k variables in a query + // we have 15 variables per decision, so 32768/15 = 2184.5333 + maxDecisionBulkSize = 2000 +) type DatabaseCfg struct { - User string `yaml:"user"` - Password string `yaml:"password"` - DbName string `yaml:"db_name"` - Sslmode string `yaml:"sslmode"` - Host string `yaml:"host"` - Port int `yaml:"port"` - DbPath string `yaml:"db_path"` - Type string `yaml:"type"` - Flush *FlushDBCfg `yaml:"flush"` - LogLevel *log.Level `yaml:"log_level"` - MaxOpenConns *int `yaml:"max_open_conns,omitempty"` - UseWal *bool `yaml:"use_wal,omitempty"` + User string `yaml:"user"` + Password string `yaml:"password"` + DbName string `yaml:"db_name"` + Sslmode string `yaml:"sslmode"` + Host string `yaml:"host"` + Port int `yaml:"port"` + DbPath string `yaml:"db_path"` + Type string `yaml:"type"` + Flush *FlushDBCfg `yaml:"flush"` + LogLevel *log.Level `yaml:"log_level"` + MaxOpenConns int `yaml:"max_open_conns,omitempty"` + UseWal *bool `yaml:"use_wal,omitempty"` + DecisionBulkSize int `yaml:"decision_bulk_size,omitempty"` } type AuthGCCfg struct { @@ -37,15 +48,17 @@ type AuthGCCfg struct { } type FlushDBCfg struct { - MaxItems *int `yaml:"max_items,omitempty"` - MaxAge *string `yaml:"max_age,omitempty"` - BouncersGC *AuthGCCfg `yaml:"bouncers_autodelete,omitempty"` - AgentsGC *AuthGCCfg `yaml:"agents_autodelete,omitempty"` + MaxItems *int `yaml:"max_items,omitempty"` + // We could unmarshal as time.Duration, but alert filters right now are a map of strings + MaxAge *string `yaml:"max_age,omitempty"` + BouncersGC *AuthGCCfg `yaml:"bouncers_autodelete,omitempty"` + AgentsGC *AuthGCCfg `yaml:"agents_autodelete,omitempty"` + MetricsMaxAge *time.Duration `yaml:"metrics_max_age,omitempty"` } -func (c *Config) LoadDBConfig() error { +func (c *Config) LoadDBConfig(inCli bool) error { if c.DbConfig == nil { - return fmt.Errorf("no database configuration provided") + return errors.New("no database configuration provided") } if c.Cscli != nil { @@ -56,15 +69,48 @@ func (c *Config) LoadDBConfig() error { c.API.Server.DbConfig = c.DbConfig } - if c.DbConfig.MaxOpenConns == nil { - c.DbConfig.MaxOpenConns = ptr.Of(DEFAULT_MAX_OPEN_CONNS) + if c.DbConfig.MaxOpenConns == 0 { + c.DbConfig.MaxOpenConns = DEFAULT_MAX_OPEN_CONNS } - if c.DbConfig.Type == "sqlite" { + if !inCli && c.DbConfig.Type == "sqlite" { if c.DbConfig.UseWal == nil { - log.Warning("You are using sqlite without WAL, this can have a performance impact. If you do not store the database in a network share, set db_config.use_wal to true. Set explicitly to false to disable this warning.") + dbDir := filepath.Dir(c.DbConfig.DbPath) + isNetwork, fsType, err := types.IsNetworkFS(dbDir) + switch { + case err != nil: + log.Warnf("unable to determine if database is on network filesystem: %s", err) + log.Warning( + "You are using sqlite without WAL, this can have a performance impact. " + + "If you do not store the database in a network share, set db_config.use_wal to true. " + + "Set explicitly to false to disable this warning.") + case isNetwork: + log.Debugf("database is on network filesystem (%s), setting useWal to false", fsType) + c.DbConfig.UseWal = ptr.Of(false) + default: + log.Debugf("database is on local filesystem (%s), setting useWal to true", fsType) + c.DbConfig.UseWal = ptr.Of(true) + } + } else if *c.DbConfig.UseWal { + dbDir := filepath.Dir(c.DbConfig.DbPath) + isNetwork, fsType, err := types.IsNetworkFS(dbDir) + switch { + case err != nil: + log.Warnf("unable to determine if database is on network filesystem: %s", err) + case isNetwork: + log.Warnf("database seems to be stored on a network share (%s), but useWal is set to true. Proceed at your own risk.", fsType) + } } + } + if c.DbConfig.DecisionBulkSize == 0 { + log.Tracef("No decision_bulk_size value provided, using default value of %d", defaultDecisionBulkSize) + c.DbConfig.DecisionBulkSize = defaultDecisionBulkSize + } + + if c.DbConfig.DecisionBulkSize > maxDecisionBulkSize { + log.Warningf("decision_bulk_size too high (%d), setting to the maximum value of %d", c.DbConfig.DecisionBulkSize, maxDecisionBulkSize) + c.DbConfig.DecisionBulkSize = maxDecisionBulkSize } return nil @@ -72,6 +118,7 @@ func (c *Config) LoadDBConfig() error { func (d *DatabaseCfg) ConnectionString() string { connString := "" + switch d.Type { case "sqlite": var sqliteConnectionStringParameters string @@ -80,6 +127,7 @@ func (d *DatabaseCfg) ConnectionString() string { } else { sqliteConnectionStringParameters = "_busy_timeout=100000&_fk=1" } + connString = fmt.Sprintf("file:%s?%s", d.DbPath, sqliteConnectionStringParameters) case "mysql": if d.isSocketConfig() { @@ -87,6 +135,10 @@ func (d *DatabaseCfg) ConnectionString() string { } else { connString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=True", d.User, d.Password, d.Host, d.Port, d.DbName) } + + if d.Sslmode != "" { + connString = fmt.Sprintf("%s&tls=%s", connString, d.Sslmode) + } case "postgres", "postgresql", "pgx": if d.isSocketConfig() { connString = fmt.Sprintf("host=%s user=%s dbname=%s password=%s", d.DbPath, d.User, d.DbName, d.Password) @@ -94,6 +146,7 @@ func (d *DatabaseCfg) ConnectionString() string { connString = fmt.Sprintf("host=%s port=%d user=%s dbname=%s password=%s sslmode=%s", d.Host, d.Port, d.User, d.DbName, d.Password, d.Sslmode) } } + return connString } @@ -107,8 +160,10 @@ func (d *DatabaseCfg) ConnectionDialect() (string, string, error) { if d.Type != "pgx" { log.Debugf("database type '%s' is deprecated, switching to 'pgx' instead", d.Type) } + return "pgx", dialect.Postgres, nil } + return "", "", fmt.Errorf("unknown database type '%s'", d.Type) } diff --git a/pkg/csconfig/database_test.go b/pkg/csconfig/database_test.go index d33c54424ae..4a1ef807f97 100644 --- a/pkg/csconfig/database_test.go +++ b/pkg/csconfig/database_test.go @@ -1,66 +1,60 @@ package csconfig import ( - "fmt" - "strings" "testing" "github.com/stretchr/testify/assert" - "github.com/crowdsecurity/go-cs-lib/pkg/ptr" + "github.com/crowdsecurity/go-cs-lib/cstest" + "github.com/crowdsecurity/go-cs-lib/ptr" ) func TestLoadDBConfig(t *testing.T) { tests := []struct { - name string - Input *Config - expectedResult *DatabaseCfg - err string + name string + input *Config + expected *DatabaseCfg + expectedErr string }{ { name: "basic valid configuration", - Input: &Config{ + input: &Config{ DbConfig: &DatabaseCfg{ Type: "sqlite", - DbPath: "./tests/test.db", - MaxOpenConns: ptr.Of(10), + DbPath: "./testdata/test.db", + MaxOpenConns: 10, }, Cscli: &CscliCfg{}, API: &APICfg{ Server: &LocalApiServerCfg{}, }, }, - expectedResult: &DatabaseCfg{ - Type: "sqlite", - DbPath: "./tests/test.db", - MaxOpenConns: ptr.Of(10), + expected: &DatabaseCfg{ + Type: "sqlite", + DbPath: "./testdata/test.db", + MaxOpenConns: 10, + UseWal: ptr.Of(true), + DecisionBulkSize: defaultDecisionBulkSize, }, }, { - name: "no configuration path", - Input: &Config{}, - expectedResult: nil, + name: "no configuration path", + input: &Config{}, + expected: nil, + expectedErr: "no database configuration provided", }, } - for idx, test := range tests { - err := test.Input.LoadDBConfig() - if err == nil && test.err != "" { - fmt.Printf("TEST '%s': NOK\n", test.name) - t.Fatalf("%d/%d expected error, didn't get it", idx, len(tests)) - } else if test.err != "" { - if !strings.HasPrefix(fmt.Sprintf("%s", err), test.err) { - fmt.Printf("TEST '%s': NOK\n", test.name) - t.Fatalf("%d/%d expected '%s' got '%s'", idx, len(tests), - test.err, - fmt.Sprintf("%s", err)) + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := tc.input.LoadDBConfig(false) + cstest.RequireErrorContains(t, err, tc.expectedErr) + + if tc.expectedErr != "" { + return } - } - isOk := assert.Equal(t, test.expectedResult, test.Input.DbConfig) - if !isOk { - t.Fatalf("TEST '%s': NOK", test.name) - } else { - fmt.Printf("TEST '%s': OK\n", test.name) - } + + assert.Equal(t, tc.expected, tc.input.DbConfig) + }) } } diff --git a/pkg/csconfig/fflag.go b/pkg/csconfig/fflag.go index e9110649d52..c86686889eb 100644 --- a/pkg/csconfig/fflag.go +++ b/pkg/csconfig/fflag.go @@ -10,22 +10,23 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/fflag" ) - // LoadFeatureFlagsEnv parses the environment variables to enable feature flags. func LoadFeatureFlagsEnv(logger *log.Logger) error { - if err := fflag.Crowdsec.SetFromEnv(logger); err != nil { - return err - } - return nil + return fflag.Crowdsec.SetFromEnv(logger) } - -// LoadFeatureFlags parses feature.yaml to enable feature flags. +// FeatureFlagsFileLocation returns the path to the feature.yaml file. // The file is in the same directory as config.yaml, which is provided // as the fist parameter. This can be different than ConfigPaths.ConfigDir -func LoadFeatureFlagsFile(configPath string, logger *log.Logger) error { +// because we have not read config.yaml yet so we don't know the value of ConfigDir. +func GetFeatureFilePath(configPath string) string { dir := filepath.Dir(configPath) - featurePath := filepath.Join(dir, "feature.yaml") + return filepath.Join(dir, "feature.yaml") +} + +// LoadFeatureFlags parses feature.yaml to enable feature flags. +func LoadFeatureFlagsFile(configPath string, logger *log.Logger) error { + featurePath := GetFeatureFilePath(configPath) if err := fflag.Crowdsec.SetFromYamlFile(featurePath, logger); err != nil { return fmt.Errorf("file %s: %s", featurePath, err) @@ -33,7 +34,6 @@ func LoadFeatureFlagsFile(configPath string, logger *log.Logger) error { return nil } - // ListFeatureFlags returns a list of the enabled feature flags. func ListFeatureFlags() string { enabledFeatures := fflag.Crowdsec.GetEnabledFeatures() diff --git a/pkg/csconfig/hub.go b/pkg/csconfig/hub.go index eb3bd7c429e..ca3750e5812 100644 --- a/pkg/csconfig/hub.go +++ b/pkg/csconfig/hub.go @@ -1,23 +1,19 @@ package csconfig -/*cscli specific config, such as hub directory*/ -type Hub struct { - HubDir string `yaml:"-"` - ConfigDir string `yaml:"-"` - HubIndexFile string `yaml:"-"` - DataDir string `yaml:"-"` +// LocalHubCfg holds the configuration for a local hub: where to download etc. +type LocalHubCfg struct { + HubIndexFile string // Path to the local index file + HubDir string // Where the hub items are downloaded + InstallDir string // Where to install items + InstallDataDir string // Where to install data } -func (c *Config) LoadHub() error { - if err := c.LoadConfigurationPaths(); err != nil { - return err - } - - c.Hub = &Hub{ - HubIndexFile: c.ConfigPaths.HubIndexFile, - ConfigDir: c.ConfigPaths.ConfigDir, - HubDir: c.ConfigPaths.HubDir, - DataDir: c.ConfigPaths.DataDir, +func (c *Config) loadHub() error { + c.Hub = &LocalHubCfg{ + HubIndexFile: c.ConfigPaths.HubIndexFile, + HubDir: c.ConfigPaths.HubDir, + InstallDir: c.ConfigPaths.ConfigDir, + InstallDataDir: c.ConfigPaths.DataDir, } return nil diff --git a/pkg/csconfig/hub_test.go b/pkg/csconfig/hub_test.go index 136790d5f4a..49d010a04f4 100644 --- a/pkg/csconfig/hub_test.go +++ b/pkg/csconfig/hub_test.go @@ -1,94 +1,48 @@ package csconfig import ( - "fmt" - "path/filepath" - "strings" "testing" "github.com/stretchr/testify/assert" + + "github.com/crowdsecurity/go-cs-lib/cstest" ) func TestLoadHub(t *testing.T) { - hubFullPath, err := filepath.Abs("./hub") - if err != nil { - t.Fatal(err) - } - - dataFullPath, err := filepath.Abs("./data") - if err != nil { - t.Fatal(err) - } - - configDirFullPath, err := filepath.Abs("./tests") - if err != nil { - t.Fatal(err) - } - - hubIndexFileFullPath, err := filepath.Abs("./hub/.index.json") - if err != nil { - t.Fatal(err) - } - tests := []struct { - name string - Input *Config - expectedResult *Hub - err string + name string + input *Config + expected *LocalHubCfg + expectedErr string }{ { name: "basic valid configuration", - Input: &Config{ + input: &Config{ ConfigPaths: &ConfigurationPaths{ - ConfigDir: "./tests", + ConfigDir: "./testdata", DataDir: "./data", HubDir: "./hub", HubIndexFile: "./hub/.index.json", }, }, - expectedResult: &Hub{ - ConfigDir: configDirFullPath, - DataDir: dataFullPath, - HubDir: hubFullPath, - HubIndexFile: hubIndexFileFullPath, + expected: &LocalHubCfg{ + HubDir: "./hub", + HubIndexFile: "./hub/.index.json", + InstallDir: "./testdata", + InstallDataDir: "./data", }, }, - { - name: "no data dir", - Input: &Config{ - ConfigPaths: &ConfigurationPaths{ - ConfigDir: "./tests", - HubDir: "./hub", - HubIndexFile: "./hub/.index.json", - }, - }, - expectedResult: nil, - }, - { - name: "no configuration path", - Input: &Config{}, - expectedResult: nil, - }, } - for idx, test := range tests { - err := test.Input.LoadHub() - if err == nil && test.err != "" { - fmt.Printf("TEST '%s': NOK\n", test.name) - t.Fatalf("%d/%d expected error, didn't get it", idx, len(tests)) - } else if test.err != "" { - if !strings.HasPrefix(fmt.Sprintf("%s", err), test.err) { - fmt.Printf("TEST '%s': NOK\n", test.name) - t.Fatalf("%d/%d expected '%s' got '%s'", idx, len(tests), - test.err, - fmt.Sprintf("%s", err)) + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := tc.input.loadHub() + cstest.RequireErrorContains(t, err, tc.expectedErr) + if tc.expectedErr != "" { + return } - } - isOk := assert.Equal(t, test.expectedResult, test.Input.Hub) - if !isOk { - t.Fatalf("TEST '%s': NOK", test.name) - } else { - fmt.Printf("TEST '%s': OK\n", test.name) - } + + assert.Equal(t, tc.expected, tc.input.Hub) + }) } } diff --git a/pkg/csconfig/profiles.go b/pkg/csconfig/profiles.go index ec70fb459ae..6fbb8ed8b21 100644 --- a/pkg/csconfig/profiles.go +++ b/pkg/csconfig/profiles.go @@ -6,10 +6,11 @@ import ( "fmt" "io" - "github.com/crowdsecurity/go-cs-lib/pkg/yamlpatch" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/go-cs-lib/yamlpatch" "github.com/crowdsecurity/crowdsec/pkg/models" - "gopkg.in/yaml.v2" ) // var OnErrorDefault = OnErrorIgnore @@ -22,44 +23,50 @@ import ( type ProfileCfg struct { Name string `yaml:"name,omitempty"` Debug *bool `yaml:"debug,omitempty"` - Filters []string `yaml:"filters,omitempty"` //A list of OR'ed expressions. the models.Alert object + Filters []string `yaml:"filters,omitempty"` // A list of OR'ed expressions. the models.Alert object Decisions []models.Decision `yaml:"decisions,omitempty"` DurationExpr string `yaml:"duration_expr,omitempty"` - OnSuccess string `yaml:"on_success,omitempty"` //continue or break - OnFailure string `yaml:"on_failure,omitempty"` //continue or break - OnError string `yaml:"on_error,omitempty"` //continue, break, error, report, apply, ignore + OnSuccess string `yaml:"on_success,omitempty"` // continue or break + OnFailure string `yaml:"on_failure,omitempty"` // continue or break + OnError string `yaml:"on_error,omitempty"` // continue, break, error, report, apply, ignore Notifications []string `yaml:"notifications,omitempty"` } func (c *LocalApiServerCfg) LoadProfiles() error { if c.ProfilesPath == "" { - return fmt.Errorf("empty profiles path") + return errors.New("empty profiles path") } patcher := yamlpatch.NewPatcher(c.ProfilesPath, ".local") + fcontent, err := patcher.PrependedPatchContent() if err != nil { return err } + reader := bytes.NewReader(fcontent) - //process the yaml dec := yaml.NewDecoder(reader) - dec.SetStrict(true) + dec.KnownFields(true) + for { t := ProfileCfg{} + err = dec.Decode(&t) if err != nil { if errors.Is(err, io.EOF) { break } + return fmt.Errorf("while decoding %s: %w", c.ProfilesPath, err) } + c.Profiles = append(c.Profiles, &t) } if len(c.Profiles) == 0 { - return fmt.Errorf("zero profiles loaded for LAPI") + return errors.New("zero profiles loaded for LAPI") } + return nil } diff --git a/pkg/csconfig/prometheus.go b/pkg/csconfig/prometheus.go index eea768ab798..9b80fe39838 100644 --- a/pkg/csconfig/prometheus.go +++ b/pkg/csconfig/prometheus.go @@ -1,19 +1,8 @@ package csconfig -import "fmt" - type PrometheusCfg struct { Enabled bool `yaml:"enabled"` Level string `yaml:"level"` //aggregated|full ListenAddr string `yaml:"listen_addr"` ListenPort int `yaml:"listen_port"` } - -func (c *Config) LoadPrometheus() error { - if c.Cscli != nil && c.Cscli.PrometheusUrl == "" && c.Prometheus != nil { - if c.Prometheus.ListenAddr != "" && c.Prometheus.ListenPort != 0 { - c.Cscli.PrometheusUrl = fmt.Sprintf("http://%s:%d", c.Prometheus.ListenAddr, c.Prometheus.ListenPort) - } - } - return nil -} diff --git a/pkg/csconfig/prometheus_test.go b/pkg/csconfig/prometheus_test.go deleted file mode 100644 index 3df9c298b94..00000000000 --- a/pkg/csconfig/prometheus_test.go +++ /dev/null @@ -1,42 +0,0 @@ -package csconfig - -import ( - "testing" - - "github.com/crowdsecurity/go-cs-lib/pkg/cstest" - - "github.com/stretchr/testify/require" -) - -func TestLoadPrometheus(t *testing.T) { - tests := []struct { - name string - Input *Config - expectedURL string - expectedErr string - }{ - { - name: "basic valid configuration", - Input: &Config{ - Prometheus: &PrometheusCfg{ - Enabled: true, - Level: "full", - ListenAddr: "127.0.0.1", - ListenPort: 6060, - }, - Cscli: &CscliCfg{}, - }, - expectedURL: "http://127.0.0.1:6060", - }, - } - - for _, tc := range tests { - tc := tc - t.Run(tc.name, func(t *testing.T) { - err := tc.Input.LoadPrometheus() - cstest.RequireErrorContains(t, err, tc.expectedErr) - - require.Equal(t, tc.expectedURL, tc.Input.Cscli.PrometheusUrl) - }) - } -} diff --git a/pkg/csconfig/simulation.go b/pkg/csconfig/simulation.go index f291a4e1651..c9041df464a 100644 --- a/pkg/csconfig/simulation.go +++ b/pkg/csconfig/simulation.go @@ -1,11 +1,15 @@ package csconfig import ( + "bytes" + "errors" "fmt" + "io" "path/filepath" - "github.com/crowdsecurity/go-cs-lib/pkg/yamlpatch" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/go-cs-lib/yamlpatch" ) type SimulationConfig struct { @@ -19,42 +23,50 @@ func (s *SimulationConfig) IsSimulated(scenario string) bool { if s.Simulation != nil && *s.Simulation { simulated = true } + for _, excluded := range s.Exclusions { if excluded == scenario { - simulated = !simulated - break + return !simulated } } + return simulated } func (c *Config) LoadSimulation() error { - - if err := c.LoadConfigurationPaths(); err != nil { - return err - } - simCfg := SimulationConfig{} + if c.ConfigPaths.SimulationFilePath == "" { - c.ConfigPaths.SimulationFilePath = filepath.Clean(c.ConfigPaths.ConfigDir + "/simulation.yaml") + c.ConfigPaths.SimulationFilePath = filepath.Join(c.ConfigPaths.ConfigDir, "simulation.yaml") } patcher := yamlpatch.NewPatcher(c.ConfigPaths.SimulationFilePath, ".local") + rcfg, err := patcher.MergedPatchContent() if err != nil { return err } - if err := yaml.UnmarshalStrict(rcfg, &simCfg); err != nil { - return fmt.Errorf("while unmarshaling simulation file '%s' : %s", c.ConfigPaths.SimulationFilePath, err) + + dec := yaml.NewDecoder(bytes.NewReader(rcfg)) + dec.KnownFields(true) + + if err := dec.Decode(&simCfg); err != nil { + if !errors.Is(err, io.EOF) { + return fmt.Errorf("while parsing simulation file '%s': %w", c.ConfigPaths.SimulationFilePath, err) + } } + if simCfg.Simulation == nil { simCfg.Simulation = new(bool) } + if c.Crowdsec != nil { c.Crowdsec.SimulationConfig = &simCfg } + if c.Cscli != nil { c.Cscli.SimulationConfig = &simCfg } + return nil } diff --git a/pkg/csconfig/simulation_test.go b/pkg/csconfig/simulation_test.go index 8b202599345..a1e5f0a5b02 100644 --- a/pkg/csconfig/simulation_test.go +++ b/pkg/csconfig/simulation_test.go @@ -2,93 +2,85 @@ package csconfig import ( "fmt" - "path/filepath" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/crowdsecurity/go-cs-lib/pkg/cstest" + "github.com/crowdsecurity/go-cs-lib/cstest" ) func TestSimulationLoading(t *testing.T) { - testXXFullPath, err := filepath.Abs("./tests/xxx.yaml") - require.NoError(t, err) - - badYamlFullPath, err := filepath.Abs("./tests/config.yaml") - require.NoError(t, err) - tests := []struct { - name string - Input *Config - expectedResult *SimulationConfig - expectedErr string + name string + input *Config + expected *SimulationConfig + expectedErr string }{ { name: "basic valid simulation", - Input: &Config{ + input: &Config{ ConfigPaths: &ConfigurationPaths{ - SimulationFilePath: "./tests/simulation.yaml", + SimulationFilePath: "./testdata/simulation.yaml", DataDir: "./data", }, Crowdsec: &CrowdsecServiceCfg{}, Cscli: &CscliCfg{}, }, - expectedResult: &SimulationConfig{Simulation: new(bool)}, + expected: &SimulationConfig{Simulation: new(bool)}, }, { name: "basic nil config", - Input: &Config{ + input: &Config{ ConfigPaths: &ConfigurationPaths{ SimulationFilePath: "", DataDir: "./data", }, Crowdsec: &CrowdsecServiceCfg{}, }, - expectedErr: "simulation.yaml: "+cstest.FileNotFoundMessage, + expectedErr: "simulation.yaml: " + cstest.FileNotFoundMessage, }, { name: "basic bad file name", - Input: &Config{ + input: &Config{ ConfigPaths: &ConfigurationPaths{ - SimulationFilePath: "./tests/xxx.yaml", + SimulationFilePath: "./testdata/xxx.yaml", DataDir: "./data", }, Crowdsec: &CrowdsecServiceCfg{}, }, - expectedErr: fmt.Sprintf("while reading yaml file: open %s: %s", testXXFullPath, cstest.FileNotFoundMessage), + expectedErr: fmt.Sprintf("while reading yaml file: open ./testdata/xxx.yaml: %s", cstest.FileNotFoundMessage), }, { name: "basic bad file content", - Input: &Config{ + input: &Config{ ConfigPaths: &ConfigurationPaths{ - SimulationFilePath: "./tests/config.yaml", + SimulationFilePath: "./testdata/config.yaml", DataDir: "./data", }, Crowdsec: &CrowdsecServiceCfg{}, }, - expectedErr: fmt.Sprintf("while unmarshaling simulation file '%s' : yaml: unmarshal errors", badYamlFullPath), + expectedErr: "while parsing simulation file './testdata/config.yaml': yaml: unmarshal errors", }, { name: "basic bad file content", - Input: &Config{ + input: &Config{ ConfigPaths: &ConfigurationPaths{ - SimulationFilePath: "./tests/config.yaml", + SimulationFilePath: "./testdata/config.yaml", DataDir: "./data", }, Crowdsec: &CrowdsecServiceCfg{}, }, - expectedErr: fmt.Sprintf("while unmarshaling simulation file '%s' : yaml: unmarshal errors", badYamlFullPath), + expectedErr: "while parsing simulation file './testdata/config.yaml': yaml: unmarshal errors", }, } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { - err := tc.Input.LoadSimulation() + err := tc.input.LoadSimulation() cstest.RequireErrorContains(t, err, tc.expectedErr) - assert.Equal(t, tc.expectedResult, tc.Input.Crowdsec.SimulationConfig) + assert.Equal(t, tc.expected, tc.input.Crowdsec.SimulationConfig) }) } } @@ -109,32 +101,31 @@ func TestIsSimulated(t *testing.T) { name string SimulationConfig *SimulationConfig Input string - expectedResult bool + expected bool }{ { name: "No simulation except (in exclusion)", SimulationConfig: simCfgOff, Input: "test", - expectedResult: true, + expected: true, }, { name: "All simulation (not in exclusion)", SimulationConfig: simCfgOn, Input: "toto", - expectedResult: true, + expected: true, }, { name: "All simulation (in exclusion)", SimulationConfig: simCfgOn, Input: "test", - expectedResult: false, + expected: false, }, } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { - IsSimulated := tc.SimulationConfig.IsSimulated(tc.Input) - require.Equal(t, tc.expectedResult, IsSimulated) + isSimulated := tc.SimulationConfig.IsSimulated(tc.Input) + require.Equal(t, tc.expected, isSimulated) }) } } diff --git a/pkg/csconfig/tests/acquis.yaml b/pkg/csconfig/testdata/acquis.yaml similarity index 100% rename from pkg/csconfig/tests/acquis.yaml rename to pkg/csconfig/testdata/acquis.yaml diff --git a/pkg/csconfig/tests/acquis/acquis.yaml b/pkg/csconfig/testdata/acquis/acquis.yaml similarity index 100% rename from pkg/csconfig/tests/acquis/acquis.yaml rename to pkg/csconfig/testdata/acquis/acquis.yaml diff --git a/pkg/csconfig/tests/bad_lapi-secrets.yaml b/pkg/csconfig/testdata/bad_lapi-secrets.yaml similarity index 100% rename from pkg/csconfig/tests/bad_lapi-secrets.yaml rename to pkg/csconfig/testdata/bad_lapi-secrets.yaml diff --git a/pkg/csconfig/tests/bad_online-api-secrets.yaml b/pkg/csconfig/testdata/bad_online-api-secrets.yaml similarity index 100% rename from pkg/csconfig/tests/bad_online-api-secrets.yaml rename to pkg/csconfig/testdata/bad_online-api-secrets.yaml diff --git a/pkg/csconfig/tests/config.yaml b/pkg/csconfig/testdata/config.yaml similarity index 54% rename from pkg/csconfig/tests/config.yaml rename to pkg/csconfig/testdata/config.yaml index 3148659e321..17975b10501 100644 --- a/pkg/csconfig/tests/config.yaml +++ b/pkg/csconfig/testdata/config.yaml @@ -2,12 +2,11 @@ common: daemonize: false log_media: stdout log_level: info - working_dir: . prometheus: enabled: true level: full crowdsec_service: - acquisition_path: ./tests/acquis.yaml + acquisition_path: ./testdata/acquis.yaml parser_routines: 1 cscli: output: human @@ -21,17 +20,17 @@ db_config: type: sqlite api: client: - credentials_path: ./tests/lapi-secrets.yaml + credentials_path: ./testdata/lapi-secrets.yaml server: - profiles_path: ./tests/profiles.yaml + profiles_path: ./testdata/profiles.yaml listen_uri: 127.0.0.1:8080 tls: null online_client: - credentials_path: ./tests/online-api-secrets.yaml + credentials_path: ./testdata/online-api-secrets.yaml config_paths: - config_dir: ./tests + config_dir: ./testdata data_dir: . - simulation_path: ./tests/simulation.yaml - index_path: ./tests/hub/.index.json - hub_dir: ./tests/hub + simulation_path: ./testdata/simulation.yaml + index_path: ./testdata/hub/.index.json + hub_dir: ./testdata/hub diff --git a/pkg/csconfig/tests/context.yaml b/pkg/csconfig/testdata/context.yaml similarity index 100% rename from pkg/csconfig/tests/context.yaml rename to pkg/csconfig/testdata/context.yaml diff --git a/pkg/csconfig/tests/lapi-secrets.yaml b/pkg/csconfig/testdata/lapi-secrets.yaml similarity index 100% rename from pkg/csconfig/tests/lapi-secrets.yaml rename to pkg/csconfig/testdata/lapi-secrets.yaml diff --git a/pkg/csconfig/tests/online-api-secrets.yaml b/pkg/csconfig/testdata/online-api-secrets.yaml similarity index 100% rename from pkg/csconfig/tests/online-api-secrets.yaml rename to pkg/csconfig/testdata/online-api-secrets.yaml diff --git a/pkg/csconfig/tests/profiles.yaml b/pkg/csconfig/testdata/profiles.yaml similarity index 100% rename from pkg/csconfig/tests/profiles.yaml rename to pkg/csconfig/testdata/profiles.yaml diff --git a/pkg/csconfig/tests/simulation.yaml b/pkg/csconfig/testdata/simulation.yaml similarity index 100% rename from pkg/csconfig/tests/simulation.yaml rename to pkg/csconfig/testdata/simulation.yaml diff --git a/pkg/csconfig/tls.go b/pkg/csconfig/tls.go new file mode 100644 index 00000000000..897112a757f --- /dev/null +++ b/pkg/csconfig/tls.go @@ -0,0 +1,87 @@ +package csconfig + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "os" + "time" + + log "github.com/sirupsen/logrus" +) + +type TLSCfg struct { + CertFilePath string `yaml:"cert_file"` + KeyFilePath string `yaml:"key_file"` + ClientVerification string `yaml:"client_verification,omitempty"` + ServerName string `yaml:"server_name"` + CACertPath string `yaml:"ca_cert_path"` + AllowedAgentsOU []string `yaml:"agents_allowed_ou"` + AllowedBouncersOU []string `yaml:"bouncers_allowed_ou"` + CRLPath string `yaml:"crl_path"` + CacheExpiration *time.Duration `yaml:"cache_expiration,omitempty"` +} + +func (t *TLSCfg) GetAuthType() (tls.ClientAuthType, error) { + if t.ClientVerification == "" { + // sounds like a sane default: verify client cert if given, but don't make it mandatory + return tls.VerifyClientCertIfGiven, nil + } + + switch t.ClientVerification { + case "NoClientCert": + return tls.NoClientCert, nil + case "RequestClientCert": + log.Warn("RequestClientCert is insecure, please use VerifyClientCertIfGiven or RequireAndVerifyClientCert instead") + return tls.RequestClientCert, nil + case "RequireAnyClientCert": + log.Warn("RequireAnyClientCert is insecure, please use VerifyClientCertIfGiven or RequireAndVerifyClientCert instead") + return tls.RequireAnyClientCert, nil + case "VerifyClientCertIfGiven": + return tls.VerifyClientCertIfGiven, nil + case "RequireAndVerifyClientCert": + return tls.RequireAndVerifyClientCert, nil + default: + return 0, fmt.Errorf("unknown TLS client_verification value: %s", t.ClientVerification) + } +} + +func (t *TLSCfg) GetTLSConfig() (*tls.Config, error) { + if t == nil { + return &tls.Config{}, nil + } + + clientAuthType, err := t.GetAuthType() + if err != nil { + return nil, err + } + + caCertPool, err := x509.SystemCertPool() + if err != nil { + log.Warnf("Error loading system CA certificates: %s", err) + } + + if caCertPool == nil { + caCertPool = x509.NewCertPool() + } + + // the > condition below is a weird way to say "if a client certificate is required" + // see https://pkg.go.dev/crypto/tls#ClientAuthType + if clientAuthType > tls.RequestClientCert && t.CACertPath != "" { + log.Infof("(tls) Client Auth Type set to %s", clientAuthType.String()) + + caCert, err := os.ReadFile(t.CACertPath) + if err != nil { + return nil, fmt.Errorf("while opening cert file: %w", err) + } + + caCertPool.AppendCertsFromPEM(caCert) + } + + return &tls.Config{ + ServerName: t.ServerName, //should it be removed ? + ClientAuth: clientAuthType, + ClientCAs: caCertPool, + MinVersion: tls.VersionTLS12, // TLS versions below 1.2 are considered insecure - see https://www.rfc-editor.org/rfc/rfc7525.txt for details + }, nil +} diff --git a/pkg/csplugin/broker.go b/pkg/csplugin/broker.go index 208ce933691..e996fa9b68c 100644 --- a/pkg/csplugin/broker.go +++ b/pkg/csplugin/broker.go @@ -19,8 +19,8 @@ import ( "gopkg.in/tomb.v2" "gopkg.in/yaml.v2" - "github.com/crowdsecurity/go-cs-lib/pkg/csstring" - "github.com/crowdsecurity/go-cs-lib/pkg/slicetools" + "github.com/crowdsecurity/go-cs-lib/csstring" + "github.com/crowdsecurity/go-cs-lib/slicetools" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/models" @@ -45,7 +45,7 @@ type PluginBroker struct { pluginConfigByName map[string]PluginConfig pluginMap map[string]plugin.Plugin notificationConfigsByPluginType map[string][][]byte // "slack" -> []{config1, config2} - notificationPluginByName map[string]Notifier + notificationPluginByName map[string]protobufs.NotifierServer watcher PluginWatcher pluginKillMethods []func() pluginProcConfig *csconfig.PluginCfg @@ -72,10 +72,10 @@ type ProfileAlert struct { Alert *models.Alert } -func (pb *PluginBroker) Init(pluginCfg *csconfig.PluginCfg, profileConfigs []*csconfig.ProfileCfg, configPaths *csconfig.ConfigurationPaths) error { +func (pb *PluginBroker) Init(ctx context.Context, pluginCfg *csconfig.PluginCfg, profileConfigs []*csconfig.ProfileCfg, configPaths *csconfig.ConfigurationPaths) error { pb.PluginChannel = make(chan ProfileAlert) pb.notificationConfigsByPluginType = make(map[string][][]byte) - pb.notificationPluginByName = make(map[string]Notifier) + pb.notificationPluginByName = make(map[string]protobufs.NotifierServer) pb.pluginMap = make(map[string]plugin.Plugin) pb.pluginConfigByName = make(map[string]PluginConfig) pb.alertsByPluginName = make(map[string][]*models.Alert) @@ -85,7 +85,7 @@ func (pb *PluginBroker) Init(pluginCfg *csconfig.PluginCfg, profileConfigs []*cs if err := pb.loadConfig(configPaths.NotificationDir); err != nil { return fmt.Errorf("while loading plugin config: %w", err) } - if err := pb.loadPlugins(configPaths.PluginDir); err != nil { + if err := pb.loadPlugins(ctx, configPaths.PluginDir); err != nil { return fmt.Errorf("while loading plugin: %w", err) } pb.watcher = PluginWatcher{} @@ -103,14 +103,13 @@ func (pb *PluginBroker) Kill() { func (pb *PluginBroker) Run(pluginTomb *tomb.Tomb) { //we get signaled via the channel when notifications need to be delivered to plugin (via the watcher) pb.watcher.Start(&tomb.Tomb{}) -loop: for { select { case profileAlert := <-pb.PluginChannel: pb.addProfileAlert(profileAlert) case pluginName := <-pb.watcher.PluginEvents: - // this can be ran in goroutine, but then locks will be needed + // this can be run in goroutine, but then locks will be needed pluginMutex.Lock() log.Tracef("going to deliver %d alerts to plugin %s", len(pb.alertsByPluginName[pluginName]), pluginName) tmpAlerts := pb.alertsByPluginName[pluginName] @@ -137,9 +136,9 @@ loop: case <-pb.watcher.tomb.Dead(): log.Info("killing all plugins") pb.Kill() - break loop + return case pluginName := <-pb.watcher.PluginEvents: - // this can be ran in goroutine, but then locks will be needed + // this can be run in goroutine, but then locks will be needed pluginMutex.Lock() log.Tracef("going to deliver %d alerts to plugin %s", len(pb.alertsByPluginName[pluginName]), pluginName) tmpAlerts := pb.alertsByPluginName[pluginName] @@ -192,7 +191,7 @@ func (pb *PluginBroker) loadConfig(path string) error { return err } for _, pluginConfig := range pluginConfigs { - setRequiredFields(&pluginConfig) + SetRequiredFields(&pluginConfig) if _, ok := pb.pluginConfigByName[pluginConfig.Name]; ok { log.Warningf("notification '%s' is defined multiple times", pluginConfig.Name) } @@ -206,7 +205,7 @@ func (pb *PluginBroker) loadConfig(path string) error { return err } -// checks whether every notification in profile has it's own config file +// checks whether every notification in profile has its own config file func (pb *PluginBroker) verifyPluginConfigsWithProfile() error { for _, profileCfg := range pb.profileConfigs { for _, pluginName := range profileCfg.Notifications { @@ -219,7 +218,7 @@ func (pb *PluginBroker) verifyPluginConfigsWithProfile() error { return nil } -// check whether each plugin in profile has it's own binary +// check whether each plugin in profile has its own binary func (pb *PluginBroker) verifyPluginBinaryWithProfile() error { for _, profileCfg := range pb.profileConfigs { for _, pluginName := range profileCfg.Notifications { @@ -231,7 +230,7 @@ func (pb *PluginBroker) verifyPluginBinaryWithProfile() error { return nil } -func (pb *PluginBroker) loadPlugins(path string) error { +func (pb *PluginBroker) loadPlugins(ctx context.Context, path string) error { binaryPaths, err := listFilesAtPath(path) if err != nil { return err @@ -266,7 +265,7 @@ func (pb *PluginBroker) loadPlugins(path string) error { return err } data = []byte(csstring.StrictExpand(string(data), os.LookupEnv)) - _, err = pluginClient.Configure(context.Background(), &protobufs.Config{Config: data}) + _, err = pluginClient.Configure(ctx, &protobufs.Config{Config: data}) if err != nil { return fmt.Errorf("while configuring %s: %w", pc.Name, err) } @@ -277,7 +276,7 @@ func (pb *PluginBroker) loadPlugins(path string) error { return pb.verifyPluginBinaryWithProfile() } -func (pb *PluginBroker) loadNotificationPlugin(name string, binaryPath string) (Notifier, error) { +func (pb *PluginBroker) loadNotificationPlugin(name string, binaryPath string) (protobufs.NotifierServer, error) { handshake, err := getHandshake() if err != nil { @@ -314,7 +313,7 @@ func (pb *PluginBroker) loadNotificationPlugin(name string, binaryPath string) ( return nil, err } pb.pluginKillMethods = append(pb.pluginKillMethods, c.Kill) - return raw.(Notifier), nil + return raw.(protobufs.NotifierServer), nil } func (pb *PluginBroker) pushNotificationsToPlugin(pluginName string, alerts []*models.Alert) error { @@ -376,7 +375,7 @@ func ParsePluginConfigFile(path string) ([]PluginConfig, error) { return parsedConfigs, nil } -func setRequiredFields(pluginCfg *PluginConfig) { +func SetRequiredFields(pluginCfg *PluginConfig) { if pluginCfg.MaxRetry == 0 { pluginCfg.MaxRetry++ } diff --git a/pkg/csplugin/broker_suite_test.go b/pkg/csplugin/broker_suite_test.go index 6e3e51407ea..1210c67058a 100644 --- a/pkg/csplugin/broker_suite_test.go +++ b/pkg/csplugin/broker_suite_test.go @@ -1,10 +1,11 @@ package csplugin import ( + "context" "io" "os" "os/exec" - "path" + "path/filepath" "runtime" "testing" @@ -43,13 +44,13 @@ func (s *PluginSuite) SetupSuite() { s.buildDir, err = os.MkdirTemp("", "cs_plugin_test_build") require.NoError(t, err) - s.builtBinary = path.Join(s.buildDir, "notification-dummy") + s.builtBinary = filepath.Join(s.buildDir, "notification-dummy") if runtime.GOOS == "windows" { s.builtBinary += ".exe" } - cmd := exec.Command("go", "build", "-o", s.builtBinary, "../../plugins/notifications/dummy/") + cmd := exec.Command("go", "build", "-o", s.builtBinary, "../../cmd/notification-dummy/") err = cmd.Run() require.NoError(t, err, "while building dummy plugin") } @@ -96,20 +97,21 @@ func (s *PluginSuite) TearDownTest() { func (s *PluginSuite) SetupSubTest() { var err error + t := s.T() s.runDir, err = os.MkdirTemp("", "cs_plugin_test") require.NoError(t, err) - s.pluginDir = path.Join(s.runDir, "bin") - err = os.MkdirAll(path.Join(s.runDir, "bin"), 0o755) + s.pluginDir = filepath.Join(s.runDir, "bin") + err = os.MkdirAll(filepath.Join(s.runDir, "bin"), 0o755) require.NoError(t, err, "while creating bin dir") - s.notifDir = path.Join(s.runDir, "config") + s.notifDir = filepath.Join(s.runDir, "config") err = os.MkdirAll(s.notifDir, 0o755) require.NoError(t, err, "while creating config dir") - s.pluginBinary = path.Join(s.pluginDir, "notification-dummy") + s.pluginBinary = filepath.Join(s.pluginDir, "notification-dummy") if runtime.GOOS == "windows" { s.pluginBinary += ".exe" @@ -120,13 +122,14 @@ func (s *PluginSuite) SetupSubTest() { err = os.Chmod(s.pluginBinary, 0o744) require.NoError(t, err, "chmod 0744 %s", s.pluginBinary) - s.pluginConfig = path.Join(s.notifDir, "dummy.yaml") + s.pluginConfig = filepath.Join(s.notifDir, "dummy.yaml") err = copyFile("testdata/dummy.yaml", s.pluginConfig) require.NoError(t, err, "while copying plugin config") } func (s *PluginSuite) TearDownSubTest() { t := s.T() + if s.pluginBroker != nil { s.pluginBroker.Kill() s.pluginBroker = nil @@ -140,19 +143,24 @@ func (s *PluginSuite) TearDownSubTest() { os.Remove("./out") } -func (s *PluginSuite) InitBroker(procCfg *csconfig.PluginCfg) (*PluginBroker, error) { +func (s *PluginSuite) InitBroker(ctx context.Context, procCfg *csconfig.PluginCfg) (*PluginBroker, error) { pb := PluginBroker{} + if procCfg == nil { procCfg = &csconfig.PluginCfg{} } + profiles := csconfig.NewDefaultConfig().API.Server.Profiles profiles = append(profiles, &csconfig.ProfileCfg{ Notifications: []string{"dummy_default"}, }) - err := pb.Init(procCfg, profiles, &csconfig.ConfigurationPaths{ + + err := pb.Init(ctx, procCfg, profiles, &csconfig.ConfigurationPaths{ PluginDir: s.pluginDir, NotificationDir: s.notifDir, }) + s.pluginBroker = &pb + return s.pluginBroker, err } diff --git a/pkg/csplugin/broker_test.go b/pkg/csplugin/broker_test.go index 991b89ed20c..ae5a615b489 100644 --- a/pkg/csplugin/broker_test.go +++ b/pkg/csplugin/broker_test.go @@ -4,6 +4,7 @@ package csplugin import ( "bytes" + "context" "encoding/json" "io" "os" @@ -14,9 +15,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gopkg.in/tomb.v2" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" - "github.com/crowdsecurity/go-cs-lib/pkg/cstest" + "github.com/crowdsecurity/go-cs-lib/cstest" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/models" @@ -31,13 +32,14 @@ func (s *PluginSuite) permissionSetter(perm os.FileMode) func(*testing.T) { func (s *PluginSuite) readconfig() PluginConfig { var config PluginConfig + t := s.T() orig, err := os.ReadFile(s.pluginConfig) require.NoError(t, err, "unable to read config file %s", s.pluginConfig) err = yaml.Unmarshal(orig, &config) - require.NoError(t, err, "unable to unmarshal config file") + require.NoError(t, err, "unable to parse config file") return config } @@ -45,13 +47,14 @@ func (s *PluginSuite) readconfig() PluginConfig { func (s *PluginSuite) writeconfig(config PluginConfig) { t := s.T() data, err := yaml.Marshal(&config) - require.NoError(t, err, "unable to marshal config file") + require.NoError(t, err, "unable to serialize config file") - err = os.WriteFile(s.pluginConfig, data, 0644) + err = os.WriteFile(s.pluginConfig, data, 0o644) require.NoError(t, err, "unable to write config file %s", s.pluginConfig) } func (s *PluginSuite) TestBrokerInit() { + ctx := context.Background() tests := []struct { name string action func(*testing.T) @@ -111,7 +114,7 @@ func (s *PluginSuite) TestBrokerInit() { }, { name: "Invalid user and group", - expectedErr: "unknown user toto1234", + expectedErr: "toto1234", procCfg: csconfig.PluginCfg{ User: "toto1234", Group: "toto1234", @@ -119,7 +122,7 @@ func (s *PluginSuite) TestBrokerInit() { }, { name: "Valid user and invalid group", - expectedErr: "unknown group toto1234", + expectedErr: "toto1234", procCfg: csconfig.PluginCfg{ User: "nobody", Group: "toto1234", @@ -128,32 +131,36 @@ func (s *PluginSuite) TestBrokerInit() { } for _, tc := range tests { - tc := tc s.Run(tc.name, func() { t := s.T() if tc.action != nil { tc.action(t) } - _, err := s.InitBroker(&tc.procCfg) + + _, err := s.InitBroker(ctx, &tc.procCfg) cstest.RequireErrorContains(t, err, tc.expectedErr) }) } } func (s *PluginSuite) TestBrokerNoThreshold() { + ctx := context.Background() + var alerts []models.Alert + DefaultEmptyTicker = 50 * time.Millisecond t := s.T() - pb, err := s.InitBroker(nil) - assert.NoError(t, err) + pb, err := s.InitBroker(ctx, nil) + require.NoError(t, err) tomb := tomb.Tomb{} go pb.Run(&tomb) // send one item, it should be processed right now pb.PluginChannel <- ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}} + time.Sleep(200 * time.Millisecond) // we expect one now @@ -170,6 +177,7 @@ func (s *PluginSuite) TestBrokerNoThreshold() { // and another one log.Printf("second send") pb.PluginChannel <- ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}} + time.Sleep(200 * time.Millisecond) // we expect one again, as we cleaned the file @@ -178,11 +186,13 @@ func (s *PluginSuite) TestBrokerNoThreshold() { err = json.Unmarshal(content, &alerts) log.Printf("content-> %s", content) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 1) } func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() { + ctx := context.Background() + // test grouping by "time" DefaultEmptyTicker = 50 * time.Millisecond @@ -194,8 +204,8 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() { cfg.GroupWait = 1 * time.Second s.writeconfig(cfg) - pb, err := s.InitBroker(nil) - assert.NoError(t, err) + pb, err := s.InitBroker(ctx, nil) + require.NoError(t, err) tomb := tomb.Tomb{} go pb.Run(&tomb) @@ -204,21 +214,23 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() { pb.PluginChannel <- ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}} pb.PluginChannel <- ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}} pb.PluginChannel <- ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}} + time.Sleep(500 * time.Millisecond) // because of group threshold, we shouldn't have data yet assert.NoFileExists(t, "./out") time.Sleep(1 * time.Second) // after 1 seconds, we should have data content, err := os.ReadFile("./out") - assert.NoError(t, err) + require.NoError(t, err) var alerts []models.Alert err = json.Unmarshal(content, &alerts) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 3) } func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() { + ctx := context.Background() DefaultEmptyTicker = 50 * time.Millisecond t := s.T() @@ -229,8 +241,8 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() { cfg.GroupWait = 4 * time.Second s.writeconfig(cfg) - pb, err := s.InitBroker(nil) - assert.NoError(t, err) + pb, err := s.InitBroker(ctx, nil) + require.NoError(t, err) tomb := tomb.Tomb{} go pb.Run(&tomb) @@ -239,11 +251,13 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() { pb.PluginChannel <- ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}} pb.PluginChannel <- ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}} pb.PluginChannel <- ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}} + time.Sleep(100 * time.Millisecond) // because of group threshold, we shouldn't have data yet assert.NoFileExists(t, "./out") pb.PluginChannel <- ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}} + time.Sleep(100 * time.Millisecond) // and now we should @@ -252,11 +266,12 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() { var alerts []models.Alert err = json.Unmarshal(content, &alerts) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 4) } func (s *PluginSuite) TestBrokerRunGroupThreshold() { + ctx := context.Background() // test grouping by "size" DefaultEmptyTicker = 50 * time.Millisecond @@ -267,8 +282,8 @@ func (s *PluginSuite) TestBrokerRunGroupThreshold() { cfg.GroupThreshold = 4 s.writeconfig(cfg) - pb, err := s.InitBroker(nil) - assert.NoError(t, err) + pb, err := s.InitBroker(ctx, nil) + require.NoError(t, err) tomb := tomb.Tomb{} go pb.Run(&tomb) @@ -277,6 +292,7 @@ func (s *PluginSuite) TestBrokerRunGroupThreshold() { pb.PluginChannel <- ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}} pb.PluginChannel <- ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}} pb.PluginChannel <- ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}} + time.Sleep(time.Second) // because of group threshold, we shouldn't have data yet @@ -284,6 +300,7 @@ func (s *PluginSuite) TestBrokerRunGroupThreshold() { pb.PluginChannel <- ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}} pb.PluginChannel <- ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}} pb.PluginChannel <- ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}} + time.Sleep(time.Second) // and now we should @@ -297,11 +314,11 @@ func (s *PluginSuite) TestBrokerRunGroupThreshold() { // two notifications, one with 4 alerts, one with 2 alerts err = decoder.Decode(&alerts) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 4) err = decoder.Decode(&alerts) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 2) err = decoder.Decode(&alerts) @@ -309,6 +326,7 @@ func (s *PluginSuite) TestBrokerRunGroupThreshold() { } func (s *PluginSuite) TestBrokerRunTimeThreshold() { + ctx := context.Background() DefaultEmptyTicker = 50 * time.Millisecond t := s.T() @@ -318,14 +336,15 @@ func (s *PluginSuite) TestBrokerRunTimeThreshold() { cfg.GroupWait = 1 * time.Second s.writeconfig(cfg) - pb, err := s.InitBroker(nil) - assert.NoError(t, err) + pb, err := s.InitBroker(ctx, nil) + require.NoError(t, err) tomb := tomb.Tomb{} go pb.Run(&tomb) // send data pb.PluginChannel <- ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}} + time.Sleep(200 * time.Millisecond) // we shouldn't have data yet @@ -338,17 +357,18 @@ func (s *PluginSuite) TestBrokerRunTimeThreshold() { var alerts []models.Alert err = json.Unmarshal(content, &alerts) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 1) } func (s *PluginSuite) TestBrokerRunSimple() { + ctx := context.Background() DefaultEmptyTicker = 50 * time.Millisecond t := s.T() - pb, err := s.InitBroker(nil) - assert.NoError(t, err) + pb, err := s.InitBroker(ctx, nil) + require.NoError(t, err) tomb := tomb.Tomb{} go pb.Run(&tomb) @@ -372,11 +392,11 @@ func (s *PluginSuite) TestBrokerRunSimple() { // two notifications, one alert each err = decoder.Decode(&alerts) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 1) err = decoder.Decode(&alerts) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 1) err = decoder.Decode(&alerts) diff --git a/pkg/csplugin/broker_win_test.go b/pkg/csplugin/broker_win_test.go index 01262b1fd0c..570f23e5015 100644 --- a/pkg/csplugin/broker_win_test.go +++ b/pkg/csplugin/broker_win_test.go @@ -4,6 +4,7 @@ package csplugin import ( "bytes" + "context" "encoding/json" "io" "os" @@ -14,7 +15,7 @@ import ( "github.com/stretchr/testify/require" "gopkg.in/tomb.v2" - "github.com/crowdsecurity/go-cs-lib/pkg/cstest" + "github.com/crowdsecurity/go-cs-lib/cstest" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/models" @@ -26,6 +27,7 @@ not if it will actually reject plugins with invalid permissions */ func (s *PluginSuite) TestBrokerInit() { + ctx := context.Background() tests := []struct { name string action func(*testing.T) @@ -54,23 +56,23 @@ func (s *PluginSuite) TestBrokerInit() { } for _, tc := range tests { - tc := tc s.Run(tc.name, func() { t := s.T() if tc.action != nil { tc.action(t) } - _, err := s.InitBroker(&tc.procCfg) + _, err := s.InitBroker(ctx, &tc.procCfg) cstest.RequireErrorContains(t, err, tc.expectedErr) }) } } func (s *PluginSuite) TestBrokerRun() { + ctx := context.Background() t := s.T() - pb, err := s.InitBroker(nil) - assert.NoError(t, err) + pb, err := s.InitBroker(ctx, nil) + require.NoError(t, err) tomb := tomb.Tomb{} go pb.Run(&tomb) @@ -94,11 +96,11 @@ func (s *PluginSuite) TestBrokerRun() { // two notifications, one alert each err = decoder.Decode(&alerts) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 1) err = decoder.Decode(&alerts) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 1) err = decoder.Decode(&alerts) diff --git a/pkg/csplugin/hclog_adapter.go b/pkg/csplugin/hclog_adapter.go index 9550e4b4539..44a22463709 100644 --- a/pkg/csplugin/hclog_adapter.go +++ b/pkg/csplugin/hclog_adapter.go @@ -221,14 +221,13 @@ func merge(dst map[string]interface{}, k, v interface{}) { func safeString(str fmt.Stringer) (s string) { defer func() { if panicVal := recover(); panicVal != nil { - if v := reflect.ValueOf(str); v.Kind() == reflect.Ptr && v.IsNil() { - s = "NULL" - } else { + if v := reflect.ValueOf(str); v.Kind() != reflect.Ptr || !v.IsNil() { panic(panicVal) } + s = "NULL" } }() s = str.String() - return + return //nolint:revive // bare return for the defer } diff --git a/pkg/csplugin/helpers.go b/pkg/csplugin/helpers.go index 297742e8dd9..915f17e5dd3 100644 --- a/pkg/csplugin/helpers.go +++ b/pkg/csplugin/helpers.go @@ -1,9 +1,12 @@ package csplugin import ( + "html" "os" "text/template" + log "github.com/sirupsen/logrus" + "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/models" ) @@ -20,8 +23,15 @@ var helpers = template.FuncMap{ } return metaValues }, - "CrowdsecCTI": exprhelpers.CrowdsecCTI, - "Hostname": os.Hostname, + "CrowdsecCTI": func(x string) any { + ret, err := exprhelpers.CrowdsecCTI(x) + if err != nil { + log.Warningf("error while calling CrowdsecCTI : %s", err) + } + return ret + }, + "Hostname": os.Hostname, + "HTMLEscape": html.EscapeString, } func funcMap() template.FuncMap { diff --git a/pkg/csplugin/listfiles_test.go b/pkg/csplugin/listfiles_test.go index 09102ef0da1..c476d7a4e4a 100644 --- a/pkg/csplugin/listfiles_test.go +++ b/pkg/csplugin/listfiles_test.go @@ -8,7 +8,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/crowdsecurity/go-cs-lib/pkg/cstest" + "github.com/crowdsecurity/go-cs-lib/cstest" ) func TestListFilesAtPath(t *testing.T) { @@ -21,7 +21,7 @@ func TestListFilesAtPath(t *testing.T) { require.NoError(t, err) _, err = os.Create(filepath.Join(dir, "slack")) require.NoError(t, err) - err = os.Mkdir(filepath.Join(dir, "somedir"), 0755) + err = os.Mkdir(filepath.Join(dir, "somedir"), 0o755) require.NoError(t, err) _, err = os.Create(filepath.Join(dir, "somedir", "inner")) require.NoError(t, err) @@ -47,7 +47,6 @@ func TestListFilesAtPath(t *testing.T) { }, } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { got, err := listFilesAtPath(tc.path) cstest.RequireErrorContains(t, err, tc.expectedErr) diff --git a/pkg/csplugin/notifier.go b/pkg/csplugin/notifier.go index 8ab1aa923b8..615322ac0c3 100644 --- a/pkg/csplugin/notifier.go +++ b/pkg/csplugin/notifier.go @@ -2,7 +2,7 @@ package csplugin import ( "context" - "fmt" + "errors" plugin "github.com/hashicorp/go-plugin" "google.golang.org/grpc" @@ -10,23 +10,21 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/protobufs" ) -type Notifier interface { - Notify(ctx context.Context, notification *protobufs.Notification) (*protobufs.Empty, error) - Configure(ctx context.Context, cfg *protobufs.Config) (*protobufs.Empty, error) -} - type NotifierPlugin struct { plugin.Plugin - Impl Notifier + Impl protobufs.NotifierServer } -type GRPCClient struct{ client protobufs.NotifierClient } +type GRPCClient struct{ + protobufs.UnimplementedNotifierServer + client protobufs.NotifierClient +} func (m *GRPCClient) Notify(ctx context.Context, notification *protobufs.Notification) (*protobufs.Empty, error) { done := make(chan error) go func() { _, err := m.client.Notify( - context.Background(), &protobufs.Notification{Text: notification.Text, Name: notification.Name}, + ctx, &protobufs.Notification{Text: notification.Text, Name: notification.Name}, ) done <- err }() @@ -35,19 +33,17 @@ func (m *GRPCClient) Notify(ctx context.Context, notification *protobufs.Notific return &protobufs.Empty{}, err case <-ctx.Done(): - return &protobufs.Empty{}, fmt.Errorf("timeout exceeded") + return &protobufs.Empty{}, errors.New("timeout exceeded") } } func (m *GRPCClient) Configure(ctx context.Context, config *protobufs.Config) (*protobufs.Empty, error) { - _, err := m.client.Configure( - context.Background(), config, - ) + _, err := m.client.Configure(ctx, config) return &protobufs.Empty{}, err } type GRPCServer struct { - Impl Notifier + Impl protobufs.NotifierServer } func (p *NotifierPlugin) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) error { diff --git a/pkg/csplugin/utils.go b/pkg/csplugin/utils.go index 216a079d457..571d78add56 100644 --- a/pkg/csplugin/utils.go +++ b/pkg/csplugin/utils.go @@ -51,7 +51,7 @@ func getUID(username string) (uint32, error) { return 0, err } if uid < 0 || uid > math.MaxInt32 { - return 0, fmt.Errorf("out of bound uid") + return 0, errors.New("out of bound uid") } return uint32(uid), nil } @@ -66,7 +66,7 @@ func getGID(groupname string) (uint32, error) { return 0, err } if gid < 0 || gid > math.MaxInt32 { - return 0, fmt.Errorf("out of bound gid") + return 0, errors.New("out of bound gid") } return uint32(gid), nil } @@ -123,10 +123,10 @@ func pluginIsValid(path string) error { mode := details.Mode() perm := uint32(mode) - if (perm & 00002) != 0 { + if (perm & 0o0002) != 0 { return fmt.Errorf("plugin at %s is world writable, world writable plugins are invalid", path) } - if (perm & 00020) != 0 { + if (perm & 0o0020) != 0 { return fmt.Errorf("plugin at %s is group writable, group writable plugins are invalid", path) } if (mode & os.ModeSetgid) != 0 { diff --git a/pkg/csplugin/utils_test.go b/pkg/csplugin/utils_test.go index b4ac1e7e7cb..7fa9a77acd5 100644 --- a/pkg/csplugin/utils_test.go +++ b/pkg/csplugin/utils_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/crowdsecurity/go-cs-lib/pkg/cstest" + "github.com/crowdsecurity/go-cs-lib/cstest" ) func TestGetPluginNameAndTypeFromPath(t *testing.T) { @@ -37,7 +37,6 @@ func TestGetPluginNameAndTypeFromPath(t *testing.T) { }, } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { got, got1, err := getPluginTypeAndSubtypeFromPath(tc.path) cstest.RequireErrorContains(t, err, tc.expectedErr) diff --git a/pkg/csplugin/utils_windows.go b/pkg/csplugin/utils_windows.go index dfb11aff548..91002079398 100644 --- a/pkg/csplugin/utils_windows.go +++ b/pkg/csplugin/utils_windows.go @@ -3,6 +3,7 @@ package csplugin import ( + "errors" "fmt" "os" "os/exec" @@ -77,14 +78,14 @@ func CheckPerms(path string) error { return fmt.Errorf("while getting owner security info: %w", err) } if !sd.IsValid() { - return fmt.Errorf("security descriptor is invalid") + return errors.New("security descriptor is invalid") } owner, _, err := sd.Owner() if err != nil { return fmt.Errorf("while getting owner: %w", err) } if !owner.IsValid() { - return fmt.Errorf("owner is invalid") + return errors.New("owner is invalid") } if !owner.Equals(systemSid) && !owner.Equals(currentUserSid) && !owner.Equals(adminSid) { @@ -100,10 +101,6 @@ func CheckPerms(path string) error { return fmt.Errorf("no DACL found on plugin, meaning fully permissive access on plugin %s", path) } - if err != nil { - return fmt.Errorf("while looking up current user sid: %w", err) - } - rs := reflect.ValueOf(dacl).Elem() /* @@ -119,7 +116,7 @@ func CheckPerms(path string) error { */ aceCount := rs.Field(3).Uint() - for i := uint64(0); i < aceCount; i++ { + for i := range aceCount { ace := &AccessAllowedAce{} ret, _, _ := procGetAce.Call(uintptr(unsafe.Pointer(dacl)), uintptr(i), uintptr(unsafe.Pointer(&ace))) if ret == 0 { diff --git a/pkg/csplugin/utils_windows_test.go b/pkg/csplugin/utils_windows_test.go index 9161fd45b8b..1eb4dfb9033 100644 --- a/pkg/csplugin/utils_windows_test.go +++ b/pkg/csplugin/utils_windows_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/crowdsecurity/go-cs-lib/pkg/cstest" + "github.com/crowdsecurity/go-cs-lib/cstest" ) func TestGetPluginNameAndTypeFromPath(t *testing.T) { @@ -37,7 +37,6 @@ func TestGetPluginNameAndTypeFromPath(t *testing.T) { }, } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { got, got1, err := getPluginTypeAndSubtypeFromPath(tc.path) cstest.RequireErrorContains(t, err, tc.expectedErr) diff --git a/pkg/csplugin/watcher_test.go b/pkg/csplugin/watcher_test.go index 391a94810f5..84e63ec6493 100644 --- a/pkg/csplugin/watcher_test.go +++ b/pkg/csplugin/watcher_test.go @@ -10,16 +10,15 @@ import ( "github.com/stretchr/testify/require" "gopkg.in/tomb.v2" - "github.com/crowdsecurity/go-cs-lib/pkg/cstest" + "github.com/crowdsecurity/go-cs-lib/cstest" "github.com/crowdsecurity/crowdsec/pkg/models" ) -var ctx = context.Background() - func resetTestTomb(testTomb *tomb.Tomb, pw *PluginWatcher) { testTomb.Kill(nil) <-pw.PluginEvents + if err := testTomb.Wait(); err != nil { log.Fatal(err) } @@ -34,7 +33,7 @@ func resetWatcherAlertCounter(pw *PluginWatcher) { } func insertNAlertsToPlugin(pw *PluginWatcher, n int, pluginName string) { - for i := 0; i < n; i++ { + for range n { pw.Inserts <- pluginName } } @@ -46,13 +45,17 @@ func listenChannelWithTimeout(ctx context.Context, channel chan string) error { case <-ctx.Done(): return ctx.Err() } + return nil } func TestPluginWatcherInterval(t *testing.T) { + ctx := context.Background() + if runtime.GOOS == "windows" { t.Skip("Skipping test on windows because timing is not reliable") } + pw := PluginWatcher{} alertsByPluginName := make(map[string][]*models.Alert) testTomb := tomb.Tomb{} @@ -66,6 +69,7 @@ func TestPluginWatcherInterval(t *testing.T) { ct, cancel := context.WithTimeout(ctx, time.Microsecond) defer cancel() + err := listenChannelWithTimeout(ct, pw.PluginEvents) cstest.RequireErrorContains(t, err, "context deadline exceeded") resetTestTomb(&testTomb, &pw) @@ -74,6 +78,7 @@ func TestPluginWatcherInterval(t *testing.T) { ct, cancel = context.WithTimeout(ctx, time.Millisecond*5) defer cancel() + err = listenChannelWithTimeout(ct, pw.PluginEvents) require.NoError(t, err) resetTestTomb(&testTomb, &pw) @@ -81,9 +86,12 @@ func TestPluginWatcherInterval(t *testing.T) { } func TestPluginAlertCountWatcher(t *testing.T) { + ctx := context.Background() + if runtime.GOOS == "windows" { t.Skip("Skipping test on windows because timing is not reliable") } + pw := PluginWatcher{} alertsByPluginName := make(map[string][]*models.Alert) configs := map[string]PluginConfig{ @@ -92,28 +100,34 @@ func TestPluginAlertCountWatcher(t *testing.T) { }, } testTomb := tomb.Tomb{} + pw.Init(configs, alertsByPluginName) pw.Start(&testTomb) // Channel won't contain any events since threshold is not crossed. ct, cancel := context.WithTimeout(ctx, time.Second) defer cancel() + err := listenChannelWithTimeout(ct, pw.PluginEvents) cstest.RequireErrorContains(t, err, "context deadline exceeded") // Channel won't contain any events since threshold is not crossed. resetWatcherAlertCounter(&pw) insertNAlertsToPlugin(&pw, 4, "testPlugin") + ct, cancel = context.WithTimeout(ctx, time.Second) defer cancel() + err = listenChannelWithTimeout(ct, pw.PluginEvents) cstest.RequireErrorContains(t, err, "context deadline exceeded") // Channel will contain an event since threshold is crossed. resetWatcherAlertCounter(&pw) insertNAlertsToPlugin(&pw, 5, "testPlugin") + ct, cancel = context.WithTimeout(ctx, time.Second) defer cancel() + err = listenChannelWithTimeout(ct, pw.PluginEvents) require.NoError(t, err) resetTestTomb(&testTomb, &pw) diff --git a/pkg/csprofiles/csprofiles.go b/pkg/csprofiles/csprofiles.go index 7668e70cb5d..52cda1ed2e1 100644 --- a/pkg/csprofiles/csprofiles.go +++ b/pkg/csprofiles/csprofiles.go @@ -4,9 +4,8 @@ import ( "fmt" "time" - "github.com/antonmedv/expr" - "github.com/antonmedv/expr/vm" - "github.com/pkg/errors" + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" log "github.com/sirupsen/logrus" "github.com/crowdsecurity/crowdsec/pkg/csconfig" @@ -16,28 +15,29 @@ import ( ) type Runtime struct { - RuntimeFilters []*vm.Program `json:"-" yaml:"-"` - DebugFilters []*exprhelpers.ExprDebugger `json:"-" yaml:"-"` - RuntimeDurationExpr *vm.Program `json:"-" yaml:"-"` - DebugDurationExpr *exprhelpers.ExprDebugger `json:"-" yaml:"-"` - Cfg *csconfig.ProfileCfg `json:"-" yaml:"-"` - Logger *log.Entry `json:"-" yaml:"-"` + RuntimeFilters []*vm.Program `json:"-" yaml:"-"` + RuntimeDurationExpr *vm.Program `json:"-" yaml:"-"` + Cfg *csconfig.ProfileCfg `json:"-" yaml:"-"` + Logger *log.Entry `json:"-" yaml:"-"` } -var defaultDuration = "4h" +const defaultDuration = "4h" func NewProfile(profilesCfg []*csconfig.ProfileCfg) ([]*Runtime, error) { var err error + profilesRuntime := make([]*Runtime, 0) for _, profile := range profilesCfg { var runtimeFilter, runtimeDurationExpr *vm.Program - var debugFilter, debugDurationExpr *exprhelpers.ExprDebugger + runtime := &Runtime{} + xlog := log.New() if err := types.ConfigureLogger(xlog); err != nil { - log.Fatalf("While creating profiles-specific logger : %s", err) + return nil, fmt.Errorf("while configuring profiles-specific logger: %w", err) } + xlog.SetLevel(log.InfoLevel) runtime.Logger = xlog.WithFields(log.Fields{ "type": "profile", @@ -45,55 +45,54 @@ func NewProfile(profilesCfg []*csconfig.ProfileCfg) ([]*Runtime, error) { }) runtime.RuntimeFilters = make([]*vm.Program, len(profile.Filters)) - runtime.DebugFilters = make([]*exprhelpers.ExprDebugger, len(profile.Filters)) runtime.Cfg = profile + if runtime.Cfg.OnSuccess != "" && runtime.Cfg.OnSuccess != "continue" && runtime.Cfg.OnSuccess != "break" { - return []*Runtime{}, fmt.Errorf("invalid 'on_success' for '%s': %s", profile.Name, runtime.Cfg.OnSuccess) + return nil, fmt.Errorf("invalid 'on_success' for '%s': %s", profile.Name, runtime.Cfg.OnSuccess) } + if runtime.Cfg.OnFailure != "" && runtime.Cfg.OnFailure != "continue" && runtime.Cfg.OnFailure != "break" && runtime.Cfg.OnFailure != "apply" { - return []*Runtime{}, fmt.Errorf("invalid 'on_failure' for '%s' : %s", profile.Name, runtime.Cfg.OnFailure) + return nil, fmt.Errorf("invalid 'on_failure' for '%s' : %s", profile.Name, runtime.Cfg.OnFailure) } - for fIdx, filter := range profile.Filters { + for fIdx, filter := range profile.Filters { if runtimeFilter, err = expr.Compile(filter, exprhelpers.GetExprOptions(map[string]interface{}{"Alert": &models.Alert{}})...); err != nil { - return []*Runtime{}, errors.Wrapf(err, "error compiling filter of '%s'", profile.Name) + return nil, fmt.Errorf("error compiling filter of '%s': %w", profile.Name, err) } + runtime.RuntimeFilters[fIdx] = runtimeFilter if profile.Debug != nil && *profile.Debug { - if debugFilter, err = exprhelpers.NewDebugger(filter, exprhelpers.GetExprOptions(map[string]interface{}{"Alert": &models.Alert{}})...); err != nil { - log.Debugf("Error compiling debug filter of %s : %s", profile.Name, err) - // Don't fail if we can't compile the filter - for now - // return errors.Wrapf(err, "Error compiling debug filter of %s", profile.Name) - } - runtime.DebugFilters[fIdx] = debugFilter runtime.Logger.Logger.SetLevel(log.DebugLevel) } } if profile.DurationExpr != "" { if runtimeDurationExpr, err = expr.Compile(profile.DurationExpr, exprhelpers.GetExprOptions(map[string]interface{}{"Alert": &models.Alert{}})...); err != nil { - return []*Runtime{}, errors.Wrapf(err, "error compiling duration_expr of %s", profile.Name) + return nil, fmt.Errorf("error compiling duration_expr of %s: %w", profile.Name, err) } runtime.RuntimeDurationExpr = runtimeDurationExpr - if profile.Debug != nil && *profile.Debug { - if debugDurationExpr, err = exprhelpers.NewDebugger(profile.DurationExpr, exprhelpers.GetExprOptions(map[string]interface{}{"Alert": &models.Alert{}})...); err != nil { - log.Debugf("Error compiling debug duration_expr of %s : %s", profile.Name, err) - } - runtime.DebugDurationExpr = debugDurationExpr - } } for _, decision := range profile.Decisions { if runtime.RuntimeDurationExpr == nil { - if _, err := time.ParseDuration(*decision.Duration); err != nil { - return []*Runtime{}, errors.Wrapf(err, "error parsing duration '%s' of %s", *decision.Duration, profile.Name) + var duration string + if decision.Duration != nil { + duration = *decision.Duration + } else { + runtime.Logger.Warningf("No duration specified for %s, using default duration %s", profile.Name, defaultDuration) + duration = defaultDuration + } + + if _, err := time.ParseDuration(duration); err != nil { + return nil, fmt.Errorf("error parsing duration '%s' of %s: %w", duration, profile.Name, err) } } } profilesRuntime = append(profilesRuntime, runtime) } + return profilesRuntime, nil } @@ -120,26 +119,29 @@ func (Profile *Runtime) GenerateDecisionFromProfile(Alert *models.Alert) ([]*mod *decision.Scope = *Alert.Source.Scope } /*some fields are populated from the reference object : duration, scope, type*/ + decision.Duration = new(string) + if refDecision.Duration != nil { + *decision.Duration = *refDecision.Duration + } + if Profile.Cfg.DurationExpr != "" && Profile.RuntimeDurationExpr != nil { - duration, err := expr.Run(Profile.RuntimeDurationExpr, map[string]interface{}{"Alert": Alert}) + profileDebug := false + if Profile.Cfg.Debug != nil && *Profile.Cfg.Debug { + profileDebug = true + } + + duration, err := exprhelpers.Run(Profile.RuntimeDurationExpr, map[string]interface{}{"Alert": Alert}, Profile.Logger, profileDebug) if err != nil { Profile.Logger.Warningf("Failed to run duration_expr : %v", err) - *decision.Duration = *refDecision.Duration } else { durationStr := fmt.Sprint(duration) if _, err := time.ParseDuration(durationStr); err != nil { Profile.Logger.Warningf("Failed to parse expr duration result '%s'", duration) - *decision.Duration = *refDecision.Duration } else { *decision.Duration = durationStr } } - } else { - if refDecision.Duration == nil { - *decision.Duration = defaultDuration - } - *decision.Duration = *refDecision.Duration } decision.Type = new(string) @@ -150,13 +152,16 @@ func (Profile *Runtime) GenerateDecisionFromProfile(Alert *models.Alert) ([]*mod *decision.Value = *Alert.Source.Value decision.Origin = new(string) *decision.Origin = types.CrowdSecOrigin + if refDecision.Origin != nil { *decision.Origin = fmt.Sprintf("%s/%s", *decision.Origin, *refDecision.Origin) } + decision.Scenario = new(string) *decision.Scenario = *Alert.Scenario decisions = append(decisions, &decision) } + return decisions, nil } @@ -165,28 +170,33 @@ func (Profile *Runtime) EvaluateProfile(Alert *models.Alert) ([]*models.Decision var decisions []*models.Decision matched := false + for eIdx, expression := range Profile.RuntimeFilters { - output, err := expr.Run(expression, map[string]interface{}{"Alert": Alert}) + debugProfile := false + if Profile.Cfg.Debug != nil && *Profile.Cfg.Debug { + debugProfile = true + } + + output, err := exprhelpers.Run(expression, map[string]interface{}{"Alert": Alert}, Profile.Logger, debugProfile) if err != nil { - Profile.Logger.Warningf("failed to run whitelist expr : %v", err) - return nil, matched, errors.Wrapf(err, "while running expression %s", Profile.Cfg.Filters[eIdx]) + Profile.Logger.Warningf("failed to run profile expr for %s: %v", Profile.Cfg.Name, err) + return nil, matched, fmt.Errorf("while running expression %s: %w", Profile.Cfg.Filters[eIdx], err) } + switch out := output.(type) { case bool: - if Profile.Cfg.Debug != nil && *Profile.Cfg.Debug { - Profile.DebugFilters[eIdx].Run(Profile.Logger, out, map[string]interface{}{"Alert": Alert}) - } if out { matched = true /*the expression matched, create the associated decision*/ subdecisions, err := Profile.GenerateDecisionFromProfile(Alert) if err != nil { - return nil, matched, errors.Wrapf(err, "while generating decision from profile %s", Profile.Cfg.Name) + return nil, matched, fmt.Errorf("while generating decision from profile %s: %w", Profile.Cfg.Name, err) } decisions = append(decisions, subdecisions...) } else { Profile.Logger.Debugf("Profile %s filter is unsuccessful", Profile.Cfg.Name) + if Profile.Cfg.OnFailure == "break" { break } @@ -194,9 +204,7 @@ func (Profile *Runtime) EvaluateProfile(Alert *models.Alert) ([]*models.Decision default: return nil, matched, fmt.Errorf("unexpected type %t (%v) while running '%s'", output, output, Profile.Cfg.Filters[eIdx]) - } - } return decisions, matched, nil diff --git a/pkg/csprofiles/csprofiles_test.go b/pkg/csprofiles/csprofiles_test.go index 8adf6829134..0247243ddd3 100644 --- a/pkg/csprofiles/csprofiles_test.go +++ b/pkg/csprofiles/csprofiles_test.go @@ -86,10 +86,22 @@ func TestNewProfile(t *testing.T) { }, expectedNbProfile: 1, }, + { + name: "filter ok and no duration", + profileCfg: &csconfig.ProfileCfg{ + Filters: []string{ + "1==1", + }, + Debug: &boolTrue, + Decisions: []models.Decision{ + {Type: &typ, Scope: &scope, Simulated: &boolFalse}, + }, + }, + expectedNbProfile: 1, + }, } for _, test := range tests { - test := test t.Run(test.name, func(t *testing.T) { profilesCfg := []*csconfig.ProfileCfg{ test.profileCfg, @@ -183,7 +195,6 @@ func TestEvaluateProfile(t *testing.T) { }, } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { profilesCfg := []*csconfig.ProfileCfg{ tt.args.profileCfg, diff --git a/pkg/cticlient/client.go b/pkg/cticlient/client.go index 16876026a4c..90112d80abf 100644 --- a/pkg/cticlient/client.go +++ b/pkg/cticlient/client.go @@ -9,6 +9,8 @@ import ( "strings" log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/pkg/apiclient/useragent" ) const ( @@ -43,7 +45,10 @@ func (c *CrowdsecCTIClient) doRequest(method string, endpoint string, params map if err != nil { return nil, err } - req.Header.Set("x-api-key", c.apiKey) + + req.Header.Set("X-Api-Key", c.apiKey) + req.Header.Set("User-Agent", useragent.Default()) + resp, err := c.httpClient.Do(req) if err != nil { return nil, err @@ -71,7 +76,7 @@ func (c *CrowdsecCTIClient) doRequest(method string, endpoint string, params map func (c *CrowdsecCTIClient) GetIPInfo(ip string) (*SmokeItem, error) { body, err := c.doRequest(http.MethodGet, smokeEndpoint+"/"+ip, nil) if err != nil { - if err == ErrNotFound { + if errors.Is(err, ErrNotFound) { return &SmokeItem{}, nil } return nil, err diff --git a/pkg/cticlient/client_test.go b/pkg/cticlient/client_test.go index a8f22e09465..cdbbd0c9732 100644 --- a/pkg/cticlient/client_test.go +++ b/pkg/cticlient/client_test.go @@ -12,8 +12,9 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" - "github.com/crowdsecurity/go-cs-lib/pkg/ptr" + "github.com/crowdsecurity/go-cs-lib/ptr" ) const validApiKey = "my-api-key" @@ -36,25 +37,30 @@ func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { // wip func fireHandler(req *http.Request) *http.Response { var err error - apiKey := req.Header.Get("x-api-key") + + apiKey := req.Header.Get("X-Api-Key") if apiKey != validApiKey { log.Warningf("invalid api key: %s", apiKey) + return &http.Response{ StatusCode: http.StatusForbidden, Body: nil, Header: make(http.Header), } } + //unmarshal data if fireResponses == nil { page1, err := os.ReadFile("tests/fire-page1.json") if err != nil { panic("can't read file") } + page2, err := os.ReadFile("tests/fire-page2.json") if err != nil { panic("can't read file") } + fireResponses = []string{string(page1), string(page2)} } //let's assume we have two valid pages. @@ -70,6 +76,7 @@ func fireHandler(req *http.Request) *http.Response { //how to react if you give a page number that is too big ? if page > len(fireResponses) { log.Warningf(" page too big %d vs %d", page, len(fireResponses)) + emptyResponse := `{ "_links": { "first": { @@ -82,8 +89,10 @@ func fireHandler(req *http.Request) *http.Response { "items": [] } ` + return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(emptyResponse))} } + reader := io.NopCloser(strings.NewReader(fireResponses[page-1])) //we should care about limit too return &http.Response{ @@ -96,7 +105,7 @@ func fireHandler(req *http.Request) *http.Response { } func smokeHandler(req *http.Request) *http.Response { - apiKey := req.Header.Get("x-api-key") + apiKey := req.Header.Get("X-Api-Key") if apiKey != validApiKey { return &http.Response{ StatusCode: http.StatusForbidden, @@ -106,6 +115,7 @@ func smokeHandler(req *http.Request) *http.Response { } requestedIP := strings.Split(req.URL.Path, "/")[3] + response, ok := smokeResponses[requestedIP] if !ok { return &http.Response{ @@ -127,7 +137,7 @@ func smokeHandler(req *http.Request) *http.Response { } func rateLimitedHandler(req *http.Request) *http.Response { - apiKey := req.Header.Get("x-api-key") + apiKey := req.Header.Get("X-Api-Key") if apiKey != validApiKey { return &http.Response{ StatusCode: http.StatusForbidden, @@ -135,6 +145,7 @@ func rateLimitedHandler(req *http.Request) *http.Response { Header: make(http.Header), } } + return &http.Response{ StatusCode: http.StatusTooManyRequests, Body: nil, @@ -143,7 +154,7 @@ func rateLimitedHandler(req *http.Request) *http.Response { } func searchHandler(req *http.Request) *http.Response { - apiKey := req.Header.Get("x-api-key") + apiKey := req.Header.Get("X-Api-Key") if apiKey != validApiKey { return &http.Response{ StatusCode: http.StatusForbidden, @@ -151,7 +162,9 @@ func searchHandler(req *http.Request) *http.Response { Header: make(http.Header), } } + url, _ := url.Parse(req.URL.String()) + ipsParam := url.Query().Get("ips") if ipsParam == "" { return &http.Response{ @@ -163,6 +176,7 @@ func searchHandler(req *http.Request) *http.Response { totalIps := 0 notFound := 0 + ips := strings.Split(ipsParam, ",") for _, ip := range ips { _, ok := smokeResponses[ip] @@ -172,12 +186,15 @@ func searchHandler(req *http.Request) *http.Response { notFound++ } } + response := fmt.Sprintf(`{"total": %d, "not_found": %d, "items": [`, totalIps, notFound) for _, ip := range ips { response += smokeResponses[ip] } + response += "]}" reader := io.NopCloser(strings.NewReader(response)) + return &http.Response{ StatusCode: http.StatusOK, Body: reader, @@ -190,7 +207,7 @@ func TestBadFireAuth(t *testing.T) { Transport: RoundTripFunc(fireHandler), })) _, err := ctiClient.Fire(FireParams{}) - assert.EqualError(t, err, ErrUnauthorized.Error()) + require.EqualError(t, err, ErrUnauthorized.Error()) } func TestFireOk(t *testing.T) { @@ -198,19 +215,19 @@ func TestFireOk(t *testing.T) { Transport: RoundTripFunc(fireHandler), })) data, err := cticlient.Fire(FireParams{}) - assert.Equal(t, err, nil) - assert.Equal(t, len(data.Items), 3) - assert.Equal(t, data.Items[0].Ip, "1.2.3.4") + require.NoError(t, err) + assert.Len(t, data.Items, 3) + assert.Equal(t, "1.2.3.4", data.Items[0].Ip) //page 1 is the default data, err = cticlient.Fire(FireParams{Page: ptr.Of(1)}) - assert.Equal(t, err, nil) - assert.Equal(t, len(data.Items), 3) - assert.Equal(t, data.Items[0].Ip, "1.2.3.4") + require.NoError(t, err) + assert.Len(t, data.Items, 3) + assert.Equal(t, "1.2.3.4", data.Items[0].Ip) //page 2 data, err = cticlient.Fire(FireParams{Page: ptr.Of(2)}) - assert.Equal(t, err, nil) - assert.Equal(t, len(data.Items), 3) - assert.Equal(t, data.Items[0].Ip, "4.2.3.4") + require.NoError(t, err) + assert.Len(t, data.Items, 3) + assert.Equal(t, "4.2.3.4", data.Items[0].Ip) } func TestFirePaginator(t *testing.T) { @@ -219,17 +236,16 @@ func TestFirePaginator(t *testing.T) { })) paginator := NewFirePaginator(cticlient, FireParams{}) items, err := paginator.Next() - assert.Equal(t, err, nil) - assert.Equal(t, len(items), 3) - assert.Equal(t, items[0].Ip, "1.2.3.4") + require.NoError(t, err) + assert.Len(t, items, 3) + assert.Equal(t, "1.2.3.4", items[0].Ip) items, err = paginator.Next() - assert.Equal(t, err, nil) - assert.Equal(t, len(items), 3) - assert.Equal(t, items[0].Ip, "4.2.3.4") + require.NoError(t, err) + assert.Len(t, items, 3) + assert.Equal(t, "4.2.3.4", items[0].Ip) items, err = paginator.Next() - assert.Equal(t, err, nil) - assert.Equal(t, len(items), 0) - + require.NoError(t, err) + assert.Empty(t, items) } func TestBadSmokeAuth(t *testing.T) { @@ -237,13 +253,14 @@ func TestBadSmokeAuth(t *testing.T) { Transport: RoundTripFunc(smokeHandler), })) _, err := ctiClient.GetIPInfo("1.1.1.1") - assert.EqualError(t, err, ErrUnauthorized.Error()) + require.EqualError(t, err, ErrUnauthorized.Error()) } func TestSmokeInfoValidIP(t *testing.T) { ctiClient := NewCrowdsecCTIClient(WithAPIKey(validApiKey), WithHTTPClient(&http.Client{ Transport: RoundTripFunc(smokeHandler), })) + resp, err := ctiClient.GetIPInfo("1.1.1.1") if err != nil { t.Fatalf("failed to get ip info: %s", err) @@ -257,6 +274,7 @@ func TestSmokeUnknownIP(t *testing.T) { ctiClient := NewCrowdsecCTIClient(WithAPIKey(validApiKey), WithHTTPClient(&http.Client{ Transport: RoundTripFunc(smokeHandler), })) + resp, err := ctiClient.GetIPInfo("42.42.42.42") if err != nil { t.Fatalf("failed to get ip info: %s", err) @@ -270,20 +288,22 @@ func TestRateLimit(t *testing.T) { Transport: RoundTripFunc(rateLimitedHandler), })) _, err := ctiClient.GetIPInfo("1.1.1.1") - assert.EqualError(t, err, ErrLimit.Error()) + require.EqualError(t, err, ErrLimit.Error()) } func TestSearchIPs(t *testing.T) { ctiClient := NewCrowdsecCTIClient(WithAPIKey(validApiKey), WithHTTPClient(&http.Client{ Transport: RoundTripFunc(searchHandler), })) + resp, err := ctiClient.SearchIPs([]string{"1.1.1.1", "42.42.42.42"}) if err != nil { t.Fatalf("failed to search ips: %s", err) } + assert.Equal(t, 1, resp.Total) assert.Equal(t, 1, resp.NotFound) - assert.Equal(t, 1, len(resp.Items)) + assert.Len(t, resp.Items, 1) assert.Equal(t, "1.1.1.1", resp.Items[0].Ip) } diff --git a/pkg/cticlient/types_test.go b/pkg/cticlient/types_test.go index 1ec58cc78b0..a7308af35e0 100644 --- a/pkg/cticlient/types_test.go +++ b/pkg/cticlient/types_test.go @@ -5,7 +5,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/crowdsecurity/go-cs-lib/pkg/ptr" + "github.com/crowdsecurity/go-cs-lib/ptr" ) //func (c *SmokeItem) GetAttackDetails() []string { @@ -88,27 +88,28 @@ func getSampleSmokeItem() SmokeItem { }, }, } + return emptyItem } func TestBasicSmokeItem(t *testing.T) { item := getSampleSmokeItem() - assert.Equal(t, item.GetAttackDetails(), []string{"ssh:bruteforce"}) - assert.Equal(t, item.GetBehaviors(), []string{"ssh:bruteforce"}) - assert.Equal(t, item.GetMaliciousnessScore(), float32(0.1)) - assert.Equal(t, item.IsPartOfCommunityBlocklist(), false) - assert.Equal(t, item.GetBackgroundNoiseScore(), int(3)) - assert.Equal(t, item.GetFalsePositives(), []string{}) - assert.Equal(t, item.IsFalsePositive(), false) + assert.Equal(t, []string{"ssh:bruteforce"}, item.GetAttackDetails()) + assert.Equal(t, []string{"ssh:bruteforce"}, item.GetBehaviors()) + assert.InDelta(t, 0.1, item.GetMaliciousnessScore(), 0.000001) + assert.False(t, item.IsPartOfCommunityBlocklist()) + assert.Equal(t, 3, item.GetBackgroundNoiseScore()) + assert.Equal(t, []string{}, item.GetFalsePositives()) + assert.False(t, item.IsFalsePositive()) } func TestEmptySmokeItem(t *testing.T) { item := SmokeItem{} - assert.Equal(t, item.GetAttackDetails(), []string{}) - assert.Equal(t, item.GetBehaviors(), []string{}) - assert.Equal(t, item.GetMaliciousnessScore(), float32(0.0)) - assert.Equal(t, item.IsPartOfCommunityBlocklist(), false) - assert.Equal(t, item.GetBackgroundNoiseScore(), int(0)) - assert.Equal(t, item.GetFalsePositives(), []string{}) - assert.Equal(t, item.IsFalsePositive(), false) + assert.Equal(t, []string{}, item.GetAttackDetails()) + assert.Equal(t, []string{}, item.GetBehaviors()) + assert.InDelta(t, 0.0, item.GetMaliciousnessScore(), 0) + assert.False(t, item.IsPartOfCommunityBlocklist()) + assert.Equal(t, 0, item.GetBackgroundNoiseScore()) + assert.Equal(t, []string{}, item.GetFalsePositives()) + assert.False(t, item.IsFalsePositive()) } diff --git a/pkg/cwhub/cwhub.go b/pkg/cwhub/cwhub.go index bdd03c89ac5..683f1853b43 100644 --- a/pkg/cwhub/cwhub.go +++ b/pkg/cwhub/cwhub.go @@ -1,369 +1,46 @@ package cwhub import ( - "crypto/sha256" "fmt" - "io" - "os" + "net/http" "path/filepath" - "sort" "strings" + "time" - "github.com/enescakir/emoji" - "github.com/pkg/errors" - log "github.com/sirupsen/logrus" - "golang.org/x/mod/semver" + "github.com/crowdsecurity/crowdsec/pkg/apiclient/useragent" ) -/*managed configuration types*/ -var PARSERS = "parsers" -var PARSERS_OVFLW = "postoverflows" -var SCENARIOS = "scenarios" -var COLLECTIONS = "collections" -var ItemTypes = []string{PARSERS, PARSERS_OVFLW, SCENARIOS, COLLECTIONS} - -var hubIdx map[string]map[string]Item - -var RawFileURLTemplate = "https://hub-cdn.crowdsec.net/%s/%s" -var HubBranch = "master" -var HubIndexFile = ".index.json" - -type ItemVersion struct { - Digest string `json:"digest,omitempty"` - Deprecated bool `json:"deprecated,omitempty"` -} - -type ItemHubStatus struct { - Name string `json:"name"` - LocalVersion string `json:"local_version"` - LocalPath string `json:"local_path"` - Description string `json:"description"` - UTF8_Status string `json:"utf8_status"` - Status string `json:"status"` -} - -//Item can be : parsed, scenario, collection -type Item struct { - /*descriptive info*/ - Type string `yaml:"type,omitempty" json:"type,omitempty"` //parser|postoverflows|scenario|collection(|enrich) - Stage string `json:"stage,omitempty" yaml:"stage,omitempty,omitempty"` //Stage for parser|postoverflow : s00-raw/s01-... - Name string `json:"name,omitempty"` //as seen in .config.json, usually "author/name" - FileName string `json:"file_name,omitempty"` //the filename, ie. apache2-logs.yaml - Description string `yaml:"description,omitempty" json:"description,omitempty"` //as seen in .config.json - Author string `json:"author,omitempty"` //as seen in .config.json - References []string `yaml:"references,omitempty" json:"references,omitempty"` //as seen in .config.json - BelongsToCollections []string `yaml:"belongs_to_collections,omitempty" json:"belongs_to_collections,omitempty"` /*if it's part of collections, track name here*/ - - /*remote (hub) infos*/ - RemoteURL string `yaml:"remoteURL,omitempty" json:"remoteURL,omitempty"` //the full remote uri of file in http - RemotePath string `json:"path,omitempty" yaml:"remote_path,omitempty"` //the path relative to git ie. /parsers/stage/author/file.yaml - RemoteHash string `yaml:"hash,omitempty" json:"hash,omitempty"` //the meow - Version string `json:"version,omitempty"` //the last version - Versions map[string]ItemVersion `json:"versions,omitempty" yaml:"-"` //the list of existing versions - - /*local (deployed) infos*/ - LocalPath string `yaml:"local_path,omitempty" json:"local_path,omitempty"` //the local path relative to ${CFG_DIR} - //LocalHubPath string - LocalVersion string `json:"local_version,omitempty"` - LocalHash string `json:"local_hash,omitempty"` //the local meow - Installed bool `json:"installed,omitempty"` - Downloaded bool `json:"downloaded,omitempty"` - UpToDate bool `json:"up_to_date,omitempty"` - Tainted bool `json:"tainted,omitempty"` //has it been locally modified - Local bool `json:"local,omitempty"` //if it's a non versioned control one - - /*if it's a collection, it not a single file*/ - Parsers []string `yaml:"parsers,omitempty" json:"parsers,omitempty"` - PostOverflows []string `yaml:"postoverflows,omitempty" json:"postoverflows,omitempty"` - Scenarios []string `yaml:"scenarios,omitempty" json:"scenarios,omitempty"` - Collections []string `yaml:"collections,omitempty" json:"collections,omitempty"` -} - -func (i *Item) toHubStatus() ItemHubStatus { - hubStatus := ItemHubStatus{} - hubStatus.Name = i.Name - hubStatus.LocalVersion = i.LocalVersion - hubStatus.LocalPath = i.LocalPath - hubStatus.Description = i.Description - - status, ok, warning, managed := ItemStatus(*i) - hubStatus.Status = status - if !managed { - hubStatus.UTF8_Status = fmt.Sprintf("%v %s", emoji.House, status) - } else if !i.Installed { - hubStatus.UTF8_Status = fmt.Sprintf("%v %s", emoji.Prohibited, status) - } else if warning { - hubStatus.UTF8_Status = fmt.Sprintf("%v %s", emoji.Warning, status) - } else if ok { - hubStatus.UTF8_Status = fmt.Sprintf("%v %s", emoji.CheckMark, status) - } - return hubStatus -} - -var skippedLocal = 0 -var skippedTainted = 0 - -/*To be used when reference(s) (is/are) missing in a collection*/ -var ReferenceMissingError = errors.New("Reference(s) missing in collection") -var MissingHubIndex = errors.New("hub index can't be found") - -//GetVersionStatus : semver requires 'v' prefix -func GetVersionStatus(v *Item) int { - return semver.Compare("v"+v.Version, "v"+v.LocalVersion) -} - -// calculate sha256 of a file -func getSHA256(filepath string) (string, error) { - /* Digest of file */ - f, err := os.Open(filepath) - if err != nil { - return "", fmt.Errorf("unable to open '%s' : %s", filepath, err) - } - - defer f.Close() - - h := sha256.New() - if _, err := io.Copy(h, f); err != nil { - return "", fmt.Errorf("unable to calculate sha256 of '%s': %s", filepath, err) - } - - return fmt.Sprintf("%x", h.Sum(nil)), nil -} - -func GetItemMap(itemType string) map[string]Item { - var m map[string]Item - var ok bool - - if m, ok = hubIdx[itemType]; !ok { - return nil - } - return m -} - -//GetItemByPath retrieves the item from hubIdx based on the path. To achieve this it will resolve symlink to find associated hub item. -func GetItemByPath(itemType string, itemPath string) (*Item, error) { - /*try to resolve symlink*/ - finalName := "" - f, err := os.Lstat(itemPath) - if err != nil { - return nil, fmt.Errorf("while performing lstat on %s: %w", itemPath, err) - } - - if f.Mode()&os.ModeSymlink == 0 { - /*it's not a symlink, it should be the filename itsef the key*/ - finalName = filepath.Base(itemPath) - } else { - /*resolve the symlink to hub file*/ - pathInHub, err := os.Readlink(itemPath) - if err != nil { - return nil, fmt.Errorf("while reading symlink of %s: %w", itemPath, err) - } - //extract author from path - fname := filepath.Base(pathInHub) - author := filepath.Base(filepath.Dir(pathInHub)) - //trim yaml suffix - fname = strings.TrimSuffix(fname, ".yaml") - fname = strings.TrimSuffix(fname, ".yml") - finalName = fmt.Sprintf("%s/%s", author, fname) - } - - /*it's not a symlink, it should be the filename itsef the key*/ - if m := GetItemMap(itemType); m != nil { - if v, ok := m[finalName]; ok { - return &v, nil - } - return nil, fmt.Errorf("%s not found in %s", finalName, itemType) - } - return nil, fmt.Errorf("item type %s doesn't exist", itemType) -} - -func GetItem(itemType string, itemName string) *Item { - if m, ok := GetItemMap(itemType)[itemName]; ok { - return &m - } - return nil -} - -func AddItem(itemType string, item Item) error { - in := false - for _, itype := range ItemTypes { - if itype == itemType { - in = true - } - } - if !in { - return fmt.Errorf("ItemType %s is unknown", itemType) - } - hubIdx[itemType][item.Name] = item - return nil -} - -func DisplaySummary() { - log.Printf("Loaded %d collecs, %d parsers, %d scenarios, %d post-overflow parsers", len(hubIdx[COLLECTIONS]), - len(hubIdx[PARSERS]), len(hubIdx[SCENARIOS]), len(hubIdx[PARSERS_OVFLW])) - if skippedLocal > 0 || skippedTainted > 0 { - log.Printf("unmanaged items : %d local, %d tainted", skippedLocal, skippedTainted) - } -} - -//returns: human-text, Enabled, Warning, Unmanaged -func ItemStatus(v Item) (string, bool, bool, bool) { - strret := "disabled" - Ok := false - if v.Installed { - Ok = true - strret = "enabled" - } - - Managed := true - if v.Local { - Managed = false - strret += ",local" - } - - //tainted or out of date - Warning := false - if v.Tainted { - Warning = true - strret += ",tainted" - } else if !v.UpToDate && !v.Local { - strret += ",update-available" - Warning = true - } - return strret, Ok, Warning, Managed -} - -func GetInstalledScenariosAsString() ([]string, error) { - var retStr []string - - items, err := GetInstalledScenarios() - if err != nil { - return nil, fmt.Errorf("while fetching scenarios: %w", err) - } - for _, it := range items { - retStr = append(retStr, it.Name) - } - return retStr, nil -} - -func GetInstalledScenarios() ([]Item, error) { - var retItems []Item - - if _, ok := hubIdx[SCENARIOS]; !ok { - return nil, fmt.Errorf("no scenarios in hubIdx") - } - for _, item := range hubIdx[SCENARIOS] { - if item.Installed { - retItems = append(retItems, item) - } - } - return retItems, nil -} - -func GetInstalledParsers() ([]Item, error) { - var retItems []Item - - if _, ok := hubIdx[PARSERS]; !ok { - return nil, fmt.Errorf("no parsers in hubIdx") - } - for _, item := range hubIdx[PARSERS] { - if item.Installed { - retItems = append(retItems, item) - } - } - return retItems, nil +// hubTransport wraps a Transport to set a custom User-Agent. +type hubTransport struct { + http.RoundTripper } -func GetInstalledParsersAsString() ([]string, error) { - var retStr []string - - items, err := GetInstalledParsers() - if err != nil { - return nil, fmt.Errorf("while fetching parsers: %w", err) - } - for _, it := range items { - retStr = append(retStr, it.Name) - } - return retStr, nil +func (t *hubTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req.Header.Set("User-Agent", useragent.Default()) + return t.RoundTripper.RoundTrip(req) } -func GetInstalledPostOverflows() ([]Item, error) { - var retItems []Item - - if _, ok := hubIdx[PARSERS_OVFLW]; !ok { - return nil, fmt.Errorf("no post overflows in hubIdx") - } - for _, item := range hubIdx[PARSERS_OVFLW] { - if item.Installed { - retItems = append(retItems, item) - } - } - return retItems, nil +// hubClient is the HTTP client used to communicate with the CrowdSec Hub. +var hubClient = &http.Client{ + Timeout: 120 * time.Second, + Transport: &hubTransport{http.DefaultTransport}, } -func GetInstalledPostOverflowsAsString() ([]string, error) { - var retStr []string - - items, err := GetInstalledPostOverflows() +// safePath returns a joined path and ensures that it does not escape the base directory. +func safePath(dir, filePath string) (string, error) { + absBaseDir, err := filepath.Abs(filepath.Clean(dir)) if err != nil { - return nil, fmt.Errorf("while fetching post overflows: %w", err) - } - for _, it := range items { - retStr = append(retStr, it.Name) + return "", err } - return retStr, nil -} -func GetInstalledCollectionsAsString() ([]string, error) { - var retStr []string - - items, err := GetInstalledCollections() + absFilePath, err := filepath.Abs(filepath.Join(dir, filePath)) if err != nil { - return nil, fmt.Errorf("while fetching collections: %w", err) + return "", err } - for _, it := range items { - retStr = append(retStr, it.Name) + if !strings.HasPrefix(absFilePath, absBaseDir) { + return "", fmt.Errorf("path %s escapes base directory %s", filePath, dir) } - return retStr, nil -} - -func GetInstalledCollections() ([]Item, error) { - var retItems []Item - if _, ok := hubIdx[COLLECTIONS]; !ok { - return nil, fmt.Errorf("no collection in hubIdx") - } - for _, item := range hubIdx[COLLECTIONS] { - if item.Installed { - retItems = append(retItems, item) - } - } - return retItems, nil -} - -//Returns a list of entries for packages : name, status, local_path, local_version, utf8_status (fancy) -func GetHubStatusForItemType(itemType string, name string, all bool) []ItemHubStatus { - if _, ok := hubIdx[itemType]; !ok { - log.Errorf("type %s doesn't exist", itemType) - - return nil - } - - var ret = make([]ItemHubStatus, 0) - /*remember, you do it for the user :)*/ - for _, item := range hubIdx[itemType] { - if name != "" && name != item.Name { - //user has requested a specific name - continue - } - //Only enabled items ? - if !all && !item.Installed { - continue - } - //Check the item status - ret = append(ret, item.toHubStatus()) - } - sort.Slice(ret, func(i, j int) bool { return ret[i].Name < ret[j].Name }) - return ret + return absFilePath, nil } diff --git a/pkg/cwhub/cwhub_test.go b/pkg/cwhub/cwhub_test.go index f91b0dcedff..17e7a0dc723 100644 --- a/pkg/cwhub/cwhub_test.go +++ b/pkg/cwhub/cwhub_test.go @@ -1,6 +1,7 @@ package cwhub import ( + "context" "fmt" "io" "net/http" @@ -9,10 +10,14 @@ import ( "strings" "testing" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + + "github.com/crowdsecurity/crowdsec/pkg/csconfig" ) +const mockURLTemplate = "https://cdn-hub.crowdsec.net/crowdsecurity/%s/%s" + /* To test : - Download 'first' hub index @@ -24,360 +29,72 @@ import ( var responseByPath map[string]string -func TestItemStatus(t *testing.T) { - cfg := envSetup(t) - defer envTearDown(cfg) - - err := UpdateHubIdx(cfg.Hub) - //DownloadHubIdx() - if err != nil { - t.Fatalf("failed to download index : %s", err) - } - if err := GetHubIdx(cfg.Hub); err != nil { - t.Fatalf("failed to load hub index : %s", err) - } - - //get existing map - x := GetItemMap(COLLECTIONS) - if len(x) == 0 { - t.Fatalf("expected non empty result") - } - - //Get item : good and bad - for k := range x { - item := GetItem(COLLECTIONS, k) - if item == nil { - t.Fatalf("expected item") - } - item.Installed = true - item.UpToDate = false - item.Local = false - item.Tainted = false - txt, _, _, _ := ItemStatus(*item) - if txt != "enabled,update-available" { - t.Fatalf("got '%s'", txt) - } - - item.Installed = false - item.UpToDate = false - item.Local = true - item.Tainted = false - txt, _, _, _ = ItemStatus(*item) - if txt != "disabled,local" { - t.Fatalf("got '%s'", txt) - } - - break - } - DisplaySummary() -} - -func TestGetters(t *testing.T) { - cfg := envSetup(t) - defer envTearDown(cfg) - - err := UpdateHubIdx(cfg.Hub) - //DownloadHubIdx() - if err != nil { - t.Fatalf("failed to download index : %s", err) - } - if err := GetHubIdx(cfg.Hub); err != nil { - t.Fatalf("failed to load hub index : %s", err) - } - - //get non existing map - empty := GetItemMap("ratata") - if empty != nil { - t.Fatalf("expected nil result") - } - //get existing map - x := GetItemMap(COLLECTIONS) - if len(x) == 0 { - t.Fatalf("expected non empty result") - } +// testHub initializes a temporary hub with an empty json file, optionally updating it. +func testHub(t *testing.T, update bool) *Hub { + tmpDir, err := os.MkdirTemp("", "testhub") + require.NoError(t, err) - //Get item : good and bad - for k := range x { - empty := GetItem(COLLECTIONS, k+"nope") - if empty != nil { - t.Fatalf("expected empty item") - } - - item := GetItem(COLLECTIONS, k) - if item == nil { - t.Fatalf("expected non empty item") - } - - //Add item and get it - item.Name += "nope" - if err := AddItem(COLLECTIONS, *item); err != nil { - t.Fatalf("didn't expect error : %s", err) - } - - newitem := GetItem(COLLECTIONS, item.Name) - if newitem == nil { - t.Fatalf("expected non empty item") - } - - //Add bad item - if err := AddItem("ratata", *item); err != nil { - if fmt.Sprintf("%s", err) != "ItemType ratata is unknown" { - t.Fatalf("unexpected error") - } - } else { - t.Fatalf("Expected error") - } - - break + local := &csconfig.LocalHubCfg{ + HubDir: filepath.Join(tmpDir, "crowdsec", "hub"), + HubIndexFile: filepath.Join(tmpDir, "crowdsec", "hub", ".index.json"), + InstallDir: filepath.Join(tmpDir, "crowdsec"), + InstallDataDir: filepath.Join(tmpDir, "installed-data"), } -} - -func TestIndexDownload(t *testing.T) { - cfg := envSetup(t) - defer envTearDown(cfg) + err = os.MkdirAll(local.HubDir, 0o700) + require.NoError(t, err) - err := UpdateHubIdx(cfg.Hub) - //DownloadHubIdx() - if err != nil { - t.Fatalf("failed to download index : %s", err) - } - if err := GetHubIdx(cfg.Hub); err != nil { - t.Fatalf("failed to load hub index : %s", err) - } -} + err = os.MkdirAll(local.InstallDir, 0o700) + require.NoError(t, err) -func getTestCfg() (cfg *csconfig.Config) { - cfg = &csconfig.Config{Hub: &csconfig.Hub{}} - cfg.Hub.ConfigDir, _ = filepath.Abs("./install") - cfg.Hub.HubDir, _ = filepath.Abs("./hubdir") - cfg.Hub.HubIndexFile = filepath.Clean("./hubdir/.index.json") - return -} + err = os.MkdirAll(local.InstallDataDir, 0o700) + require.NoError(t, err) -func envSetup(t *testing.T) *csconfig.Config { - resetResponseByPath() - log.SetLevel(log.DebugLevel) - cfg := getTestCfg() + err = os.WriteFile(local.HubIndexFile, []byte("{}"), 0o644) + require.NoError(t, err) - defaultTransport := http.DefaultClient.Transport t.Cleanup(func() { - http.DefaultClient.Transport = defaultTransport + os.RemoveAll(tmpDir) }) - //Mock the http client - http.DefaultClient.Transport = newMockTransport() - - if err := os.MkdirAll(cfg.Hub.ConfigDir, 0700); err != nil { - log.Fatalf("mkdir : %s", err) + remote := &RemoteHubCfg{ + Branch: "master", + URLTemplate: mockURLTemplate, + IndexPath: ".index.json", } - if err := os.MkdirAll(cfg.Hub.HubDir, 0700); err != nil { - log.Fatalf("failed to mkdir %s : %s", cfg.Hub.HubDir, err) - } - - if err := UpdateHubIdx(cfg.Hub); err != nil { - log.Fatalf("failed to download index : %s", err) - } - - // if err := os.RemoveAll(cfg.Hub.InstallDir); err != nil { - // log.Fatalf("failed to remove %s : %s", cfg.Hub.InstallDir, err) - // } - // if err := os.MkdirAll(cfg.Hub.InstallDir, 0700); err != nil { - // log.Fatalf("failed to mkdir %s : %s", cfg.Hub.InstallDir, err) - // } - return cfg -} - + hub, err := NewHub(local, remote, log.StandardLogger()) + require.NoError(t, err) -func envTearDown(cfg *csconfig.Config) { - if err := os.RemoveAll(cfg.Hub.ConfigDir); err != nil { - log.Fatalf("failed to remove %s : %s", cfg.Hub.ConfigDir, err) + if update { + ctx := context.Background() + err := hub.Update(ctx) + require.NoError(t, err) } - if err := os.RemoveAll(cfg.Hub.HubDir); err != nil { - log.Fatalf("failed to remove %s : %s", cfg.Hub.HubDir, err) - } -} - - -func testInstallItem(cfg *csconfig.Hub, t *testing.T, item Item) { - - //Install the parser - item, err := DownloadLatest(cfg, item, false, false) - if err != nil { - t.Fatalf("error while downloading %s : %v", item.Name, err) - } - if err, _ := LocalSync(cfg); err != nil { - t.Fatalf("taint: failed to run localSync : %s", err) - } - if !hubIdx[item.Type][item.Name].UpToDate { - t.Fatalf("download: %s should be up-to-date", item.Name) - } - if hubIdx[item.Type][item.Name].Installed { - t.Fatalf("download: %s should not be installed", item.Name) - } - if hubIdx[item.Type][item.Name].Tainted { - t.Fatalf("download: %s should not be tainted", item.Name) - } + err = hub.Load() + require.NoError(t, err) - item, err = EnableItem(cfg, item) - if err != nil { - t.Fatalf("error while enabling %s : %v.", item.Name, err) - } - if err, _ := LocalSync(cfg); err != nil { - t.Fatalf("taint: failed to run localSync : %s", err) - } - if !hubIdx[item.Type][item.Name].Installed { - t.Fatalf("install: %s should be installed", item.Name) - } + return hub } -func testTaintItem(cfg *csconfig.Hub, t *testing.T, item Item) { - if hubIdx[item.Type][item.Name].Tainted { - t.Fatalf("pre-taint: %s should not be tainted", item.Name) - } - f, err := os.OpenFile(item.LocalPath, os.O_APPEND|os.O_WRONLY, 0600) - if err != nil { - t.Fatalf("(taint) opening %s (%s) : %s", item.LocalPath, item.Name, err) - } - defer f.Close() - - if _, err = f.WriteString("tainted"); err != nil { - t.Fatalf("tainting %s : %s", item.Name, err) - } - //Local sync and check status - if err, _ := LocalSync(cfg); err != nil { - t.Fatalf("taint: failed to run localSync : %s", err) - } - if !hubIdx[item.Type][item.Name].Tainted { - t.Fatalf("taint: %s should be tainted", item.Name) - } -} +// envSetup initializes the temporary hub and mocks the http client. +func envSetup(t *testing.T) *Hub { + setResponseByPath() + log.SetLevel(log.DebugLevel) -func testUpdateItem(cfg *csconfig.Hub, t *testing.T, item Item) { + defaultTransport := hubClient.Transport - if hubIdx[item.Type][item.Name].UpToDate { - t.Fatalf("update: %s should NOT be up-to-date", item.Name) - } - //Update it + check status - item, err := DownloadLatest(cfg, item, true, true) - if err != nil { - t.Fatalf("failed to update %s : %s", item.Name, err) - } - //Local sync and check status - if err, _ := LocalSync(cfg); err != nil { - t.Fatalf("failed to run localSync : %s", err) - } - if !hubIdx[item.Type][item.Name].UpToDate { - t.Fatalf("update: %s should be up-to-date", item.Name) - } - if hubIdx[item.Type][item.Name].Tainted { - t.Fatalf("update: %s should not be tainted anymore", item.Name) - } -} + t.Cleanup(func() { + hubClient.Transport = defaultTransport + }) -func testDisableItem(cfg *csconfig.Hub, t *testing.T, item Item) { - if !item.Installed { - t.Fatalf("disable: %s should be installed", item.Name) - } - //Remove - item, err := DisableItem(cfg, item, false, false) - if err != nil { - t.Fatalf("failed to disable item : %v", err) - } - //Local sync and check status - if err, warns := LocalSync(cfg); err != nil || len(warns) > 0 { - t.Fatalf("failed to run localSync : %s (%+v)", err, warns) - } - if hubIdx[item.Type][item.Name].Tainted { - t.Fatalf("disable: %s should not be tainted anymore", item.Name) - } - if hubIdx[item.Type][item.Name].Installed { - t.Fatalf("disable: %s should not be installed anymore", item.Name) - } - if !hubIdx[item.Type][item.Name].Downloaded { - t.Fatalf("disable: %s should still be downloaded", item.Name) - } - //Purge - item, err = DisableItem(cfg, item, true, false) - if err != nil { - t.Fatalf("failed to purge item : %v", err) - } - //Local sync and check status - if err, warns := LocalSync(cfg); err != nil || len(warns) > 0 { - t.Fatalf("failed to run localSync : %s (%+v)", err, warns) - } - if hubIdx[item.Type][item.Name].Installed { - t.Fatalf("disable: %s should not be installed anymore", item.Name) - } - if hubIdx[item.Type][item.Name].Downloaded { - t.Fatalf("disable: %s should not be downloaded", item.Name) - } -} + // Mock the http client + hubClient.Transport = newMockTransport() -func TestInstallParser(t *testing.T) { - - /* - - install a random parser - - check its status - - taint it - - check its status - - force update it - - check its status - - remove it - */ - cfg := envSetup(t) - defer envTearDown(cfg) - - getHubIdxOrFail(t) - //map iteration is random by itself - for _, it := range hubIdx[PARSERS] { - testInstallItem(cfg.Hub, t, it) - it = hubIdx[PARSERS][it.Name] - _ = GetHubStatusForItemType(PARSERS, it.Name, false) - testTaintItem(cfg.Hub, t, it) - it = hubIdx[PARSERS][it.Name] - _ = GetHubStatusForItemType(PARSERS, it.Name, false) - testUpdateItem(cfg.Hub, t, it) - it = hubIdx[PARSERS][it.Name] - testDisableItem(cfg.Hub, t, it) - it = hubIdx[PARSERS][it.Name] - - break - } -} + hub := testHub(t, true) -func TestInstallCollection(t *testing.T) { - - /* - - install a random parser - - check its status - - taint it - - check its status - - force update it - - check its status - - remove it - */ - cfg := envSetup(t) - defer envTearDown(cfg) - - getHubIdxOrFail(t) - //map iteration is random by itself - for _, it := range hubIdx[COLLECTIONS] { - testInstallItem(cfg.Hub, t, it) - it = hubIdx[COLLECTIONS][it.Name] - testTaintItem(cfg.Hub, t, it) - it = hubIdx[COLLECTIONS][it.Name] - testUpdateItem(cfg.Hub, t, it) - it = hubIdx[COLLECTIONS][it.Name] - testDisableItem(cfg.Hub, t, it) - - it = hubIdx[COLLECTIONS][it.Name] - x := GetHubStatusForItemType(COLLECTIONS, it.Name, false) - log.Printf("%+v", x) - break - } + return hub } type mockTransport struct{} @@ -386,7 +103,7 @@ func newMockTransport() http.RoundTripper { return &mockTransport{} } -// Implement http.RoundTripper +// Implement http.RoundTripper. func (t *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) { // Create mocked http.Response response := &http.Response{ @@ -395,47 +112,49 @@ func (t *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) { StatusCode: http.StatusOK, } response.Header.Set("Content-Type", "application/json") - responseBody := "" - log.Printf("---> %s", req.URL.Path) - - /*FAKE PARSER*/ - if resp, ok := responseByPath[req.URL.Path]; ok { - responseBody = resp - } else { - log.Fatalf("unexpected url :/ %s", req.URL.Path) + + log.Infof("---> %s", req.URL.Path) + + // FAKE PARSER + resp, ok := responseByPath[req.URL.Path] + if !ok { + return nil, fmt.Errorf("unexpected url: %s", req.URL.Path) } - response.Body = io.NopCloser(strings.NewReader(responseBody)) + response.Body = io.NopCloser(strings.NewReader(resp)) + return response, nil } func fileToStringX(path string) string { - if f, err := os.Open(path); err == nil { - defer f.Close() - if data, err := io.ReadAll(f); err == nil { - return strings.ReplaceAll(string(data), "\r\n", "\n") - } else { - panic(err) - } - } else { + f, err := os.Open(path) + if err != nil { panic(err) } + defer f.Close() + + data, err := io.ReadAll(f) + if err != nil { + panic(err) + } + + return strings.ReplaceAll(string(data), "\r\n", "\n") } -func resetResponseByPath() { +func setResponseByPath() { responseByPath = map[string]string{ - "/master/parsers/s01-parse/crowdsecurity/foobar_parser.yaml": fileToStringX("./tests/foobar_parser.yaml"), - "/master/parsers/s01-parse/crowdsecurity/foobar_subparser.yaml": fileToStringX("./tests/foobar_parser.yaml"), - "/master/collections/crowdsecurity/test_collection.yaml": fileToStringX("./tests/collection_v1.yaml"), - "/master/.index.json": fileToStringX("./tests/index1.json"), - "/master/scenarios/crowdsecurity/foobar_scenario.yaml": `filter: true + "/crowdsecurity/master/parsers/s01-parse/crowdsecurity/foobar_parser.yaml": fileToStringX("./testdata/foobar_parser.yaml"), + "/crowdsecurity/master/parsers/s01-parse/crowdsecurity/foobar_subparser.yaml": fileToStringX("./testdata/foobar_parser.yaml"), + "/crowdsecurity/master/collections/crowdsecurity/test_collection.yaml": fileToStringX("./testdata/collection_v1.yaml"), + "/crowdsecurity/master/.index.json": fileToStringX("./testdata/index1.json"), + "/crowdsecurity/master/scenarios/crowdsecurity/foobar_scenario.yaml": `filter: true name: crowdsecurity/foobar_scenario`, - "/master/scenarios/crowdsecurity/barfoo_scenario.yaml": `filter: true + "/crowdsecurity/master/scenarios/crowdsecurity/barfoo_scenario.yaml": `filter: true name: crowdsecurity/foobar_scenario`, - "/master/collections/crowdsecurity/foobar_subcollection.yaml": ` + "/crowdsecurity/master/collections/crowdsecurity/foobar_subcollection.yaml": ` blah: blalala qwe: jejwejejw`, - "/master/collections/crowdsecurity/foobar.yaml": ` + "/crowdsecurity/master/collections/crowdsecurity/foobar.yaml": ` blah: blalala qwe: jejwejejw`, } diff --git a/pkg/cwhub/dataset.go b/pkg/cwhub/dataset.go index 848686be69d..90bc9e057f9 100644 --- a/pkg/cwhub/dataset.go +++ b/pkg/cwhub/dataset.go @@ -1,68 +1,70 @@ package cwhub import ( + "context" + "errors" "fmt" "io" - "net/http" - "os" - "path" + "time" - log "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/go-cs-lib/downloader" "github.com/crowdsecurity/crowdsec/pkg/types" ) +// The DataSet is a list of data sources required by an item (built from the data: section in the yaml). type DataSet struct { - Data []*types.DataSource `yaml:"data,omitempty"` + Data []types.DataSource `yaml:"data,omitempty"` } -func downloadFile(url string, destPath string) error { - log.Debugf("downloading %s in %s", url, destPath) - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return err - } +// downloadDataSet downloads all the data files for an item. +func downloadDataSet(ctx context.Context, dataFolder string, force bool, reader io.Reader, logger *logrus.Logger) error { + dec := yaml.NewDecoder(reader) - resp, err := http.DefaultClient.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() + for { + data := &DataSet{} - body, err := io.ReadAll(resp.Body) - if err != nil { - return err - } + if err := dec.Decode(data); err != nil { + if errors.Is(err, io.EOF) { + break + } - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("download response 'HTTP %d' : %s", resp.StatusCode, string(body)) - } + return fmt.Errorf("while reading file: %w", err) + } - file, err := os.OpenFile(destPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0644) - if err != nil { - return err - } + for _, dataS := range data.Data { + destPath, err := safePath(dataFolder, dataS.DestPath) + if err != nil { + return err + } - _, err = file.WriteString(string(body)) - if err != nil { - return err - } + d := downloader. + New(). + WithHTTPClient(hubClient). + ToFile(destPath). + CompareContent(). + WithLogger(logrus.WithField("url", dataS.SourceURL)) - err = file.Sync() - if err != nil { - return err - } + if !force { + d = d.WithLastModified(). + WithShelfLife(7 * 24 * time.Hour) + } - return nil -} + downloaded, err := d.Download(ctx, dataS.SourceURL) + if err != nil { + return fmt.Errorf("while getting data: %w", err) + } -func GetData(data []*types.DataSource, dataDir string) error { - for _, dataS := range data { - destPath := path.Join(dataDir, dataS.DestPath) - log.Infof("downloading data '%s' in '%s'", dataS.SourceURL, destPath) - err := downloadFile(dataS.SourceURL, destPath) - if err != nil { - return err + if downloaded { + logger.Infof("Downloaded %s", destPath) + // a check on stdout is used while scripting to know if the hub has been upgraded + // and a configuration reload is required + // TODO: use a better way to communicate this + fmt.Printf("updated %s\n", destPath) + } } } diff --git a/pkg/cwhub/dataset_test.go b/pkg/cwhub/dataset_test.go deleted file mode 100644 index 106268c01b6..00000000000 --- a/pkg/cwhub/dataset_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package cwhub - -import ( - "os" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/jarcoal/httpmock" -) - -func TestDownloadFile(t *testing.T) { - examplePath := "./example.txt" - defer os.Remove(examplePath) - - httpmock.Activate() - defer httpmock.DeactivateAndReset() - //OK - httpmock.RegisterResponder( - "GET", - "https://example.com/xx", - httpmock.NewStringResponder(200, "example content oneoneone"), - ) - httpmock.RegisterResponder( - "GET", - "https://example.com/x", - httpmock.NewStringResponder(404, "not found"), - ) - err := downloadFile("https://example.com/xx", examplePath) - assert.NoError(t, err) - content, err := os.ReadFile(examplePath) - assert.Equal(t, "example content oneoneone", string(content)) - assert.NoError(t, err) - //bad uri - err = downloadFile("https://zz.com", examplePath) - assert.Error(t, err) - //404 - err = downloadFile("https://example.com/x", examplePath) - assert.Error(t, err) - //bad target - err = downloadFile("https://example.com/xx", "") - assert.Error(t, err) -} diff --git a/pkg/cwhub/doc.go b/pkg/cwhub/doc.go new file mode 100644 index 00000000000..f86b95c6454 --- /dev/null +++ b/pkg/cwhub/doc.go @@ -0,0 +1,126 @@ +// Package cwhub is responsible for installing and upgrading the local hub files for CrowdSec. +// +// # Definitions +// +// - A hub ITEM is a file that defines a parser, a scenario, a collection... in the case of a collection, it has dependencies on other hub items. +// - The hub INDEX is a JSON file that contains a tree of available hub items. +// - A REMOTE HUB is an HTTP server that hosts the hub index and the hub items. It can serve from several branches, usually linked to the CrowdSec version. +// - A LOCAL HUB is a directory that contains a copy of the hub index and the downloaded hub items. +// +// Once downloaded, hub items can be installed by linking to them from the configuration directory. +// If an item is present in the configuration directory but it's not a link to the local hub, it is +// considered as a LOCAL ITEM and won't be removed or upgraded. +// +// # Directory Structure +// +// A typical directory layout is the following: +// +// For the local hub (HubDir = /etc/crowdsec/hub): +// +// - /etc/crowdsec/hub/.index.json +// - /etc/crowdsec/hub/parsers/{stage}/{author}/{parser-name}.yaml +// - /etc/crowdsec/hub/scenarios/{author}/{scenario-name}.yaml +// +// For the configuration directory (InstallDir = /etc/crowdsec): +// +// - /etc/crowdsec/parsers/{stage}/{parser-name.yaml} -> /etc/crowdsec/hub/parsers/{stage}/{author}/{parser-name}.yaml +// - /etc/crowdsec/scenarios/{scenario-name.yaml} -> /etc/crowdsec/hub/scenarios/{author}/{scenario-name}.yaml +// - /etc/crowdsec/scenarios/local-scenario.yaml +// +// Note that installed items are not grouped by author, this may change in the future if we want to +// support items with the same name from different authors. +// +// Only parsers and postoverflows have the concept of stage. +// +// Additionally, an item can reference a DATA SET that is installed in a different location than +// the item itself. These files are stored in the data directory (InstallDataDir = /var/lib/crowdsec/data). +// +// - /var/lib/crowdsec/data/http_path_traversal.txt +// - /var/lib/crowdsec/data/jira_cve_2021-26086.txt +// - /var/lib/crowdsec/data/log4j2_cve_2021_44228.txt +// - /var/lib/crowdsec/data/sensitive_data.txt +// +// # Using the package +// +// The main entry point is the Hub struct. You can create a new instance with NewHub(). +// This constructor takes three parameters, but only the LOCAL HUB configuration is required: +// +// import ( +// "fmt" +// "github.com/crowdsecurity/crowdsec/pkg/csconfig" +// "github.com/crowdsecurity/crowdsec/pkg/cwhub" +// ) +// +// localHub := csconfig.LocalHubCfg{ +// HubIndexFile: "/etc/crowdsec/hub/.index.json", +// HubDir: "/etc/crowdsec/hub", +// InstallDir: "/etc/crowdsec", +// InstallDataDir: "/var/lib/crowdsec/data", +// } +// +// hub, err := cwhub.NewHub(localHub, nil, logger) +// if err != nil { +// return fmt.Errorf("unable to initialize hub: %w", err) +// } +// +// If the logger is nil, the item-by-item messages will be discarded, including warnings. +// After configuring the hub, you must sync its state with items on disk. +// +// err := hub.Load() +// if err != nil { +// return fmt.Errorf("unable to load hub: %w", err) +// } +// +// Now you can use the hub object to access the existing items: +// +// // list all the parsers +// for _, parser := range hub.GetItemsByType(cwhub.PARSERS, false) { +// fmt.Printf("parser: %s\n", parser.Name) +// } +// +// // retrieve a specific collection +// coll := hub.GetItem(cwhub.COLLECTIONS, "crowdsecurity/linux") +// if coll == nil { +// return fmt.Errorf("collection not found") +// } +// +// You can also install items if they have already been downloaded: +// +// // install a parser +// force := false +// downloadOnly := false +// err := parser.Install(force, downloadOnly) +// if err != nil { +// return fmt.Errorf("unable to install parser: %w", err) +// } +// +// As soon as you try to install an item that is not downloaded or is not up-to-date (meaning its computed hash +// does not correspond to the latest version available in the index), a download will be attempted and you'll +// get the error "remote hub configuration is not provided". +// +// To provide the remote hub configuration, use the second parameter of NewHub(): +// +// remoteHub := cwhub.RemoteHubCfg{ +// URLTemplate: "https://cdn-hub.crowdsec.net/crowdsecurity/%s/%s", +// Branch: "master", +// IndexPath: ".index.json", +// } +// +// hub, err := cwhub.NewHub(localHub, remoteHub, logger) +// if err != nil { +// return fmt.Errorf("unable to initialize hub: %w", err) +// } +// +// The URLTemplate is a string that will be used to build the URL of the remote hub. It must contain two +// placeholders: the branch and the file path (it will be an index or an item). +// +// Before calling hub.Load(), you can update the index file by calling the Update() method: +// +// err := hub.Update(context.Background()) +// if err != nil { +// return fmt.Errorf("unable to update hub index: %w", err) +// } +// +// Note that the command will fail if the hub has already been synced. If you want to do it (ex. after a configuration +// change the application is notified with SIGHUP) you have to instantiate a new hub object and dispose of the old one. +package cwhub diff --git a/pkg/cwhub/download.go b/pkg/cwhub/download.go deleted file mode 100644 index 0ba74c7720d..00000000000 --- a/pkg/cwhub/download.go +++ /dev/null @@ -1,281 +0,0 @@ -package cwhub - -import ( - "bytes" - "crypto/sha256" - "errors" - "fmt" - "io" - "net/http" - "os" - "path" - "path/filepath" - "strings" - - log "github.com/sirupsen/logrus" - "gopkg.in/yaml.v2" - - "github.com/crowdsecurity/crowdsec/pkg/csconfig" -) - -var ErrIndexNotFound = fmt.Errorf("index not found") - -func UpdateHubIdx(hub *csconfig.Hub) error { - bidx, err := DownloadHubIdx(hub) - if err != nil { - return fmt.Errorf("failed to download index: %w", err) - } - ret, err := LoadPkgIndex(bidx) - if err != nil { - if !errors.Is(err, ReferenceMissingError) { - return fmt.Errorf("failed to read index: %w", err) - } - } - hubIdx = ret - if err, _ := LocalSync(hub); err != nil { - return fmt.Errorf("failed to sync: %w", err) - } - return nil -} - -func DownloadHubIdx(hub *csconfig.Hub) ([]byte, error) { - log.Debugf("fetching index from branch %s (%s)", HubBranch, fmt.Sprintf(RawFileURLTemplate, HubBranch, HubIndexFile)) - req, err := http.NewRequest(http.MethodGet, fmt.Sprintf(RawFileURLTemplate, HubBranch, HubIndexFile), nil) - if err != nil { - return nil, fmt.Errorf("failed to build request for hub index: %w", err) - } - resp, err := http.DefaultClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed http request for hub index: %w", err) - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - if resp.StatusCode == http.StatusNotFound { - return nil, ErrIndexNotFound - } - return nil, fmt.Errorf("bad http code %d while requesting %s", resp.StatusCode, req.URL.String()) - } - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read request answer for hub index: %w", err) - } - - oldContent, err := os.ReadFile(hub.HubIndexFile) - if err != nil { - if !os.IsNotExist(err) { - log.Warningf("failed to read hub index: %s", err) - } - } else if bytes.Equal(body, oldContent) { - log.Info("hub index is up to date") - // write it anyway, can't hurt - } - - file, err := os.OpenFile(hub.HubIndexFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) - - if err != nil { - return nil, fmt.Errorf("while opening hub index file: %w", err) - } - defer file.Close() - - wsize, err := file.WriteString(string(body)) - if err != nil { - return nil, fmt.Errorf("while writing hub index file: %w", err) - } - log.Infof("Wrote new %d bytes index to %s", wsize, hub.HubIndexFile) - return body, nil -} - -// DownloadLatest will download the latest version of Item to the tdir directory -func DownloadLatest(hub *csconfig.Hub, target Item, overwrite bool, updateOnly bool) (Item, error) { - var err error - - log.Debugf("Downloading %s %s", target.Type, target.Name) - if target.Type != COLLECTIONS { - if !target.Installed && updateOnly && target.Downloaded { - log.Debugf("skipping upgrade of %s : not installed", target.Name) - return target, nil - } - return DownloadItem(hub, target, overwrite) - } - - // collection - var tmp = [][]string{target.Parsers, target.PostOverflows, target.Scenarios, target.Collections} - for idx, ptr := range tmp { - ptrtype := ItemTypes[idx] - for _, p := range ptr { - val, ok := hubIdx[ptrtype][p] - if !ok { - return target, fmt.Errorf("required %s %s of %s doesn't exist, abort", ptrtype, p, target.Name) - } - - if !val.Installed && updateOnly && val.Downloaded { - log.Debugf("skipping upgrade of %s : not installed", target.Name) - continue - } - - log.Debugf("Download %s sub-item : %s %s (%t -> %t)", target.Name, ptrtype, p, target.Installed, updateOnly) - //recurse as it's a collection - if ptrtype == COLLECTIONS { - log.Tracef("collection, recurse") - hubIdx[ptrtype][p], err = DownloadLatest(hub, val, overwrite, updateOnly) - if err != nil { - return target, fmt.Errorf("while downloading %s: %w", val.Name, err) - } - } - item, err := DownloadItem(hub, val, overwrite) - if err != nil { - return target, fmt.Errorf("while downloading %s: %w", val.Name, err) - } - - // We need to enable an item when it has been added to a collection since latest release of the collection. - // We check if val.Downloaded is false because maybe the item has been disabled by the user. - if !item.Installed && !val.Downloaded { - if item, err = EnableItem(hub, item); err != nil { - return target, fmt.Errorf("enabling '%s': %w", item.Name, err) - } - } - hubIdx[ptrtype][p] = item - } - } - target, err = DownloadItem(hub, target, overwrite) - if err != nil { - return target, fmt.Errorf("failed to download item : %s", err) - } - return target, nil -} - -func DownloadItem(hub *csconfig.Hub, target Item, overwrite bool) (Item, error) { - var tdir = hub.HubDir - var dataFolder = hub.DataDir - /*if user didn't --force, don't overwrite local, tainted, up-to-date files*/ - if !overwrite { - if target.Tainted { - log.Debugf("%s : tainted, not updated", target.Name) - return target, nil - } - if target.UpToDate { - log.Debugf("%s : up-to-date, not updated", target.Name) - // We still have to check if data files are present - } - } - req, err := http.NewRequest(http.MethodGet, fmt.Sprintf(RawFileURLTemplate, HubBranch, target.RemotePath), nil) - if err != nil { - return target, fmt.Errorf("while downloading %s: %w", req.URL.String(), err) - } - resp, err := http.DefaultClient.Do(req) - if err != nil { - return target, fmt.Errorf("while downloading %s: %w", req.URL.String(), err) - } - if resp.StatusCode != http.StatusOK { - return target, fmt.Errorf("bad http code %d for %s", resp.StatusCode, req.URL.String()) - } - defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) - if err != nil { - return target, fmt.Errorf("while reading %s: %w", req.URL.String(), err) - } - h := sha256.New() - if _, err := h.Write(body); err != nil { - return target, fmt.Errorf("while hashing %s: %w", target.Name, err) - } - meow := fmt.Sprintf("%x", h.Sum(nil)) - if meow != target.Versions[target.Version].Digest { - log.Errorf("Downloaded version doesn't match index, please 'hub update'") - log.Debugf("got %s, expected %s", meow, target.Versions[target.Version].Digest) - return target, fmt.Errorf("invalid download hash for %s", target.Name) - } - //all good, install - //check if parent dir exists - tmpdirs := strings.Split(tdir+"/"+target.RemotePath, "/") - parent_dir := strings.Join(tmpdirs[:len(tmpdirs)-1], "/") - - /*ensure that target file is within target dir*/ - finalPath, err := filepath.Abs(tdir + "/" + target.RemotePath) - if err != nil { - return target, fmt.Errorf("filepath.Abs error on %s: %w", tdir+"/"+target.RemotePath, err) - } - if !strings.HasPrefix(finalPath, tdir) { - return target, fmt.Errorf("path %s escapes %s, abort", target.RemotePath, tdir) - } - /*check dir*/ - if _, err = os.Stat(parent_dir); os.IsNotExist(err) { - log.Debugf("%s doesn't exist, create", parent_dir) - if err := os.MkdirAll(parent_dir, os.ModePerm); err != nil { - return target, fmt.Errorf("while creating parent directories: %w", err) - } - } - /*check actual file*/ - if _, err = os.Stat(finalPath); !os.IsNotExist(err) { - log.Warningf("%s : overwrite", target.Name) - log.Debugf("target: %s/%s", tdir, target.RemotePath) - } else { - log.Infof("%s : OK", target.Name) - } - - f, err := os.OpenFile(tdir+"/"+target.RemotePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) - if err != nil { - return target, fmt.Errorf("while opening file: %w", err) - } - defer f.Close() - _, err = f.WriteString(string(body)) - if err != nil { - return target, fmt.Errorf("while writing file: %w", err) - } - target.Downloaded = true - target.Tainted = false - target.UpToDate = true - - if err = downloadData(dataFolder, overwrite, bytes.NewReader(body)); err != nil { - return target, fmt.Errorf("while downloading data for %s: %w", target.FileName, err) - } - - hubIdx[target.Type][target.Name] = target - return target, nil -} - -func DownloadDataIfNeeded(hub *csconfig.Hub, target Item, force bool) error { - var ( - dataFolder = hub.DataDir - itemFile *os.File - err error - ) - itemFilePath := fmt.Sprintf("%s/%s/%s/%s", hub.ConfigDir, target.Type, target.Stage, target.FileName) - if itemFile, err = os.Open(itemFilePath); err != nil { - return fmt.Errorf("while opening %s: %w", itemFilePath, err) - } - defer itemFile.Close() - if err = downloadData(dataFolder, force, itemFile); err != nil { - return fmt.Errorf("while downloading data for %s: %w", itemFilePath, err) - } - return nil -} - -func downloadData(dataFolder string, force bool, reader io.Reader) error { - var err error - dec := yaml.NewDecoder(reader) - - for { - data := &DataSet{} - err = dec.Decode(data) - if err != nil { - if errors.Is(err, io.EOF) { - break - } - return fmt.Errorf("while reading file: %w", err) - } - - download := false - for _, dataS := range data.Data { - if _, err := os.Stat(path.Join(dataFolder, dataS.DestPath)); os.IsNotExist(err) { - download = true - } - } - if download || force { - err = GetData(data.Data, dataFolder) - if err != nil { - return fmt.Errorf("while getting data: %w", err) - } - } - } - return nil -} diff --git a/pkg/cwhub/download_test.go b/pkg/cwhub/download_test.go deleted file mode 100644 index 156c4132279..00000000000 --- a/pkg/cwhub/download_test.go +++ /dev/null @@ -1,42 +0,0 @@ -package cwhub - -import ( - "fmt" - "strings" - "testing" - - "github.com/crowdsecurity/crowdsec/pkg/csconfig" - log "github.com/sirupsen/logrus" -) - -func TestDownloadHubIdx(t *testing.T) { - back := RawFileURLTemplate - //bad url template - fmt.Println("Test 'bad URL'") - RawFileURLTemplate = "x" - ret, err := DownloadHubIdx(&csconfig.Hub{}) - if err == nil || !strings.HasPrefix(fmt.Sprintf("%s", err), "failed to build request for hub index: parse ") { - log.Errorf("unexpected error %s", err) - } - fmt.Printf("->%+v", ret) - - //bad domain - fmt.Println("Test 'bad domain'") - RawFileURLTemplate = "https://baddomain/%s/%s" - ret, err = DownloadHubIdx(&csconfig.Hub{}) - if err == nil || !strings.HasPrefix(fmt.Sprintf("%s", err), "failed http request for hub index: Get") { - log.Errorf("unexpected error %s", err) - } - fmt.Printf("->%+v", ret) - - //bad target path - fmt.Println("Test 'bad target path'") - RawFileURLTemplate = back - ret, err = DownloadHubIdx(&csconfig.Hub{HubIndexFile: "/does/not/exist/index.json"}) - if err == nil || !strings.HasPrefix(fmt.Sprintf("%s", err), "while opening hub index file: open /does/not/exist/index.json:") { - log.Errorf("unexpected error %s", err) - } - - RawFileURLTemplate = back - fmt.Printf("->%+v", ret) -} diff --git a/pkg/cwhub/errors.go b/pkg/cwhub/errors.go new file mode 100644 index 00000000000..b0be444fcba --- /dev/null +++ b/pkg/cwhub/errors.go @@ -0,0 +1,19 @@ +package cwhub + +import ( + "errors" + "fmt" +) + +// ErrNilRemoteHub is returned when trying to download with a local-only configuration. +var ErrNilRemoteHub = errors.New("remote hub configuration is not provided. Please report this issue to the developers") + +// IndexNotFoundError is returned when the remote hub index is not found. +type IndexNotFoundError struct { + URL string + Branch string +} + +func (e IndexNotFoundError) Error() string { + return fmt.Sprintf("index not found at %s, branch '%s'. Please check the .cscli.hub_branch value if you specified it in config.yaml, or use 'master' if not sure", e.URL, e.Branch) +} diff --git a/pkg/cwhub/helpers.go b/pkg/cwhub/helpers.go deleted file mode 100644 index 4133e227250..00000000000 --- a/pkg/cwhub/helpers.go +++ /dev/null @@ -1,228 +0,0 @@ -package cwhub - -import ( - "fmt" - "path/filepath" - - "github.com/enescakir/emoji" - log "github.com/sirupsen/logrus" - "golang.org/x/mod/semver" - - "github.com/crowdsecurity/crowdsec/pkg/csconfig" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" -) - -// pick a hub branch corresponding to the current crowdsec version. -func chooseHubBranch() (string, error) { - latest, err := cwversion.Latest() - if err != nil { - log.Warningf("Unable to retrieve latest crowdsec version: %s, defaulting to master", err) - //lint:ignore nilerr reason - return "master", nil // ignore - } - - csVersion := cwversion.VersionStrip() - if csVersion == latest { - log.Debugf("current version is equal to latest (%s)", csVersion) - return "master", nil - } - - // if current version is greater than the latest we are in pre-release - if semver.Compare(csVersion, latest) == 1 { - log.Debugf("Your current crowdsec version seems to be a pre-release (%s)", csVersion) - return "master", nil - } - - if csVersion == "" { - log.Warning("Crowdsec version is not set, using master branch for the hub") - return "master", nil - } - - log.Warnf("Crowdsec is not the latest version. "+ - "Current version is '%s' and the latest stable version is '%s'. Please update it!", - csVersion, latest) - log.Warnf("As a result, you will not be able to use parsers/scenarios/collections "+ - "added to Crowdsec Hub after CrowdSec %s", latest) - return csVersion, nil -} - -// SetHubBranch sets the package variable that points to the hub branch. -func SetHubBranch() error { - // a branch is already set, or specified from the flags - if HubBranch != "" { - return nil - } - - // use the branch corresponding to the crowdsec version - branch, err := chooseHubBranch() - if err != nil { - return err - } - HubBranch = branch - log.Debugf("Using branch '%s' for the hub", HubBranch) - return nil -} - -func InstallItem(csConfig *csconfig.Config, name string, obtype string, force bool, downloadOnly bool) error { - it := GetItem(obtype, name) - if it == nil { - return fmt.Errorf("unable to retrieve item: %s", name) - } - - item := *it - if downloadOnly && item.Downloaded && item.UpToDate { - log.Warningf("%s is already downloaded and up-to-date", item.Name) - if !force { - return nil - } - } - - item, err := DownloadLatest(csConfig.Hub, item, force, true) - if err != nil { - return fmt.Errorf("while downloading %s: %w", item.Name, err) - } - - if err := AddItem(obtype, item); err != nil { - return fmt.Errorf("while adding %s: %w", item.Name, err) - } - - if downloadOnly { - log.Infof("Downloaded %s to %s", item.Name, filepath.Join(csConfig.Hub.HubDir, item.RemotePath)) - return nil - } - - item, err = EnableItem(csConfig.Hub, item) - if err != nil { - return fmt.Errorf("while enabling %s: %w", item.Name, err) - } - - if err := AddItem(obtype, item); err != nil { - return fmt.Errorf("while adding %s: %w", item.Name, err) - } - - log.Infof("Enabled %s", item.Name) - - return nil -} - -// XXX this must return errors instead of log.Fatal -func RemoveMany(csConfig *csconfig.Config, itemType string, name string, all bool, purge bool, forceAction bool) { - var ( - err error - disabled int - ) - - if name != "" { - it := GetItem(itemType, name) - if it == nil { - log.Fatalf("unable to retrieve: %s", name) - } - - item := *it - item, err = DisableItem(csConfig.Hub, item, purge, forceAction) - if err != nil { - log.Fatalf("unable to disable %s : %v", item.Name, err) - } - - if err := AddItem(itemType, item); err != nil { - log.Fatalf("unable to add %s: %v", item.Name, err) - } - return - } - - if !all { - log.Fatal("removing item: no item specified") - } - - // remove all - for _, v := range GetItemMap(itemType) { - if !v.Installed { - continue - } - v, err = DisableItem(csConfig.Hub, v, purge, forceAction) - if err != nil { - log.Fatalf("unable to disable %s : %v", v.Name, err) - } - - if err := AddItem(itemType, v); err != nil { - log.Fatalf("unable to add %s: %v", v.Name, err) - } - disabled++ - } - log.Infof("Disabled %d items", disabled) -} - -func UpgradeConfig(csConfig *csconfig.Config, itemType string, name string, force bool) { - var ( - err error - updated int - found bool - ) - - for _, v := range GetItemMap(itemType) { - if name != "" && name != v.Name { - continue - } - - if !v.Installed { - log.Tracef("skip %s, not installed", v.Name) - continue - } - - if !v.Downloaded { - log.Warningf("%s : not downloaded, please install.", v.Name) - continue - } - - found = true - - if v.UpToDate { - log.Infof("%s : up-to-date", v.Name) - - if err = DownloadDataIfNeeded(csConfig.Hub, v, force); err != nil { - log.Fatalf("%s : download failed : %v", v.Name, err) - } - - if !force { - continue - } - } - - v, err = DownloadLatest(csConfig.Hub, v, force, true) - if err != nil { - log.Fatalf("%s : download failed : %v", v.Name, err) - } - - if !v.UpToDate { - if v.Tainted { - log.Infof("%v %s is tainted, --force to overwrite", emoji.Warning, v.Name) - } else if v.Local { - log.Infof("%v %s is local", emoji.Prohibited, v.Name) - } - } else { - // this is used while scripting to know if the hub has been upgraded - // and a configuration reload is required - fmt.Printf("updated %s\n", v.Name) - log.Infof("%v %s : updated", emoji.Package, v.Name) - updated++ - } - - if err := AddItem(itemType, v); err != nil { - log.Fatalf("unable to add %s: %v", v.Name, err) - } - } - - if !found && name == "" { - log.Infof("No %s installed, nothing to upgrade", itemType) - } else if !found { - log.Errorf("Item '%s' not found in hub", name) - } else if updated == 0 && found { - if name == "" { - log.Infof("All %s are already up-to-date", itemType) - } else { - log.Infof("Item '%s' is up-to-date", name) - } - } else if updated != 0 { - log.Infof("Upgraded %d items", updated) - } -} diff --git a/pkg/cwhub/helpers_test.go b/pkg/cwhub/helpers_test.go deleted file mode 100644 index b8a15519da5..00000000000 --- a/pkg/cwhub/helpers_test.go +++ /dev/null @@ -1,161 +0,0 @@ -package cwhub - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -//Download index, install collection. Add scenario to collection (hub-side), update index, upgrade collection -// We expect the new scenario to be installed -func TestUpgradeConfigNewScenarioInCollection(t *testing.T) { - cfg := envSetup(t) - defer envTearDown(cfg) - - // fresh install of collection - getHubIdxOrFail(t) - - require.False(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].Downloaded) - require.False(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].Installed) - - require.NoError(t, InstallItem(cfg, "crowdsecurity/test_collection", COLLECTIONS, false, false)) - - require.True(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].Downloaded) - require.True(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].Installed) - require.True(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].UpToDate) - require.False(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].Tainted) - - // This is the scenario that gets added in next version of collection - require.False(t, hubIdx[SCENARIOS]["crowdsecurity/barfoo_scenario"].Downloaded) - require.False(t, hubIdx[SCENARIOS]["crowdsecurity/barfoo_scenario"].Installed) - - assertCollectionDepsInstalled(t, "crowdsecurity/test_collection") - - // collection receives an update. It now adds new scenario "crowdsecurity/barfoo_scenario" - pushUpdateToCollectionInHub() - - if err := UpdateHubIdx(cfg.Hub); err != nil { - t.Fatalf("failed to download index : %s", err) - } - getHubIdxOrFail(t) - - require.True(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].Downloaded) - require.True(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].Installed) - require.False(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].UpToDate) - require.False(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].Tainted) - - UpgradeConfig(cfg, COLLECTIONS, "crowdsecurity/test_collection", false) - assertCollectionDepsInstalled(t, "crowdsecurity/test_collection") - - require.True(t, hubIdx[SCENARIOS]["crowdsecurity/barfoo_scenario"].Downloaded) - require.True(t, hubIdx[SCENARIOS]["crowdsecurity/barfoo_scenario"].Installed) - -} - -// Install a collection, disable a scenario. -// Upgrade should install should not enable/download the disabled scenario. -func TestUpgradeConfigInDisabledScenarioShouldNotBeInstalled(t *testing.T) { - cfg := envSetup(t) - defer envTearDown(cfg) - - // fresh install of collection - getHubIdxOrFail(t) - - require.False(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].Downloaded) - require.False(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].Installed) - require.False(t, hubIdx[SCENARIOS]["crowdsecurity/foobar_scenario"].Installed) - - require.NoError(t, InstallItem(cfg, "crowdsecurity/test_collection", COLLECTIONS, false, false)) - - require.True(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].Downloaded) - require.True(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].Installed) - require.True(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].UpToDate) - require.False(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].Tainted) - require.True(t, hubIdx[SCENARIOS]["crowdsecurity/foobar_scenario"].Installed) - assertCollectionDepsInstalled(t, "crowdsecurity/test_collection") - - RemoveMany(cfg, SCENARIOS, "crowdsecurity/foobar_scenario", false, false, false) - getHubIdxOrFail(t) - // scenario referenced by collection was deleted hence, collection should be tainted - require.False(t, hubIdx[SCENARIOS]["crowdsecurity/foobar_scenario"].Installed) - require.True(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].Tainted) - require.True(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].Downloaded) - require.True(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].Installed) - require.True(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].UpToDate) - - if err := UpdateHubIdx(cfg.Hub); err != nil { - t.Fatalf("failed to download index : %s", err) - } - - UpgradeConfig(cfg, COLLECTIONS, "crowdsecurity/test_collection", false) - - getHubIdxOrFail(t) - require.False(t, hubIdx[SCENARIOS]["crowdsecurity/foobar_scenario"].Installed) -} - -func getHubIdxOrFail(t *testing.T) { - if err := GetHubIdx(getTestCfg().Hub); err != nil { - t.Fatalf("failed to load hub index") - } -} - -// Install a collection. Disable a referenced scenario. Publish new version of collection with new scenario -// Upgrade should not enable/download the disabled scenario. -// Upgrade should install and enable the newly added scenario. -func TestUpgradeConfigNewScenarioIsInstalledWhenReferencedScenarioIsDisabled(t *testing.T) { - cfg := envSetup(t) - defer envTearDown(cfg) - - // fresh install of collection - getHubIdxOrFail(t) - - require.False(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].Downloaded) - require.False(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].Installed) - require.False(t, hubIdx[SCENARIOS]["crowdsecurity/foobar_scenario"].Installed) - - require.NoError(t, InstallItem(cfg, "crowdsecurity/test_collection", COLLECTIONS, false, false)) - - require.True(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].Downloaded) - require.True(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].Installed) - require.True(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].UpToDate) - require.False(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].Tainted) - require.True(t, hubIdx[SCENARIOS]["crowdsecurity/foobar_scenario"].Installed) - assertCollectionDepsInstalled(t, "crowdsecurity/test_collection") - - RemoveMany(cfg, SCENARIOS, "crowdsecurity/foobar_scenario", false, false, false) - getHubIdxOrFail(t) - // scenario referenced by collection was deleted hence, collection should be tainted - require.False(t, hubIdx[SCENARIOS]["crowdsecurity/foobar_scenario"].Installed) - require.True(t, hubIdx[SCENARIOS]["crowdsecurity/foobar_scenario"].Downloaded) // this fails - require.True(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].Tainted) - require.True(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].Downloaded) - require.True(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].Installed) - require.True(t, hubIdx[COLLECTIONS]["crowdsecurity/test_collection"].UpToDate) - - // collection receives an update. It now adds new scenario "crowdsecurity/barfoo_scenario" - // we now attempt to upgrade the collection, however it shouldn't install the foobar_scenario - // we just removed. Nor should it install the newly added scenario - pushUpdateToCollectionInHub() - - if err := UpdateHubIdx(cfg.Hub); err != nil { - t.Fatalf("failed to download index : %s", err) - } - require.False(t, hubIdx[SCENARIOS]["crowdsecurity/foobar_scenario"].Installed) - getHubIdxOrFail(t) - - UpgradeConfig(cfg, COLLECTIONS, "crowdsecurity/test_collection", false) - getHubIdxOrFail(t) - require.False(t, hubIdx[SCENARIOS]["crowdsecurity/foobar_scenario"].Installed) - require.True(t, hubIdx[SCENARIOS]["crowdsecurity/barfoo_scenario"].Installed) -} - -func assertCollectionDepsInstalled(t *testing.T, collection string) { - t.Helper() - c := hubIdx[COLLECTIONS][collection] - require.NoError(t, CollecDepsCheck(&c)) -} - -func pushUpdateToCollectionInHub() { - responseByPath["/master/.index.json"] = fileToStringX("./tests/index2.json") - responseByPath["/master/collections/crowdsecurity/test_collection.yaml"] = fileToStringX("./tests/collection_v2.yaml") -} diff --git a/pkg/cwhub/hub.go b/pkg/cwhub/hub.go new file mode 100644 index 00000000000..f74a794a512 --- /dev/null +++ b/pkg/cwhub/hub.go @@ -0,0 +1,281 @@ +package cwhub + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "path" + "strings" + + "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/go-cs-lib/maptools" + + "github.com/crowdsecurity/crowdsec/pkg/csconfig" +) + +// Hub is the main structure for the package. +type Hub struct { + items HubItems // Items read from HubDir and InstallDir + pathIndex map[string]*Item + local *csconfig.LocalHubCfg + remote *RemoteHubCfg + logger *logrus.Logger + Warnings []string // Warnings encountered during sync +} + +// GetDataDir returns the data directory, where data sets are installed. +func (h *Hub) GetDataDir() string { + return h.local.InstallDataDir +} + +// NewHub returns a new Hub instance with local and (optionally) remote configuration. +// The hub is not synced automatically. Load() must be called to read the index, sync the local state, +// and check for unmanaged items. +// All download operations (including updateIndex) return ErrNilRemoteHub if the remote configuration is not set. +func NewHub(local *csconfig.LocalHubCfg, remote *RemoteHubCfg, logger *logrus.Logger) (*Hub, error) { + if local == nil { + return nil, errors.New("no hub configuration found") + } + + if logger == nil { + logger = logrus.New() + logger.SetOutput(io.Discard) + } + + hub := &Hub{ + local: local, + remote: remote, + logger: logger, + pathIndex: make(map[string]*Item, 0), + } + + return hub, nil +} + +// Load reads the state of the items on disk. +func (h *Hub) Load() error { + h.logger.Debugf("loading hub idx %s", h.local.HubIndexFile) + + if err := h.parseIndex(); err != nil { + return fmt.Errorf("failed to load hub index: %w", err) + } + + if err := h.localSync(); err != nil { + return fmt.Errorf("failed to sync hub items: %w", err) + } + + return nil +} + +// parseIndex takes the content of an index file and fills the map of associated parsers/scenarios/collections. +func (h *Hub) parseIndex() error { + bidx, err := os.ReadFile(h.local.HubIndexFile) + if err != nil { + return fmt.Errorf("unable to read index file: %w", err) + } + + if err := json.Unmarshal(bidx, &h.items); err != nil { + return fmt.Errorf("failed to parse index: %w", err) + } + + h.logger.Debugf("%d item types in hub index", len(ItemTypes)) + + // Iterate over the different types to complete the struct + for _, itemType := range ItemTypes { + h.logger.Tracef("%s: %d items", itemType, len(h.GetItemMap(itemType))) + + for name, item := range h.GetItemMap(itemType) { + item.hub = h + item.Name = name + + // if the item has no (redundant) author, take it from the json key + if item.Author == "" && strings.Contains(name, "/") { + item.Author = strings.Split(name, "/")[0] + } + + item.Type = itemType + item.FileName = path.Base(item.RemotePath) + + item.logMissingSubItems() + + if item.latestHash() == "" { + h.logger.Errorf("invalid hub item %s: latest version missing from index", item.FQName()) + } + } + } + + return nil +} + +// ItemStats returns total counts of the hub items, including local and tainted. +func (h *Hub) ItemStats() []string { + loaded := "" + local := 0 + tainted := 0 + + for _, itemType := range ItemTypes { + items := h.GetItemsByType(itemType, false) + if len(items) == 0 { + continue + } + + loaded += fmt.Sprintf("%d %s, ", len(items), itemType) + + for _, item := range items { + if item.State.IsLocal() { + local++ + } + + if item.State.Tainted { + tainted++ + } + } + } + + loaded = strings.Trim(loaded, ", ") + if loaded == "" { + loaded = "0 items" + } + + ret := []string{ + "Loaded: " + loaded, + } + + if local > 0 || tainted > 0 { + ret = append(ret, fmt.Sprintf("Unmanaged items: %d local, %d tainted", local, tainted)) + } + + return ret +} + +// Update downloads the latest version of the index and writes it to disk if it changed. It cannot be called after Load() +// unless the hub is completely empty. +func (h *Hub) Update(ctx context.Context) error { + if len(h.pathIndex) > 0 { + // if this happens, it's a bug. + return errors.New("cannot update hub after items have been loaded") + } + + downloaded, err := h.remote.fetchIndex(ctx, h.local.HubIndexFile) + if err != nil { + return err + } + + if downloaded { + h.logger.Infof("Wrote index to %s", h.local.HubIndexFile) + } else { + h.logger.Info("hub index is up to date") + } + + return nil +} + +func (h *Hub) addItem(item *Item) { + if h.items[item.Type] == nil { + h.items[item.Type] = make(map[string]*Item) + } + + h.items[item.Type][item.Name] = item + h.pathIndex[item.State.LocalPath] = item +} + +// GetItemMap returns the map of items for a given type. +func (h *Hub) GetItemMap(itemType string) map[string]*Item { + return h.items[itemType] +} + +// GetItem returns an item from hub based on its type and full name (author/name). +func (h *Hub) GetItem(itemType string, itemName string) *Item { + return h.GetItemMap(itemType)[itemName] +} + +// GetItemByPath returns an item from hub based on its (absolute) local path. +func (h *Hub) GetItemByPath(itemPath string) *Item { + return h.pathIndex[itemPath] +} + +// GetItemFQ returns an item from hub based on its type and name (type:author/name). +func (h *Hub) GetItemFQ(itemFQName string) (*Item, error) { + // type and name are separated by a colon + parts := strings.Split(itemFQName, ":") + + if len(parts) != 2 { + return nil, fmt.Errorf("invalid item name %s", itemFQName) + } + + m := h.GetItemMap(parts[0]) + if m == nil { + return nil, fmt.Errorf("invalid item type %s", parts[0]) + } + + i := m[parts[1]] + if i == nil { + return nil, fmt.Errorf("item %s not found", parts[1]) + } + + return i, nil +} + +// GetItemsByType returns a slice of all the items of a given type, installed or not, optionally sorted by case-insensitive name. +// A non-existent type will silently return an empty slice. +func (h *Hub) GetItemsByType(itemType string, sorted bool) []*Item { + items := h.items[itemType] + + ret := make([]*Item, len(items)) + + if sorted { + for idx, name := range maptools.SortedKeysNoCase(items) { + ret[idx] = items[name] + } + + return ret + } + + idx := 0 + for _, item := range items { + ret[idx] = item + idx += 1 + } + + return ret +} + +// GetInstalledByType returns a slice of all the installed items of a given type, optionally sorted by case-insensitive name. +// A non-existent type will silently return an empty slice. +func (h *Hub) GetInstalledByType(itemType string, sorted bool) []*Item { + ret := make([]*Item, 0) + + for _, item := range h.GetItemsByType(itemType, sorted) { + if item.State.Installed { + ret = append(ret, item) + } + } + + return ret +} + +// GetInstalledListForAPI returns a slice of names of all the installed scenarios and appsec-rules. +// The returned list is sorted by type (scenarios first) and case-insensitive name. +func (h *Hub) GetInstalledListForAPI() []string { + scenarios := h.GetInstalledByType(SCENARIOS, true) + appsecRules := h.GetInstalledByType(APPSEC_RULES, true) + + ret := make([]string, len(scenarios)+len(appsecRules)) + + idx := 0 + for _, item := range scenarios { + ret[idx] = item.Name + idx += 1 + } + + for _, item := range appsecRules { + ret[idx] = item.Name + idx += 1 + } + + return ret +} diff --git a/pkg/cwhub/hub_test.go b/pkg/cwhub/hub_test.go new file mode 100644 index 00000000000..1c2c9ccceca --- /dev/null +++ b/pkg/cwhub/hub_test.go @@ -0,0 +1,91 @@ +package cwhub + +import ( + "context" + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/crowdsecurity/go-cs-lib/cstest" +) + +func TestInitHubUpdate(t *testing.T) { + hub := envSetup(t) + remote := &RemoteHubCfg{ + URLTemplate: mockURLTemplate, + Branch: "master", + IndexPath: ".index.json", + } + + _, err := NewHub(hub.local, remote, nil) + require.NoError(t, err) + + ctx := context.Background() + + err = hub.Update(ctx) + require.NoError(t, err) + + err = hub.Load() + require.NoError(t, err) +} + +func TestUpdateIndex(t *testing.T) { + // bad url template + fmt.Println("Test 'bad URL'") + + tmpIndex, err := os.CreateTemp("", "index.json") + require.NoError(t, err) + + // close the file to avoid preventing the rename on windows + err = tmpIndex.Close() + require.NoError(t, err) + + t.Cleanup(func() { + os.Remove(tmpIndex.Name()) + }) + + hub := envSetup(t) + + hub.remote = &RemoteHubCfg{ + URLTemplate: "x", + Branch: "", + IndexPath: "", + } + + hub.local.HubIndexFile = tmpIndex.Name() + + ctx := context.Background() + + err = hub.Update(ctx) + cstest.RequireErrorContains(t, err, "failed to build hub index request: invalid URL template 'x'") + + // bad domain + fmt.Println("Test 'bad domain'") + + hub.remote = &RemoteHubCfg{ + URLTemplate: "https://baddomain/crowdsecurity/%s/%s", + Branch: "master", + IndexPath: ".index.json", + } + + err = hub.Update(ctx) + require.NoError(t, err) + // XXX: this is not failing + // cstest.RequireErrorContains(t, err, "failed http request for hub index: Get") + + // bad target path + fmt.Println("Test 'bad target path'") + + hub.remote = &RemoteHubCfg{ + URLTemplate: mockURLTemplate, + Branch: "master", + IndexPath: ".index.json", + } + + hub.local.HubIndexFile = "/does/not/exist/index.json" + + err = hub.Update(ctx) + cstest.RequireErrorContains(t, err, "failed to create temporary download file for /does/not/exist/index.json:") +} diff --git a/pkg/cwhub/install.go b/pkg/cwhub/install.go deleted file mode 100644 index 505c3629760..00000000000 --- a/pkg/cwhub/install.go +++ /dev/null @@ -1,201 +0,0 @@ -package cwhub - -import ( - "fmt" - "os" - "path/filepath" - - log "github.com/sirupsen/logrus" - - "github.com/crowdsecurity/crowdsec/pkg/csconfig" -) - -func purgeItem(hub *csconfig.Hub, target Item) (Item, error) { - var hdir = hub.HubDir - hubpath := hdir + "/" + target.RemotePath - - // disable hub file - if err := os.Remove(hubpath); err != nil { - return target, fmt.Errorf("while removing file: %w", err) - } - - target.Downloaded = false - log.Infof("Removed source file [%s] : %s", target.Name, hubpath) - hubIdx[target.Type][target.Name] = target - return target, nil -} - -//DisableItem to disable an item managed by the hub, removes the symlink if purge is true -func DisableItem(hub *csconfig.Hub, target Item, purge bool, force bool) (Item, error) { - var tdir = hub.ConfigDir - var hdir = hub.HubDir - var err error - - if !target.Installed { - if purge { - target, err = purgeItem(hub, target) - if err != nil { - return target, err - } - } - return target, nil - } - - syml, err := filepath.Abs(tdir + "/" + target.Type + "/" + target.Stage + "/" + target.FileName) - if err != nil { - return Item{}, err - } - - if target.Local { - return target, fmt.Errorf("%s isn't managed by hub. Please delete manually", target.Name) - } - - if target.Tainted && !force { - return target, fmt.Errorf("%s is tainted, use '--force' to overwrite", target.Name) - } - - /*for a COLLECTIONS, disable sub-items*/ - if target.Type == COLLECTIONS { - var tmp = [][]string{target.Parsers, target.PostOverflows, target.Scenarios, target.Collections} - for idx, ptr := range tmp { - ptrtype := ItemTypes[idx] - for _, p := range ptr { - if val, ok := hubIdx[ptrtype][p]; ok { - // check if the item doesn't belong to another collection before removing it - toRemove := true - for _, collection := range val.BelongsToCollections { - if collection != target.Name { - toRemove = false - break - } - } - if toRemove { - hubIdx[ptrtype][p], err = DisableItem(hub, val, purge, force) - if err != nil { - return target, fmt.Errorf("while disabling %s: %w", p, err) - } - } else { - log.Infof("%s was not removed because it belongs to another collection", val.Name) - } - } else { - log.Errorf("Referred %s %s in collection %s doesn't exist.", ptrtype, p, target.Name) - } - } - } - } - - stat, err := os.Lstat(syml) - if os.IsNotExist(err) { - if !purge && !force { //we only accept to "delete" non existing items if it's a purge - return target, fmt.Errorf("can't delete %s : %s doesn't exist", target.Name, syml) - } - } else { - //if it's managed by hub, it's a symlink to csconfig.GConfig.hub.HubDir / ... - if stat.Mode()&os.ModeSymlink == 0 { - log.Warningf("%s (%s) isn't a symlink, can't disable", target.Name, syml) - return target, fmt.Errorf("%s isn't managed by hub", target.Name) - } - hubpath, err := os.Readlink(syml) - if err != nil { - return target, fmt.Errorf("while reading symlink: %w", err) - } - absPath, err := filepath.Abs(hdir + "/" + target.RemotePath) - if err != nil { - return target, fmt.Errorf("while abs path: %w", err) - } - if hubpath != absPath { - log.Warningf("%s (%s) isn't a symlink to %s", target.Name, syml, absPath) - return target, fmt.Errorf("%s isn't managed by hub", target.Name) - } - - //remove the symlink - if err = os.Remove(syml); err != nil { - return target, fmt.Errorf("while removing symlink: %w", err) - } - log.Infof("Removed symlink [%s] : %s", target.Name, syml) - } - target.Installed = false - - if purge { - target, err = purgeItem(hub, target) - if err != nil { - return target, err - } - } - hubIdx[target.Type][target.Name] = target - return target, nil -} - -// creates symlink between actual config file at hub.HubDir and hub.ConfigDir -// Handles collections recursively -func EnableItem(hub *csconfig.Hub, target Item) (Item, error) { - var tdir = hub.ConfigDir - var hdir = hub.HubDir - var err error - parent_dir := filepath.Clean(tdir + "/" + target.Type + "/" + target.Stage + "/") - /*create directories if needed*/ - if target.Installed { - if target.Tainted { - return target, fmt.Errorf("%s is tainted, won't enable unless --force", target.Name) - } - if target.Local { - return target, fmt.Errorf("%s is local, won't enable", target.Name) - } - /* if it's a collection, check sub-items even if the collection file itself is up-to-date */ - if target.UpToDate && target.Type != COLLECTIONS { - log.Tracef("%s is installed and up-to-date, skip.", target.Name) - return target, nil - } - } - if _, err := os.Stat(parent_dir); os.IsNotExist(err) { - log.Printf("%s doesn't exist, create", parent_dir) - if err := os.MkdirAll(parent_dir, os.ModePerm); err != nil { - return target, fmt.Errorf("while creating directory: %w", err) - } - } - - /*install sub-items if it's a collection*/ - if target.Type == COLLECTIONS { - var tmp = [][]string{target.Parsers, target.PostOverflows, target.Scenarios, target.Collections} - for idx, ptr := range tmp { - ptrtype := ItemTypes[idx] - for _, p := range ptr { - val, ok := hubIdx[ptrtype][p] - if !ok { - return target, fmt.Errorf("required %s %s of %s doesn't exist, abort", ptrtype, p, target.Name) - } - - hubIdx[ptrtype][p], err = EnableItem(hub, val) - if err != nil { - return target, fmt.Errorf("while installing %s: %w", p, err) - } - } - } - } - - // check if file already exists where it should in configdir (eg /etc/crowdsec/collections/) - if _, err := os.Lstat(parent_dir + "/" + target.FileName); !os.IsNotExist(err) { - log.Printf("%s already exists.", parent_dir+"/"+target.FileName) - return target, nil - } - - //tdir+target.RemotePath - srcPath, err := filepath.Abs(hdir + "/" + target.RemotePath) - if err != nil { - return target, fmt.Errorf("while getting source path: %w", err) - } - - dstPath, err := filepath.Abs(parent_dir + "/" + target.FileName) - if err != nil { - return target, fmt.Errorf("while getting destination path: %w", err) - } - - if err = os.Symlink(srcPath, dstPath); err != nil { - return target, fmt.Errorf("while creating symlink from %s to %s: %w", srcPath, dstPath, err) - } - - log.Printf("Enabled %s : %s", target.Type, target.Name) - target.Installed = true - hubIdx[target.Type][target.Name] = target - return target, nil -} diff --git a/pkg/cwhub/item.go b/pkg/cwhub/item.go new file mode 100644 index 00000000000..32d1acf94ff --- /dev/null +++ b/pkg/cwhub/item.go @@ -0,0 +1,454 @@ +package cwhub + +import ( + "encoding/json" + "fmt" + "path/filepath" + "slices" + + "github.com/Masterminds/semver/v3" + + "github.com/crowdsecurity/crowdsec/pkg/emoji" +) + +const ( + // managed item types. + COLLECTIONS = "collections" + PARSERS = "parsers" + POSTOVERFLOWS = "postoverflows" + SCENARIOS = "scenarios" + CONTEXTS = "contexts" + APPSEC_CONFIGS = "appsec-configs" + APPSEC_RULES = "appsec-rules" +) + +const ( + versionUpToDate = iota // the latest version from index is installed + versionUpdateAvailable // not installed, or lower than latest + versionUnknown // local file with no version, or invalid version number + versionFuture // local version is higher latest, but is included in the index: should not happen +) + +// The order is important, as it is used to range over sub-items in collections. +var ItemTypes = []string{PARSERS, POSTOVERFLOWS, SCENARIOS, CONTEXTS, APPSEC_CONFIGS, APPSEC_RULES, COLLECTIONS} + +type HubItems map[string]map[string]*Item + +// ItemVersion is used to detect the version of a given item +// by comparing the hash of each version to the local file. +// If the item does not match any known version, it is considered tainted (modified). +type ItemVersion struct { + Digest string `json:"digest,omitempty" yaml:"digest,omitempty"` + Deprecated bool `json:"deprecated,omitempty" yaml:"deprecated,omitempty"` +} + +// ItemState is used to keep the local state (i.e. at runtime) of an item. +// This data is not stored in the index, but is displayed with "cscli ... inspect". +type ItemState struct { + LocalPath string `json:"local_path,omitempty" yaml:"local_path,omitempty"` + LocalVersion string `json:"local_version,omitempty" yaml:"local_version,omitempty"` + LocalHash string `json:"local_hash,omitempty" yaml:"local_hash,omitempty"` + Installed bool `json:"installed"` + Downloaded bool `json:"downloaded"` + UpToDate bool `json:"up_to_date"` + Tainted bool `json:"tainted"` + TaintedBy []string `json:"tainted_by,omitempty" yaml:"tainted_by,omitempty"` + BelongsToCollections []string `json:"belongs_to_collections,omitempty" yaml:"belongs_to_collections,omitempty"` +} + +// IsLocal returns true if the item has been create by a user (not downloaded from the hub). +func (s *ItemState) IsLocal() bool { + return s.Installed && !s.Downloaded +} + +// Text returns the status of the item as a string (eg. "enabled,update-available"). +func (s *ItemState) Text() string { + ret := "disabled" + + if s.Installed { + ret = "enabled" + } + + if s.IsLocal() { + ret += ",local" + } + + if s.Tainted { + ret += ",tainted" + } else if !s.UpToDate && !s.IsLocal() { + ret += ",update-available" + } + + return ret +} + +// Emoji returns the status of the item as an emoji (eg. emoji.Warning). +func (s *ItemState) Emoji() string { + switch { + case s.IsLocal(): + return emoji.House + case !s.Installed: + return emoji.Prohibited + case s.Tainted || (!s.UpToDate && !s.IsLocal()): + return emoji.Warning + case s.Installed: + return emoji.CheckMark + default: + return emoji.QuestionMark + } +} + +// Item is created from an index file and enriched with local info. +type Item struct { + hub *Hub // back pointer to the hub, to retrieve other items and call install/remove methods + + State ItemState `json:"-" yaml:"-"` // local state, not stored in the index + + Type string `json:"type,omitempty" yaml:"type,omitempty"` // one of the ItemTypes + Stage string `json:"stage,omitempty" yaml:"stage,omitempty"` // Stage for parser|postoverflow: s00-raw/s01-... + Name string `json:"name,omitempty" yaml:"name,omitempty"` // usually "author/name" + FileName string `json:"file_name,omitempty" yaml:"file_name,omitempty"` // eg. apache2-logs.yaml + Description string `json:"description,omitempty" yaml:"description,omitempty"` + Content string `json:"content,omitempty" yaml:"-"` + Author string `json:"author,omitempty" yaml:"author,omitempty"` + References []string `json:"references,omitempty" yaml:"references,omitempty"` + + RemotePath string `json:"path,omitempty" yaml:"path,omitempty"` // path relative to the base URL eg. /parsers/stage/author/file.yaml + Version string `json:"version,omitempty" yaml:"version,omitempty"` // the last available version + Versions map[string]ItemVersion `json:"versions,omitempty" yaml:"-"` // all the known versions + + // if it's a collection, it can have sub items + Parsers []string `json:"parsers,omitempty" yaml:"parsers,omitempty"` + PostOverflows []string `json:"postoverflows,omitempty" yaml:"postoverflows,omitempty"` + Scenarios []string `json:"scenarios,omitempty" yaml:"scenarios,omitempty"` + Collections []string `json:"collections,omitempty" yaml:"collections,omitempty"` + Contexts []string `json:"contexts,omitempty" yaml:"contexts,omitempty"` + AppsecConfigs []string `json:"appsec-configs,omitempty" yaml:"appsec-configs,omitempty"` + AppsecRules []string `json:"appsec-rules,omitempty" yaml:"appsec-rules,omitempty"` +} + +// installPath returns the location of the symlink to the item in the hub, or the path of the item itself if it's local +// (eg. /etc/crowdsec/collections/xyz.yaml). +// Raises an error if the path goes outside of the install dir. +func (i *Item) installPath() (string, error) { + p := i.Type + if i.Stage != "" { + p = filepath.Join(p, i.Stage) + } + + return safePath(i.hub.local.InstallDir, filepath.Join(p, i.FileName)) +} + +// downloadPath returns the location of the actual config file in the hub +// (eg. /etc/crowdsec/hub/collections/author/xyz.yaml). +// Raises an error if the path goes outside of the hub dir. +func (i *Item) downloadPath() (string, error) { + ret, err := safePath(i.hub.local.HubDir, i.RemotePath) + if err != nil { + return "", err + } + + return ret, nil +} + +// HasSubItems returns true if items of this type can have sub-items. Currently only collections. +func (i *Item) HasSubItems() bool { + return i.Type == COLLECTIONS +} + +// MarshalJSON is used to prepare the output for "cscli ... inspect -o json". +// It must not use a pointer receiver. +func (i Item) MarshalJSON() ([]byte, error) { + type Alias Item + + return json.Marshal(&struct { + Alias + // we have to repeat the fields here, json will have inline support in v2 + LocalPath string `json:"local_path,omitempty"` + LocalVersion string `json:"local_version,omitempty"` + LocalHash string `json:"local_hash,omitempty"` + Installed bool `json:"installed"` + Downloaded bool `json:"downloaded"` + UpToDate bool `json:"up_to_date"` + Tainted bool `json:"tainted"` + Local bool `json:"local"` + BelongsToCollections []string `json:"belongs_to_collections,omitempty"` + }{ + Alias: Alias(i), + LocalPath: i.State.LocalPath, + LocalVersion: i.State.LocalVersion, + LocalHash: i.State.LocalHash, + Installed: i.State.Installed, + Downloaded: i.State.Downloaded, + UpToDate: i.State.UpToDate, + Tainted: i.State.Tainted, + BelongsToCollections: i.State.BelongsToCollections, + Local: i.State.IsLocal(), + }) +} + +// MarshalYAML is used to prepare the output for "cscli ... inspect -o raw". +// It must not use a pointer receiver. +func (i Item) MarshalYAML() (interface{}, error) { + type Alias Item + + return &struct { + Alias `yaml:",inline"` + State ItemState `yaml:",inline"` + Local bool `yaml:"local"` + }{ + Alias: Alias(i), + State: i.State, + Local: i.State.IsLocal(), + }, nil +} + +// SubItems returns a slice of sub-items, excluding the ones that were not found. +func (i *Item) SubItems() []*Item { + sub := make([]*Item, 0) + + for _, name := range i.Parsers { + s := i.hub.GetItem(PARSERS, name) + if s == nil { + continue + } + + sub = append(sub, s) + } + + for _, name := range i.PostOverflows { + s := i.hub.GetItem(POSTOVERFLOWS, name) + if s == nil { + continue + } + + sub = append(sub, s) + } + + for _, name := range i.Scenarios { + s := i.hub.GetItem(SCENARIOS, name) + if s == nil { + continue + } + + sub = append(sub, s) + } + + for _, name := range i.Contexts { + s := i.hub.GetItem(CONTEXTS, name) + if s == nil { + continue + } + + sub = append(sub, s) + } + + for _, name := range i.AppsecConfigs { + s := i.hub.GetItem(APPSEC_CONFIGS, name) + if s == nil { + continue + } + + sub = append(sub, s) + } + + for _, name := range i.AppsecRules { + s := i.hub.GetItem(APPSEC_RULES, name) + if s == nil { + continue + } + + sub = append(sub, s) + } + + for _, name := range i.Collections { + s := i.hub.GetItem(COLLECTIONS, name) + if s == nil { + continue + } + + sub = append(sub, s) + } + + return sub +} + +func (i *Item) logMissingSubItems() { + if !i.HasSubItems() { + return + } + + for _, subName := range i.Parsers { + if i.hub.GetItem(PARSERS, subName) == nil { + i.hub.logger.Errorf("can't find %s in %s, required by %s", subName, PARSERS, i.Name) + } + } + + for _, subName := range i.Scenarios { + if i.hub.GetItem(SCENARIOS, subName) == nil { + i.hub.logger.Errorf("can't find %s in %s, required by %s", subName, SCENARIOS, i.Name) + } + } + + for _, subName := range i.PostOverflows { + if i.hub.GetItem(POSTOVERFLOWS, subName) == nil { + i.hub.logger.Errorf("can't find %s in %s, required by %s", subName, POSTOVERFLOWS, i.Name) + } + } + + for _, subName := range i.Contexts { + if i.hub.GetItem(CONTEXTS, subName) == nil { + i.hub.logger.Errorf("can't find %s in %s, required by %s", subName, CONTEXTS, i.Name) + } + } + + for _, subName := range i.AppsecConfigs { + if i.hub.GetItem(APPSEC_CONFIGS, subName) == nil { + i.hub.logger.Errorf("can't find %s in %s, required by %s", subName, APPSEC_CONFIGS, i.Name) + } + } + + for _, subName := range i.AppsecRules { + if i.hub.GetItem(APPSEC_RULES, subName) == nil { + i.hub.logger.Errorf("can't find %s in %s, required by %s", subName, APPSEC_RULES, i.Name) + } + } + + for _, subName := range i.Collections { + if i.hub.GetItem(COLLECTIONS, subName) == nil { + i.hub.logger.Errorf("can't find %s in %s, required by %s", subName, COLLECTIONS, i.Name) + } + } +} + +// Ancestors returns a slice of items (typically collections) that have this item as a direct or indirect dependency. +func (i *Item) Ancestors() []*Item { + ret := make([]*Item, 0) + + for _, parentName := range i.State.BelongsToCollections { + parent := i.hub.GetItem(COLLECTIONS, parentName) + if parent == nil { + continue + } + + ret = append(ret, parent) + } + + return ret +} + +// descendants returns a list of all (direct or indirect) dependencies of the item. +func (i *Item) descendants() ([]*Item, error) { + var collectSubItems func(item *Item, visited map[*Item]bool, result *[]*Item) error + + collectSubItems = func(item *Item, visited map[*Item]bool, result *[]*Item) error { + if item == nil { + return nil + } + + if visited[item] { + return nil + } + + visited[item] = true + + for _, subItem := range item.SubItems() { + if subItem == i { + return fmt.Errorf("circular dependency detected: %s depends on %s", item.Name, i.Name) + } + + *result = append(*result, subItem) + + err := collectSubItems(subItem, visited, result) + if err != nil { + return err + } + } + + return nil + } + + ret := []*Item{} + visited := map[*Item]bool{} + + err := collectSubItems(i, visited, &ret) + if err != nil { + return nil, err + } + + return ret, nil +} + +// versionStatus returns the status of the item version compared to the hub version. +// semver requires the 'v' prefix. +func (i *Item) versionStatus() int { + local, err := semver.NewVersion(i.State.LocalVersion) + if err != nil { + return versionUnknown + } + + // hub versions are already validated while syncing, ignore errors + latest, _ := semver.NewVersion(i.Version) + + if local.LessThan(latest) { + return versionUpdateAvailable + } + + if local.Equal(latest) { + return versionUpToDate + } + + return versionFuture +} + +// validPath returns true if the (relative) path is allowed for the item. +// dirNname: the directory name (ie. crowdsecurity). +// fileName: the filename (ie. apache2-logs.yaml). +func (i *Item) validPath(dirName, fileName string) bool { + return (dirName+"/"+fileName == i.Name+".yaml") || (dirName+"/"+fileName == i.Name+".yml") +} + +// FQName returns the fully qualified name of the item (ie. parsers:crowdsecurity/apache2-logs). +func (i *Item) FQName() string { + return fmt.Sprintf("%s:%s", i.Type, i.Name) +} + +// addTaint marks the item as tainted, and propagates the taint to the ancestors. +// sub: the sub-item that caused the taint. May be the item itself! +func (i *Item) addTaint(sub *Item) { + i.State.Tainted = true + taintedBy := sub.FQName() + + idx, ok := slices.BinarySearch(i.State.TaintedBy, taintedBy) + if ok { + return + } + + // insert the taintedBy in the slice + + i.State.TaintedBy = append(i.State.TaintedBy, "") + + copy(i.State.TaintedBy[idx+1:], i.State.TaintedBy[idx:]) + + i.State.TaintedBy[idx] = taintedBy + + i.hub.logger.Debugf("%s is tainted by %s", i.Name, taintedBy) + + // propagate the taint to the ancestors + + for _, ancestor := range i.Ancestors() { + ancestor.addTaint(sub) + } +} + +// latestHash() returns the hash of the latest version of the item. +// if it's missing, the index file has been manually modified or got corrupted. +func (i *Item) latestHash() string { + for k, v := range i.Versions { + if k == i.Version { + return v.Digest + } + } + + return "" +} diff --git a/pkg/cwhub/item_test.go b/pkg/cwhub/item_test.go new file mode 100644 index 00000000000..703bbb5cb90 --- /dev/null +++ b/pkg/cwhub/item_test.go @@ -0,0 +1,71 @@ +package cwhub + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestItemStatus(t *testing.T) { + hub := envSetup(t) + + // get existing map + x := hub.GetItemMap(COLLECTIONS) + require.NotEmpty(t, x) + + // Get item: good and bad + for k := range x { + item := hub.GetItem(COLLECTIONS, k) + require.NotNil(t, item) + + item.State.Installed = true + item.State.UpToDate = false + item.State.Tainted = false + item.State.Downloaded = true + + txt := item.State.Text() + require.Equal(t, "enabled,update-available", txt) + + item.State.Installed = true + item.State.UpToDate = false + item.State.Tainted = false + item.State.Downloaded = false + + txt = item.State.Text() + require.Equal(t, "enabled,local", txt) + } + + stats := hub.ItemStats() + require.Equal(t, []string{ + "Loaded: 2 parsers, 1 scenarios, 3 collections", + "Unmanaged items: 3 local, 0 tainted", + }, stats) +} + +func TestGetters(t *testing.T) { + hub := envSetup(t) + + // get non existing map + empty := hub.GetItemMap("ratata") + require.Nil(t, empty) + + // get existing map + x := hub.GetItemMap(COLLECTIONS) + require.NotEmpty(t, x) + + // Get item: good and bad + for k := range x { + empty := hub.GetItem(COLLECTIONS, k+"nope") + require.Nil(t, empty) + + item := hub.GetItem(COLLECTIONS, k) + require.NotNil(t, item) + + // Add item and get it + item.Name += "nope" + hub.addItem(item) + + newitem := hub.GetItem(COLLECTIONS, item.Name) + require.NotNil(t, newitem) + } +} diff --git a/pkg/cwhub/iteminstall.go b/pkg/cwhub/iteminstall.go new file mode 100644 index 00000000000..912897d0d7e --- /dev/null +++ b/pkg/cwhub/iteminstall.go @@ -0,0 +1,73 @@ +package cwhub + +import ( + "context" + "fmt" +) + +// enable enables the item by creating a symlink to the downloaded content, and also enables sub-items. +func (i *Item) enable() error { + if i.State.Installed { + if i.State.Tainted { + return fmt.Errorf("%s is tainted, won't overwrite unless --force", i.Name) + } + + if i.State.IsLocal() { + return fmt.Errorf("%s is local, won't overwrite", i.Name) + } + + // if it's a collection, check sub-items even if the collection file itself is up-to-date + if i.State.UpToDate && !i.HasSubItems() { + i.hub.logger.Tracef("%s is installed and up-to-date, skip.", i.Name) + return nil + } + } + + for _, sub := range i.SubItems() { + if err := sub.enable(); err != nil { + return fmt.Errorf("while installing %s: %w", sub.Name, err) + } + } + + if err := i.createInstallLink(); err != nil { + return err + } + + i.hub.logger.Infof("Enabled %s: %s", i.Type, i.Name) + i.State.Installed = true + + return nil +} + +// Install installs the item from the hub, downloading it if needed. +func (i *Item) Install(ctx context.Context, force bool, downloadOnly bool) error { + if downloadOnly && i.State.Downloaded && i.State.UpToDate { + i.hub.logger.Infof("%s is already downloaded and up-to-date", i.Name) + + if !force { + return nil + } + } + + downloaded, err := i.downloadLatest(ctx, force, true) + if err != nil { + return err + } + + if downloadOnly && downloaded { + return nil + } + + if err := i.enable(); err != nil { + return fmt.Errorf("while enabling %s: %w", i.Name, err) + } + + // a check on stdout is used while scripting to know if the hub has been upgraded + // and a configuration reload is required + // TODO: use a better way to communicate this + fmt.Printf("installed %s\n", i.Name) + + i.hub.logger.Infof("Enabled %s", i.Name) + + return nil +} diff --git a/pkg/cwhub/iteminstall_test.go b/pkg/cwhub/iteminstall_test.go new file mode 100644 index 00000000000..5bfc7e8148e --- /dev/null +++ b/pkg/cwhub/iteminstall_test.go @@ -0,0 +1,141 @@ +package cwhub + +import ( + "context" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func testInstall(hub *Hub, t *testing.T, item *Item) { + ctx := context.Background() + + // Install the parser + _, err := item.downloadLatest(ctx, false, false) + require.NoError(t, err, "failed to download %s", item.Name) + + err = hub.localSync() + require.NoError(t, err, "failed to run localSync") + + assert.True(t, item.State.UpToDate, "%s should be up-to-date", item.Name) + assert.False(t, item.State.Installed, "%s should not be installed", item.Name) + assert.False(t, item.State.Tainted, "%s should not be tainted", item.Name) + + err = item.enable() + require.NoError(t, err, "failed to enable %s", item.Name) + + err = hub.localSync() + require.NoError(t, err, "failed to run localSync") + + assert.True(t, item.State.Installed, "%s should be installed", item.Name) +} + +func testTaint(hub *Hub, t *testing.T, item *Item) { + assert.False(t, item.State.Tainted, "%s should not be tainted", item.Name) + + // truncate the file + f, err := os.Create(item.State.LocalPath) + require.NoError(t, err) + err = f.Close() + require.NoError(t, err) + + // Local sync and check status + err = hub.localSync() + require.NoError(t, err, "failed to run localSync") + + assert.True(t, item.State.Tainted, "%s should be tainted", item.Name) +} + +func testUpdate(hub *Hub, t *testing.T, item *Item) { + assert.False(t, item.State.UpToDate, "%s should not be up-to-date", item.Name) + + ctx := context.Background() + + // Update it + check status + _, err := item.downloadLatest(ctx, true, true) + require.NoError(t, err, "failed to update %s", item.Name) + + // Local sync and check status + err = hub.localSync() + require.NoError(t, err, "failed to run localSync") + + assert.True(t, item.State.UpToDate, "%s should be up-to-date", item.Name) + assert.False(t, item.State.Tainted, "%s should not be tainted anymore", item.Name) +} + +func testDisable(hub *Hub, t *testing.T, item *Item) { + assert.True(t, item.State.Installed, "%s should be installed", item.Name) + + // Remove + _, err := item.disable(false, false) + require.NoError(t, err, "failed to disable %s", item.Name) + + // Local sync and check status + err = hub.localSync() + require.NoError(t, err, "failed to run localSync") + require.Empty(t, hub.Warnings) + + assert.False(t, item.State.Tainted, "%s should not be tainted anymore", item.Name) + assert.False(t, item.State.Installed, "%s should not be installed anymore", item.Name) + assert.True(t, item.State.Downloaded, "%s should still be downloaded", item.Name) + + // Purge + _, err = item.disable(true, false) + require.NoError(t, err, "failed to purge %s", item.Name) + + // Local sync and check status + err = hub.localSync() + require.NoError(t, err, "failed to run localSync") + require.Empty(t, hub.Warnings) + + assert.False(t, item.State.Installed, "%s should not be installed anymore", item.Name) + assert.False(t, item.State.Downloaded, "%s should not be downloaded", item.Name) +} + +func TestInstallParser(t *testing.T) { + /* + - install a random parser + - check its status + - taint it + - check its status + - force update it + - check its status + - remove it + */ + hub := envSetup(t) + + // map iteration is random by itself + for _, it := range hub.GetItemMap(PARSERS) { + testInstall(hub, t, it) + testTaint(hub, t, it) + testUpdate(hub, t, it) + testDisable(hub, t, it) + + break + } +} + +func TestInstallCollection(t *testing.T) { + /* + - install a random parser + - check its status + - taint it + - check its status + - force update it + - check its status + - remove it + */ + hub := envSetup(t) + + // map iteration is random by itself + for _, it := range hub.GetItemMap(COLLECTIONS) { + testInstall(hub, t, it) + testTaint(hub, t, it) + testUpdate(hub, t, it) + testDisable(hub, t, it) + + break + } +} diff --git a/pkg/cwhub/itemlink.go b/pkg/cwhub/itemlink.go new file mode 100644 index 00000000000..8a78d6805b7 --- /dev/null +++ b/pkg/cwhub/itemlink.go @@ -0,0 +1,78 @@ +package cwhub + +import ( + "fmt" + "os" + "path/filepath" +) + +// createInstallLink creates a symlink between the actual config file at hub.HubDir and hub.ConfigDir. +func (i *Item) createInstallLink() error { + dest, err := i.installPath() + if err != nil { + return err + } + + destDir := filepath.Dir(dest) + if err = os.MkdirAll(destDir, os.ModePerm); err != nil { + return fmt.Errorf("while creating %s: %w", destDir, err) + } + + if _, err = os.Lstat(dest); !os.IsNotExist(err) { + i.hub.logger.Infof("%s already exists.", dest) + return nil + } + + src, err := i.downloadPath() + if err != nil { + return err + } + + if err = os.Symlink(src, dest); err != nil { + return fmt.Errorf("while creating symlink from %s to %s: %w", src, dest, err) + } + + return nil +} + +// removeInstallLink removes the symlink to the downloaded content. +func (i *Item) removeInstallLink() error { + syml, err := i.installPath() + if err != nil { + return err + } + + stat, err := os.Lstat(syml) + if err != nil { + return err + } + + // if it's managed by hub, it's a symlink to csconfig.GConfig.hub.HubDir / ... + if stat.Mode()&os.ModeSymlink == 0 { + i.hub.logger.Warningf("%s (%s) isn't a symlink, can't disable", i.Name, syml) + return fmt.Errorf("%s isn't managed by hub", i.Name) + } + + hubpath, err := os.Readlink(syml) + if err != nil { + return fmt.Errorf("while reading symlink: %w", err) + } + + src, err := i.downloadPath() + if err != nil { + return err + } + + if hubpath != src { + i.hub.logger.Warningf("%s (%s) isn't a symlink to %s", i.Name, syml, src) + return fmt.Errorf("%s isn't managed by hub", i.Name) + } + + if err := os.Remove(syml); err != nil { + return fmt.Errorf("while removing symlink: %w", err) + } + + i.hub.logger.Infof("Removed symlink [%s]: %s", i.Name, syml) + + return nil +} diff --git a/pkg/cwhub/itemremove.go b/pkg/cwhub/itemremove.go new file mode 100644 index 00000000000..eca0c856237 --- /dev/null +++ b/pkg/cwhub/itemremove.go @@ -0,0 +1,138 @@ +package cwhub + +import ( + "fmt" + "os" + "slices" +) + +// purge removes the actual config file that was downloaded. +func (i *Item) purge() (bool, error) { + if !i.State.Downloaded { + i.hub.logger.Debugf("removing %s: not downloaded -- no need to remove", i.Name) + return false, nil + } + + src, err := i.downloadPath() + if err != nil { + return false, err + } + + if err := os.Remove(src); err != nil { + if os.IsNotExist(err) { + i.hub.logger.Debugf("%s doesn't exist, no need to remove", src) + return false, nil + } + + return false, fmt.Errorf("while removing file: %w", err) + } + + i.State.Downloaded = false + i.hub.logger.Infof("Removed source file [%s]: %s", i.Name, src) + + return true, nil +} + +// disable removes the install link, and optionally the downloaded content. +func (i *Item) disable(purge bool, force bool) (bool, error) { + didRemove := true + + err := i.removeInstallLink() + if os.IsNotExist(err) { + if !purge && !force { + link, _ := i.installPath() + return false, fmt.Errorf("link %s does not exist (override with --force or --purge)", link) + } + + didRemove = false + } else if err != nil { + return false, err + } + + i.State.Installed = false + didPurge := false + + if purge { + if didPurge, err = i.purge(); err != nil { + return didRemove, err + } + } + + ret := didRemove || didPurge + + return ret, nil +} + +// Remove disables the item, optionally removing the downloaded content. +func (i *Item) Remove(purge bool, force bool) (bool, error) { + if i.State.IsLocal() { + i.hub.logger.Warningf("%s is a local item, please delete manually", i.Name) + return false, nil + } + + if i.State.Tainted && !force { + return false, fmt.Errorf("%s is tainted, use '--force' to remove", i.Name) + } + + if !i.State.Installed && !purge { + i.hub.logger.Infof("removing %s: not installed -- no need to remove", i.Name) + return false, nil + } + + removed := false + + descendants, err := i.descendants() + if err != nil { + return false, err + } + + ancestors := i.Ancestors() + + for _, sub := range i.SubItems() { + if !sub.State.Installed { + continue + } + + // if the sub depends on a collection that is not a direct or indirect dependency + // of the current item, it is not removed + for _, subParent := range sub.Ancestors() { + if !purge && !subParent.State.Installed { + continue + } + + // the ancestor that would block the removal of the sub item is also an ancestor + // of the item we are removing, so we don't want false warnings + // (e.g. crowdsecurity/sshd-logs was not removed because it also belongs to crowdsecurity/linux, + // while we are removing crowdsecurity/sshd) + if slices.Contains(ancestors, subParent) { + continue + } + + // the sub-item belongs to the item we are removing, but we already knew that + if subParent == i { + continue + } + + if !slices.Contains(descendants, subParent) { + i.hub.logger.Infof("%s was not removed because it also belongs to %s", sub.Name, subParent.Name) + continue + } + } + + subRemoved, err := sub.Remove(purge, force) + if err != nil { + return false, fmt.Errorf("unable to disable %s: %w", i.Name, err) + } + + removed = removed || subRemoved + } + + didDisable, err := i.disable(purge, force) + if err != nil { + return false, fmt.Errorf("while removing %s: %w", i.Name, err) + } + + removed = removed || didDisable + + return removed, nil +} diff --git a/pkg/cwhub/itemupgrade.go b/pkg/cwhub/itemupgrade.go new file mode 100644 index 00000000000..105e5ebec31 --- /dev/null +++ b/pkg/cwhub/itemupgrade.go @@ -0,0 +1,254 @@ +package cwhub + +// Install, upgrade and remove items from the hub to the local configuration + +import ( + "context" + "crypto" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "os" + "path/filepath" + + "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/go-cs-lib/downloader" + + "github.com/crowdsecurity/crowdsec/pkg/emoji" +) + +// Upgrade downloads and applies the last version of the item from the hub. +func (i *Item) Upgrade(ctx context.Context, force bool) (bool, error) { + if i.State.IsLocal() { + i.hub.logger.Infof("not upgrading %s: local item", i.Name) + return false, nil + } + + if !i.State.Downloaded { + return false, fmt.Errorf("can't upgrade %s: not installed", i.Name) + } + + if !i.State.Installed { + return false, fmt.Errorf("can't upgrade %s: downloaded but not installed", i.Name) + } + + if i.State.UpToDate { + i.hub.logger.Infof("%s: up-to-date", i.Name) + + if err := i.DownloadDataIfNeeded(ctx, force); err != nil { + return false, fmt.Errorf("%s: download failed: %w", i.Name, err) + } + + if !force { + // no upgrade needed + return false, nil + } + } + + if _, err := i.downloadLatest(ctx, force, true); err != nil { + return false, fmt.Errorf("%s: download failed: %w", i.Name, err) + } + + if !i.State.UpToDate { + if i.State.Tainted { + i.hub.logger.Warningf("%v %s is tainted, --force to overwrite", emoji.Warning, i.Name) + } + + return false, nil + } + + // a check on stdout is used while scripting to know if the hub has been upgraded + // and a configuration reload is required + // TODO: use a better way to communicate this + fmt.Printf("updated %s\n", i.Name) + i.hub.logger.Infof("%v %s: updated", emoji.Package, i.Name) + + return true, nil +} + +// downloadLatest downloads the latest version of the item to the hub directory. +func (i *Item) downloadLatest(ctx context.Context, overwrite bool, updateOnly bool) (bool, error) { + i.hub.logger.Debugf("Downloading %s %s", i.Type, i.Name) + + for _, sub := range i.SubItems() { + if !sub.State.Installed && updateOnly && sub.State.Downloaded { + i.hub.logger.Debugf("skipping upgrade of %s: not installed", i.Name) + continue + } + + i.hub.logger.Debugf("Download %s sub-item: %s %s (%t -> %t)", i.Name, sub.Type, sub.Name, i.State.Installed, updateOnly) + + // recurse as it's a collection + if sub.HasSubItems() { + i.hub.logger.Tracef("collection, recurse") + + if _, err := sub.downloadLatest(ctx, overwrite, updateOnly); err != nil { + return false, err + } + } + + downloaded := sub.State.Downloaded + + if _, err := sub.download(ctx, overwrite); err != nil { + return false, err + } + + // We need to enable an item when it has been added to a collection since latest release of the collection. + // We check if sub.Downloaded is false because maybe the item has been disabled by the user. + if !sub.State.Installed && !downloaded { + if err := sub.enable(); err != nil { + return false, fmt.Errorf("enabling '%s': %w", sub.Name, err) + } + } + } + + if !i.State.Installed && updateOnly && i.State.Downloaded && !overwrite { + i.hub.logger.Debugf("skipping upgrade of %s: not installed", i.Name) + return false, nil + } + + return i.download(ctx, overwrite) +} + +// FetchContentTo downloads the last version of the item's YAML file to the specified path. +func (i *Item) FetchContentTo(ctx context.Context, destPath string) (bool, string, error) { + wantHash := i.latestHash() + if wantHash == "" { + return false, "", errors.New("latest hash missing from index. The index file is invalid, please run 'cscli hub update' and try again") + } + + // Use the embedded content if available + if i.Content != "" { + // the content was historically base64 encoded + content, err := base64.StdEncoding.DecodeString(i.Content) + if err != nil { + content = []byte(i.Content) + } + + dir := filepath.Dir(destPath) + + if err := os.MkdirAll(dir, 0o755); err != nil { + return false, "", fmt.Errorf("while creating %s: %w", dir, err) + } + + // check sha256 + hash := crypto.SHA256.New() + if _, err := hash.Write(content); err != nil { + return false, "", fmt.Errorf("while hashing %s: %w", i.Name, err) + } + + gotHash := hex.EncodeToString(hash.Sum(nil)) + if gotHash != wantHash { + return false, "", fmt.Errorf("hash mismatch: expected %s, got %s. The index file is invalid, please run 'cscli hub update' and try again", wantHash, gotHash) + } + + if err := os.WriteFile(destPath, content, 0o600); err != nil { + return false, "", fmt.Errorf("while writing %s: %w", destPath, err) + } + + i.hub.logger.Debugf("Wrote %s content from .index.json to %s", i.Name, destPath) + + return true, fmt.Sprintf("(embedded in %s)", i.hub.local.HubIndexFile), nil + } + + url, err := i.hub.remote.urlTo(i.RemotePath) + if err != nil { + return false, "", fmt.Errorf("failed to build request: %w", err) + } + + d := downloader. + New(). + WithHTTPClient(hubClient). + ToFile(destPath). + WithETagFn(downloader.SHA256). + WithMakeDirs(true). + WithLogger(logrus.WithField("url", url)). + CompareContent(). + VerifyHash("sha256", wantHash) + + // TODO: recommend hub update if hash does not match + + downloaded, err := d.Download(ctx, url) + if err != nil { + return false, "", err + } + + return downloaded, url, nil +} + +// download downloads the item from the hub and writes it to the hub directory. +func (i *Item) download(ctx context.Context, overwrite bool) (bool, error) { + // ensure that target file is within target dir + finalPath, err := i.downloadPath() + if err != nil { + return false, err + } + + if i.State.IsLocal() { + i.hub.logger.Warningf("%s is local, can't download", i.Name) + return false, nil + } + + // if user didn't --force, don't overwrite local, tainted, up-to-date files + if !overwrite { + if i.State.Tainted { + i.hub.logger.Debugf("%s: tainted, not updated", i.Name) + return false, nil + } + + if i.State.UpToDate { + // We still have to check if data files are present + i.hub.logger.Debugf("%s: up-to-date, not updated", i.Name) + } + } + + downloaded, _, err := i.FetchContentTo(ctx, finalPath) + if err != nil { + return false, err + } + + if downloaded { + i.hub.logger.Infof("Downloaded %s", i.Name) + } + + i.State.Downloaded = true + i.State.Tainted = false + i.State.UpToDate = true + + // read content to get the list of data files + reader, err := os.Open(finalPath) + if err != nil { + return false, fmt.Errorf("while opening %s: %w", finalPath, err) + } + + defer reader.Close() + + if err = downloadDataSet(ctx, i.hub.local.InstallDataDir, overwrite, reader, i.hub.logger); err != nil { + return false, fmt.Errorf("while downloading data for %s: %w", i.FileName, err) + } + + return true, nil +} + +// DownloadDataIfNeeded downloads the data set for the item. +func (i *Item) DownloadDataIfNeeded(ctx context.Context, force bool) error { + itemFilePath, err := i.installPath() + if err != nil { + return err + } + + itemFile, err := os.Open(itemFilePath) + if err != nil { + return fmt.Errorf("while opening %s: %w", itemFilePath, err) + } + + defer itemFile.Close() + + if err = downloadDataSet(ctx, i.hub.local.InstallDataDir, force, itemFile, i.hub.logger); err != nil { + return fmt.Errorf("while downloading data for %s: %w", itemFilePath, err) + } + + return nil +} diff --git a/pkg/cwhub/itemupgrade_test.go b/pkg/cwhub/itemupgrade_test.go new file mode 100644 index 00000000000..5f9e4d1944e --- /dev/null +++ b/pkg/cwhub/itemupgrade_test.go @@ -0,0 +1,223 @@ +package cwhub + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/crowdsecurity/crowdsec/pkg/csconfig" +) + +// Download index, install collection. Add scenario to collection (hub-side), update index, upgrade collection. +// We expect the new scenario to be installed. +func TestUpgradeItemNewScenarioInCollection(t *testing.T) { + hub := envSetup(t) + item := hub.GetItem(COLLECTIONS, "crowdsecurity/test_collection") + + // fresh install of collection + require.False(t, item.State.Downloaded) + require.False(t, item.State.Installed) + + ctx := context.Background() + + require.NoError(t, item.Install(ctx, false, false)) + + require.True(t, item.State.Downloaded) + require.True(t, item.State.Installed) + require.True(t, item.State.UpToDate) + require.False(t, item.State.Tainted) + + // This is the scenario that gets added in next version of collection + require.Nil(t, hub.GetItem(SCENARIOS, "crowdsecurity/barfoo_scenario")) + + assertCollectionDepsInstalled(t, hub, "crowdsecurity/test_collection") + + // collection receives an update. It now adds new scenario "crowdsecurity/barfoo_scenario" + pushUpdateToCollectionInHub() + + remote := &RemoteHubCfg{ + URLTemplate: mockURLTemplate, + Branch: "master", + IndexPath: ".index.json", + } + + hub, err := NewHub(hub.local, remote, nil) + require.NoError(t, err) + + err = hub.Update(ctx) + require.NoError(t, err) + + err = hub.Load() + require.NoError(t, err) + + hub = getHubOrFail(t, hub.local, remote) + + item = hub.GetItem(COLLECTIONS, "crowdsecurity/test_collection") + + require.True(t, item.State.Downloaded) + require.True(t, item.State.Installed) + require.False(t, item.State.UpToDate) + require.False(t, item.State.Tainted) + + didUpdate, err := item.Upgrade(ctx, false) + require.NoError(t, err) + require.True(t, didUpdate) + assertCollectionDepsInstalled(t, hub, "crowdsecurity/test_collection") + + require.True(t, hub.GetItem(SCENARIOS, "crowdsecurity/barfoo_scenario").State.Downloaded) + require.True(t, hub.GetItem(SCENARIOS, "crowdsecurity/barfoo_scenario").State.Installed) +} + +// Install a collection, disable a scenario. +// Upgrade should install should not enable/download the disabled scenario. +func TestUpgradeItemInDisabledScenarioShouldNotBeInstalled(t *testing.T) { + hub := envSetup(t) + item := hub.GetItem(COLLECTIONS, "crowdsecurity/test_collection") + + // fresh install of collection + require.False(t, item.State.Downloaded) + require.False(t, item.State.Installed) + require.False(t, hub.GetItem(SCENARIOS, "crowdsecurity/foobar_scenario").State.Installed) + + ctx := context.Background() + + require.NoError(t, item.Install(ctx, false, false)) + + require.True(t, item.State.Downloaded) + require.True(t, item.State.Installed) + require.True(t, item.State.UpToDate) + require.False(t, item.State.Tainted) + require.True(t, hub.GetItem(SCENARIOS, "crowdsecurity/foobar_scenario").State.Installed) + assertCollectionDepsInstalled(t, hub, "crowdsecurity/test_collection") + + item = hub.GetItem(SCENARIOS, "crowdsecurity/foobar_scenario") + didRemove, err := item.Remove(false, false) + require.NoError(t, err) + require.True(t, didRemove) + + remote := &RemoteHubCfg{ + URLTemplate: mockURLTemplate, + Branch: "master", + IndexPath: ".index.json", + } + + hub = getHubOrFail(t, hub.local, remote) + // scenario referenced by collection was deleted hence, collection should be tainted + require.False(t, hub.GetItem(SCENARIOS, "crowdsecurity/foobar_scenario").State.Installed) + + require.True(t, hub.GetItem(COLLECTIONS, "crowdsecurity/test_collection").State.Tainted) + require.True(t, hub.GetItem(COLLECTIONS, "crowdsecurity/test_collection").State.Downloaded) + require.True(t, hub.GetItem(COLLECTIONS, "crowdsecurity/test_collection").State.Installed) + require.True(t, hub.GetItem(COLLECTIONS, "crowdsecurity/test_collection").State.UpToDate) + + hub, err = NewHub(hub.local, remote, nil) + require.NoError(t, err) + + err = hub.Update(ctx) + require.NoError(t, err) + + err = hub.Load() + require.NoError(t, err) + + item = hub.GetItem(COLLECTIONS, "crowdsecurity/test_collection") + didUpdate, err := item.Upgrade(ctx, false) + require.NoError(t, err) + require.False(t, didUpdate) + + hub = getHubOrFail(t, hub.local, remote) + require.False(t, hub.GetItem(SCENARIOS, "crowdsecurity/foobar_scenario").State.Installed) +} + +// getHubOrFail refreshes the hub state (load index, sync) and returns the singleton, or fails the test. +func getHubOrFail(t *testing.T, local *csconfig.LocalHubCfg, remote *RemoteHubCfg) *Hub { + hub, err := NewHub(local, remote, nil) + require.NoError(t, err) + + err = hub.Load() + require.NoError(t, err) + + return hub +} + +// Install a collection. Disable a referenced scenario. Publish new version of collection with new scenario +// Upgrade should not enable/download the disabled scenario. +// Upgrade should install and enable the newly added scenario. +func TestUpgradeItemNewScenarioIsInstalledWhenReferencedScenarioIsDisabled(t *testing.T) { + hub := envSetup(t) + item := hub.GetItem(COLLECTIONS, "crowdsecurity/test_collection") + + // fresh install of collection + require.False(t, item.State.Downloaded) + require.False(t, item.State.Installed) + require.False(t, hub.GetItem(SCENARIOS, "crowdsecurity/foobar_scenario").State.Installed) + + ctx := context.Background() + + require.NoError(t, item.Install(ctx, false, false)) + + require.True(t, item.State.Downloaded) + require.True(t, item.State.Installed) + require.True(t, item.State.UpToDate) + require.False(t, item.State.Tainted) + require.True(t, hub.GetItem(SCENARIOS, "crowdsecurity/foobar_scenario").State.Installed) + assertCollectionDepsInstalled(t, hub, "crowdsecurity/test_collection") + + item = hub.GetItem(SCENARIOS, "crowdsecurity/foobar_scenario") + didRemove, err := item.Remove(false, false) + require.NoError(t, err) + require.True(t, didRemove) + + remote := &RemoteHubCfg{ + URLTemplate: mockURLTemplate, + Branch: "master", + IndexPath: ".index.json", + } + + hub = getHubOrFail(t, hub.local, remote) + // scenario referenced by collection was deleted hence, collection should be tainted + require.False(t, hub.GetItem(SCENARIOS, "crowdsecurity/foobar_scenario").State.Installed) + require.True(t, hub.GetItem(SCENARIOS, "crowdsecurity/foobar_scenario").State.Downloaded) // this fails + require.True(t, hub.GetItem(COLLECTIONS, "crowdsecurity/test_collection").State.Tainted) + require.True(t, hub.GetItem(COLLECTIONS, "crowdsecurity/test_collection").State.Downloaded) + require.True(t, hub.GetItem(COLLECTIONS, "crowdsecurity/test_collection").State.Installed) + require.True(t, hub.GetItem(COLLECTIONS, "crowdsecurity/test_collection").State.UpToDate) + + // collection receives an update. It now adds new scenario "crowdsecurity/barfoo_scenario" + // we now attempt to upgrade the collection, however it shouldn't install the foobar_scenario + // we just removed. Nor should it install the newly added scenario + pushUpdateToCollectionInHub() + + hub, err = NewHub(hub.local, remote, nil) + require.NoError(t, err) + + err = hub.Update(ctx) + require.NoError(t, err) + + err = hub.Load() + require.NoError(t, err) + + require.False(t, hub.GetItem(SCENARIOS, "crowdsecurity/foobar_scenario").State.Installed) + hub = getHubOrFail(t, hub.local, remote) + + item = hub.GetItem(COLLECTIONS, "crowdsecurity/test_collection") + didUpdate, err := item.Upgrade(ctx, false) + require.NoError(t, err) + require.True(t, didUpdate) + + hub = getHubOrFail(t, hub.local, remote) + require.False(t, hub.GetItem(SCENARIOS, "crowdsecurity/foobar_scenario").State.Installed) + require.True(t, hub.GetItem(SCENARIOS, "crowdsecurity/barfoo_scenario").State.Installed) +} + +func assertCollectionDepsInstalled(t *testing.T, hub *Hub, collection string) { + t.Helper() + + c := hub.GetItem(COLLECTIONS, collection) + require.Empty(t, c.checkSubItemVersions()) +} + +func pushUpdateToCollectionInHub() { + responseByPath["/crowdsecurity/master/.index.json"] = fileToStringX("./testdata/index2.json") + responseByPath["/crowdsecurity/master/collections/crowdsecurity/test_collection.yaml"] = fileToStringX("./testdata/collection_v2.yaml") +} diff --git a/pkg/cwhub/loader.go b/pkg/cwhub/loader.go deleted file mode 100644 index 1b2b8c83212..00000000000 --- a/pkg/cwhub/loader.go +++ /dev/null @@ -1,428 +0,0 @@ -package cwhub - -import ( - "encoding/json" - "errors" - "fmt" - "os" - "path/filepath" - "sort" - "strings" - - log "github.com/sirupsen/logrus" - "golang.org/x/mod/semver" - - "github.com/crowdsecurity/crowdsec/pkg/csconfig" -) - -/*the walk/parser_visit function can't receive extra args*/ -var hubdir, installdir string - -func parser_visit(path string, f os.DirEntry, err error) error { - - var target Item - var local bool - var hubpath string - var inhub bool - var fname string - var ftype string - var fauthor string - var stage string - - if err != nil { - log.Debugf("while syncing hub dir: %s", err) - // there is a path error, we ignore the file - return nil - } - - path, err = filepath.Abs(path) - if err != nil { - return err - } - //we only care about files - if f == nil || f.IsDir() { - return nil - } - //we only care about yaml files - if !strings.HasSuffix(f.Name(), ".yaml") && !strings.HasSuffix(f.Name(), ".yml") { - return nil - } - - subs := strings.Split(path, string(os.PathSeparator)) - - log.Tracef("path:%s, hubdir:%s, installdir:%s", path, hubdir, installdir) - log.Tracef("subs:%v", subs) - /*we're in hub (~/.hub/hub/)*/ - if strings.HasPrefix(path, hubdir) { - log.Tracef("in hub dir") - inhub = true - //.../hub/parsers/s00-raw/crowdsec/skip-pretag.yaml - //.../hub/scenarios/crowdsec/ssh_bf.yaml - //.../hub/profiles/crowdsec/linux.yaml - if len(subs) < 4 { - log.Fatalf("path is too short : %s (%d)", path, len(subs)) - } - fname = subs[len(subs)-1] - fauthor = subs[len(subs)-2] - stage = subs[len(subs)-3] - ftype = subs[len(subs)-4] - } else if strings.HasPrefix(path, installdir) { /*we're in install /etc/crowdsec//... */ - log.Tracef("in install dir") - if len(subs) < 3 { - log.Fatalf("path is too short : %s (%d)", path, len(subs)) - } - ///.../config/parser/stage/file.yaml - ///.../config/postoverflow/stage/file.yaml - ///.../config/scenarios/scenar.yaml - ///.../config/collections/linux.yaml //file is empty - fname = subs[len(subs)-1] - stage = subs[len(subs)-2] - ftype = subs[len(subs)-3] - fauthor = "" - } else { - return fmt.Errorf("file '%s' is not from hub '%s' nor from the configuration directory '%s'", path, hubdir, installdir) - } - - log.Tracef("stage:%s ftype:%s", stage, ftype) - //log.Printf("%s -> name:%s stage:%s", path, fname, stage) - if stage == SCENARIOS { - ftype = SCENARIOS - stage = "" - } else if stage == COLLECTIONS { - ftype = COLLECTIONS - stage = "" - } else if ftype != PARSERS && ftype != PARSERS_OVFLW /*its a PARSER / PARSER_OVFLW with a stage */ { - return fmt.Errorf("unknown configuration type for file '%s'", path) - } - - log.Tracef("CORRECTED [%s] by [%s] in stage [%s] of type [%s]", fname, fauthor, stage, ftype) - - /* - we can encounter 'collections' in the form of a symlink : - /etc/crowdsec/.../collections/linux.yaml -> ~/.hub/hub/collections/.../linux.yaml - when the collection is installed, both files are created - */ - //non symlinks are local user files or hub files - if f.Type()&os.ModeSymlink == 0 { - local = true - log.Tracef("%s isn't a symlink", path) - } else { - hubpath, err = os.Readlink(path) - if err != nil { - return fmt.Errorf("unable to read symlink of %s", path) - } - //the symlink target doesn't exist, user might have removed ~/.hub/hub/...yaml without deleting /etc/crowdsec/....yaml - _, err := os.Lstat(hubpath) - if os.IsNotExist(err) { - log.Infof("%s is a symlink to %s that doesn't exist, deleting symlink", path, hubpath) - //remove the symlink - if err = os.Remove(path); err != nil { - return fmt.Errorf("failed to unlink %s: %+v", path, err) - } - return nil - } - log.Tracef("%s points to %s", path, hubpath) - } - - //if it's not a symlink and not in hub, it's a local file, don't bother - if local && !inhub { - log.Tracef("%s is a local file, skip", path) - skippedLocal++ - // log.Printf("local scenario, skip.") - target.Name = fname - target.Stage = stage - target.Installed = true - target.Type = ftype - target.Local = true - target.LocalPath = path - target.UpToDate = true - _, target.FileName = filepath.Split(path) - - hubIdx[ftype][fname] = target - return nil - } - //try to find which configuration item it is - log.Tracef("check [%s] of %s", fname, ftype) - - match := false - for k, v := range hubIdx[ftype] { - log.Tracef("check [%s] vs [%s] : %s", fname, v.RemotePath, ftype+"/"+stage+"/"+fname+".yaml") - if fname != v.FileName { - log.Tracef("%s != %s (filename)", fname, v.FileName) - continue - } - //wrong stage - if v.Stage != stage { - continue - } - /*if we are walking hub dir, just mark present files as downloaded*/ - if inhub { - //wrong author - if fauthor != v.Author { - continue - } - //wrong file - if CheckName(v.Name, fauthor, fname) { - continue - } - - if path == hubdir+"/"+v.RemotePath { - log.Tracef("marking %s as downloaded", v.Name) - v.Downloaded = true - } - } else if CheckSuffix(hubpath, v.RemotePath) { - //wrong file - /////.yaml - continue - } - sha, err := getSHA256(path) - if err != nil { - log.Fatalf("Failed to get sha of %s : %v", path, err) - } - //let's reverse sort the versions to deal with hash collisions (#154) - versions := make([]string, 0, len(v.Versions)) - for k := range v.Versions { - versions = append(versions, k) - } - sort.Sort(sort.Reverse(sort.StringSlice(versions))) - - for _, version := range versions { - val := v.Versions[version] - if sha != val.Digest { - //log.Printf("matching filenames, wrong hash %s != %s -- %s", sha, val.Digest, spew.Sdump(v)) - continue - } - /*we got an exact match, update struct*/ - if !inhub { - log.Tracef("found exact match for %s, version is %s, latest is %s", v.Name, version, v.Version) - v.LocalPath = path - v.LocalVersion = version - v.Tainted = false - v.Downloaded = true - /*if we're walking the hub, present file doesn't means installed file*/ - v.Installed = true - v.LocalHash = sha - _, target.FileName = filepath.Split(path) - } else { - v.Downloaded = true - v.LocalHash = sha - } - if version == v.Version { - log.Tracef("%s is up-to-date", v.Name) - v.UpToDate = true - } - match = true - break - } - if !match { - log.Tracef("got tainted match for %s : %s", v.Name, path) - skippedTainted += 1 - //the file and the stage is right, but the hash is wrong, it has been tainted by user - if !inhub { - v.LocalPath = path - v.Installed = true - } - v.UpToDate = false - v.LocalVersion = "?" - v.Tainted = true - v.LocalHash = sha - _, target.FileName = filepath.Split(path) - - } - //update the entry if appropriate - // if _, ok := hubIdx[ftype][k]; !ok || !inhub || v.D { - // fmt.Printf("Updating %s", k) - // hubIdx[ftype][k] = v - // } else if !inhub { - - // } else if - hubIdx[ftype][k] = v - return nil - } - log.Infof("Ignoring file %s of type %s", path, ftype) - return nil -} - -func CollecDepsCheck(v *Item) error { - - if GetVersionStatus(v) != 0 { //not up-to-date - log.Debugf("%s dependencies not checked : not up-to-date", v.Name) - return nil - } - - /*if it's a collection, ensure all the items are installed, or tag it as tainted*/ - if v.Type == COLLECTIONS { - log.Tracef("checking submembers of %s installed:%t", v.Name, v.Installed) - var tmp = [][]string{v.Parsers, v.PostOverflows, v.Scenarios, v.Collections} - for idx, ptr := range tmp { - ptrtype := ItemTypes[idx] - for _, p := range ptr { - val, ok := hubIdx[ptrtype][p] - if !ok { - log.Fatalf("Referred %s %s in collection %s doesn't exist.", ptrtype, p, v.Name) - } - log.Tracef("check %s installed:%t", val.Name, val.Installed) - if !v.Installed { - continue - } - if val.Type == COLLECTIONS { - log.Tracef("collec, recurse.") - if err := CollecDepsCheck(&val); err != nil { - if val.Tainted { - v.Tainted = true - } - return fmt.Errorf("sub collection %s is broken : %s", val.Name, err) - } - hubIdx[ptrtype][p] = val - } - - //propagate the state of sub-items to set - if val.Tainted { - v.Tainted = true - return fmt.Errorf("tainted %s %s, tainted.", ptrtype, p) - } - if !val.Installed && v.Installed { - v.Tainted = true - return fmt.Errorf("missing %s %s, tainted.", ptrtype, p) - } - if !val.UpToDate { - v.UpToDate = false - return fmt.Errorf("outdated %s %s", ptrtype, p) - } - skip := false - for idx := range val.BelongsToCollections { - if val.BelongsToCollections[idx] == v.Name { - skip = true - } - } - if !skip { - val.BelongsToCollections = append(val.BelongsToCollections, v.Name) - } - hubIdx[ptrtype][p] = val - log.Tracef("checking for %s - tainted:%t uptodate:%t", p, v.Tainted, v.UpToDate) - } - } - } - return nil -} - -func SyncDir(hub *csconfig.Hub, dir string) (error, []string) { - hubdir = hub.HubDir - installdir = hub.ConfigDir - warnings := []string{} - - /*For each, scan PARSERS, PARSERS_OVFLW, SCENARIOS and COLLECTIONS last*/ - for _, scan := range ItemTypes { - cpath, err := filepath.Abs(fmt.Sprintf("%s/%s", dir, scan)) - if err != nil { - log.Errorf("failed %s : %s", cpath, err) - } - err = filepath.WalkDir(cpath, parser_visit) - if err != nil { - return err, warnings - } - - } - - for k, v := range hubIdx[COLLECTIONS] { - if v.Installed { - versStat := GetVersionStatus(&v) - if versStat == 0 { //latest - if err := CollecDepsCheck(&v); err != nil { - warnings = append(warnings, fmt.Sprintf("dependency of %s : %s", v.Name, err)) - hubIdx[COLLECTIONS][k] = v - } - } else if versStat == 1 { //not up-to-date - warnings = append(warnings, fmt.Sprintf("update for collection %s available (currently:%s, latest:%s)", v.Name, v.LocalVersion, v.Version)) - } else { //version is higher than the highest available from hub? - warnings = append(warnings, fmt.Sprintf("collection %s is in the future (currently:%s, latest:%s)", v.Name, v.LocalVersion, v.Version)) - } - log.Debugf("installed (%s) - status:%d | installed:%s | latest : %s | full : %+v", v.Name, semver.Compare("v"+v.Version, "v"+v.LocalVersion), v.LocalVersion, v.Version, v.Versions) - } - } - return nil, warnings -} - -/* Updates the infos from HubInit() with the local state */ -func LocalSync(hub *csconfig.Hub) (error, []string) { - skippedLocal = 0 - skippedTainted = 0 - - err, warnings := SyncDir(hub, hub.ConfigDir) - if err != nil { - return fmt.Errorf("failed to scan %s : %s", hub.ConfigDir, err), warnings - } - err, _ = SyncDir(hub, hub.HubDir) - if err != nil { - return fmt.Errorf("failed to scan %s : %s", hub.HubDir, err), warnings - } - return nil, warnings -} - -func GetHubIdx(hub *csconfig.Hub) error { - if hub == nil { - return fmt.Errorf("no configuration found for hub") - } - log.Debugf("loading hub idx %s", hub.HubIndexFile) - bidx, err := os.ReadFile(hub.HubIndexFile) - if err != nil { - return fmt.Errorf("unable to read index file: %w", err) - } - ret, err := LoadPkgIndex(bidx) - if err != nil { - if !errors.Is(err, ReferenceMissingError) { - log.Fatalf("Unable to load existing index : %v.", err) - } - return err - } - hubIdx = ret - err, _ = LocalSync(hub) - if err != nil { - log.Fatalf("Failed to sync Hub index with local deployment : %v", err) - } - return nil -} - -/*LoadPkgIndex loads a local .index.json file and returns the map of parsers/scenarios/collections associated*/ -func LoadPkgIndex(buff []byte) (map[string]map[string]Item, error) { - var err error - var RawIndex map[string]map[string]Item - var missingItems []string - - if err = json.Unmarshal(buff, &RawIndex); err != nil { - return nil, fmt.Errorf("failed to unmarshal index : %v", err) - } - - log.Debugf("%d item types in hub index", len(ItemTypes)) - /*Iterate over the different types to complete struct */ - for _, itemType := range ItemTypes { - /*complete struct*/ - log.Tracef("%d item", len(RawIndex[itemType])) - for idx, item := range RawIndex[itemType] { - item.Name = idx - item.Type = itemType - x := strings.Split(item.RemotePath, "/") - item.FileName = x[len(x)-1] - RawIndex[itemType][idx] = item - /*if it's a collection, check its sub-items are present*/ - //XX should be done later - if itemType == COLLECTIONS { - var tmp = [][]string{item.Parsers, item.PostOverflows, item.Scenarios, item.Collections} - for idx, ptr := range tmp { - ptrtype := ItemTypes[idx] - for _, p := range ptr { - if _, ok := RawIndex[ptrtype][p]; !ok { - log.Errorf("Referred %s %s in collection %s doesn't exist.", ptrtype, p, item.Name) - missingItems = append(missingItems, p) - } - } - } - } - } - } - if len(missingItems) > 0 { - return RawIndex, fmt.Errorf("%q : %w", missingItems, ReferenceMissingError) - } - - return RawIndex, nil -} diff --git a/pkg/cwhub/path_separator_windows.go b/pkg/cwhub/path_separator_windows.go deleted file mode 100644 index 42f61aa16f0..00000000000 --- a/pkg/cwhub/path_separator_windows.go +++ /dev/null @@ -1,23 +0,0 @@ -package cwhub - -import ( - "path/filepath" - "strings" -) - -func CheckSuffix(hubpath string, remotePath string) bool { - newPath := filepath.ToSlash(hubpath) - if !strings.HasSuffix(newPath, remotePath) { - return true - } else { - return false - } -} - -func CheckName(vname string, fauthor string, fname string) bool { - if vname+".yaml" != fauthor+"/"+fname && vname+".yml" != fauthor+"/"+fname { - return true - } else { - return false - } -} diff --git a/pkg/cwhub/pathseparator.go b/pkg/cwhub/pathseparator.go deleted file mode 100644 index 0340697ee6e..00000000000 --- a/pkg/cwhub/pathseparator.go +++ /dev/null @@ -1,24 +0,0 @@ -//go:build linux || freebsd || netbsd || openbsd || solaris || !windows -// +build linux freebsd netbsd openbsd solaris !windows - -package cwhub - -import "strings" - -const PathSeparator = "/" - -func CheckSuffix(hubpath string, remotePath string) bool { - if !strings.HasSuffix(hubpath, remotePath) { - return true - } else { - return false - } -} - -func CheckName(vname string, fauthor string, fname string) bool { - if vname+".yaml" != fauthor+"/"+fname && vname+".yml" != fauthor+"/"+fname { - return true - } else { - return false - } -} diff --git a/pkg/cwhub/pathseparator_unix.go b/pkg/cwhub/pathseparator_unix.go new file mode 100644 index 00000000000..9420dc11ebe --- /dev/null +++ b/pkg/cwhub/pathseparator_unix.go @@ -0,0 +1,9 @@ +//go:build unix + +package cwhub + +import "strings" + +func hasPathSuffix(hubpath string, remotePath string) bool { + return strings.HasSuffix(hubpath, remotePath) +} diff --git a/pkg/cwhub/pathseparator_windows.go b/pkg/cwhub/pathseparator_windows.go new file mode 100644 index 00000000000..a6d1be3f8d1 --- /dev/null +++ b/pkg/cwhub/pathseparator_windows.go @@ -0,0 +1,11 @@ +package cwhub + +import ( + "path/filepath" + "strings" +) + +func hasPathSuffix(hubpath string, remotePath string) bool { + newPath := filepath.ToSlash(hubpath) + return strings.HasSuffix(newPath, remotePath) +} diff --git a/pkg/cwhub/relativepath.go b/pkg/cwhub/relativepath.go new file mode 100644 index 00000000000..bcd4c576840 --- /dev/null +++ b/pkg/cwhub/relativepath.go @@ -0,0 +1,28 @@ +package cwhub + +import ( + "path/filepath" + "strings" +) + +// relativePathComponents returns the list of path components after baseDir. +// If path is not inside baseDir, it returns an empty slice. +func relativePathComponents(path string, baseDir string) []string { + absPath, err := filepath.Abs(path) + if err != nil { + return []string{} + } + + absBaseDir, err := filepath.Abs(baseDir) + if err != nil { + return []string{} + } + + // is path inside baseDir? + relPath, err := filepath.Rel(absBaseDir, absPath) + if err != nil || strings.HasPrefix(relPath, "..") || relPath == "." { + return []string{} + } + + return strings.Split(relPath, string(filepath.Separator)) +} diff --git a/pkg/cwhub/relativepath_test.go b/pkg/cwhub/relativepath_test.go new file mode 100644 index 00000000000..11eba566064 --- /dev/null +++ b/pkg/cwhub/relativepath_test.go @@ -0,0 +1,72 @@ +package cwhub + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRelativePathComponents(t *testing.T) { + tests := []struct { + name string + path string + baseDir string + expected []string + }{ + { + name: "Path within baseDir", + path: "/home/user/project/src/file.go", + baseDir: "/home/user/project", + expected: []string{"src", "file.go"}, + }, + { + name: "Path is baseDir", + path: "/home/user/project", + baseDir: "/home/user/project", + expected: []string{}, + }, + { + name: "Path outside baseDir", + path: "/home/user/otherproject/src/file.go", + baseDir: "/home/user/project", + expected: []string{}, + }, + { + name: "Path is subdirectory of baseDir", + path: "/home/user/project/src/", + baseDir: "/home/user/project", + expected: []string{"src"}, + }, + { + name: "Relative paths", + path: "project/src/file.go", + baseDir: "project", + expected: []string{"src", "file.go"}, + }, + { + name: "BaseDir with trailing slash", + path: "/home/user/project/src/file.go", + baseDir: "/home/user/project/", + expected: []string{"src", "file.go"}, + }, + { + name: "Empty baseDir", + path: "/home/user/project/src/file.go", + baseDir: "", + expected: []string{}, + }, + { + name: "Empty path", + path: "", + baseDir: "/home/user/project", + expected: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := relativePathComponents(tt.path, tt.baseDir) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/pkg/cwhub/remote.go b/pkg/cwhub/remote.go new file mode 100644 index 00000000000..8d2dc2dbb94 --- /dev/null +++ b/pkg/cwhub/remote.go @@ -0,0 +1,84 @@ +package cwhub + +import ( + "context" + "fmt" + "net/url" + + "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/go-cs-lib/downloader" +) + +// RemoteHubCfg is used to retrieve index and items from the remote hub. +type RemoteHubCfg struct { + Branch string + URLTemplate string + IndexPath string + EmbedItemContent bool +} + +// urlTo builds the URL to download a file from the remote hub. +func (r *RemoteHubCfg) urlTo(remotePath string) (string, error) { + if r == nil { + return "", ErrNilRemoteHub + } + + // the template must contain two string placeholders + if fmt.Sprintf(r.URLTemplate, "%s", "%s") != r.URLTemplate { + return "", fmt.Errorf("invalid URL template '%s'", r.URLTemplate) + } + + return fmt.Sprintf(r.URLTemplate, r.Branch, remotePath), nil +} + +// addURLParam adds the "with_content=true" parameter to the URL if it's not already present. +func addURLParam(rawURL string, param string, value string) (string, error) { + parsedURL, err := url.Parse(rawURL) + if err != nil { + return "", fmt.Errorf("failed to parse URL: %w", err) + } + + query := parsedURL.Query() + + if _, exists := query[param]; !exists { + query.Add(param, value) + } + + parsedURL.RawQuery = query.Encode() + + return parsedURL.String(), nil +} + +// fetchIndex downloads the index from the hub and returns the content. +func (r *RemoteHubCfg) fetchIndex(ctx context.Context, destPath string) (bool, error) { + if r == nil { + return false, ErrNilRemoteHub + } + + url, err := r.urlTo(r.IndexPath) + if err != nil { + return false, fmt.Errorf("failed to build hub index request: %w", err) + } + + if r.EmbedItemContent { + url, err = addURLParam(url, "with_content", "true") + if err != nil { + return false, fmt.Errorf("failed to add 'with_content' parameter to URL: %w", err) + } + } + + downloaded, err := downloader. + New(). + WithHTTPClient(hubClient). + ToFile(destPath). + WithETagFn(downloader.SHA256). + CompareContent(). + WithLogger(logrus.WithField("url", url)). + Download(ctx, url) + if err != nil { + return false, err + } + + return downloaded, nil +} diff --git a/pkg/cwhub/sync.go b/pkg/cwhub/sync.go new file mode 100644 index 00000000000..c82822e64ef --- /dev/null +++ b/pkg/cwhub/sync.go @@ -0,0 +1,577 @@ +package cwhub + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "slices" + "sort" + "strings" + + "github.com/Masterminds/semver/v3" + "github.com/sirupsen/logrus" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/go-cs-lib/downloader" +) + +func isYAMLFileName(path string) bool { + return strings.HasSuffix(path, ".yaml") || strings.HasSuffix(path, ".yml") +} + +// resolveSymlink returns the ultimate target path of a symlink +// returns error if the symlink is dangling or too many symlinks are followed +func resolveSymlink(path string) (string, error) { + const maxSymlinks = 10 // Prevent infinite loops + for range maxSymlinks { + fi, err := os.Lstat(path) + if err != nil { + return "", err // dangling link + } + + if fi.Mode()&os.ModeSymlink == 0 { + // found the target + return path, nil + } + + path, err = os.Readlink(path) + if err != nil { + return "", err + } + + // relative to the link's directory? + if !filepath.IsAbs(path) { + path = filepath.Join(filepath.Dir(path), path) + } + } + + return "", errors.New("too many levels of symbolic links") +} + +// isPathInside checks if a path is inside the given directory +// it can return false negatives if the filesystem is case insensitive +func isPathInside(path, dir string) (bool, error) { + absFilePath, err := filepath.Abs(path) + if err != nil { + return false, err + } + + absDir, err := filepath.Abs(dir) + if err != nil { + return false, err + } + + return strings.HasPrefix(absFilePath, absDir), nil +} + +// information used to create a new Item, from a file path. +type itemFileInfo struct { + fname string + stage string + ftype string + fauthor string + inhub bool +} + +func (h *Hub) getItemFileInfo(path string, logger *logrus.Logger) (*itemFileInfo, error) { + var ret *itemFileInfo + + hubDir := h.local.HubDir + installDir := h.local.InstallDir + + subsHub := relativePathComponents(path, hubDir) + subsInstall := relativePathComponents(path, installDir) + + switch { + case len(subsHub) > 0: + logger.Tracef("in hub dir") + + // .../hub/parsers/s00-raw/crowdsecurity/skip-pretag.yaml + // .../hub/scenarios/crowdsecurity/ssh_bf.yaml + // .../hub/profiles/crowdsecurity/linux.yaml + if len(subsHub) < 3 { + return nil, fmt.Errorf("path is too short: %s (%d)", path, len(subsHub)) + } + + ftype := subsHub[0] + if !slices.Contains(ItemTypes, ftype) { + // this doesn't really happen anymore, because we only scan the {hubtype} directories + return nil, fmt.Errorf("unknown configuration type '%s'", ftype) + } + + stage := "" + fauthor := subsHub[1] + fname := subsHub[2] + + if ftype == PARSERS || ftype == POSTOVERFLOWS { + stage = subsHub[1] + fauthor = subsHub[2] + fname = subsHub[3] + } + + ret = &itemFileInfo{ + inhub: true, + ftype: ftype, + stage: stage, + fauthor: fauthor, + fname: fname, + } + + case len(subsInstall) > 0: + logger.Tracef("in install dir") + + // .../config/parser/stage/file.yaml + // .../config/postoverflow/stage/file.yaml + // .../config/scenarios/scenar.yaml + // .../config/collections/linux.yaml //file is empty + + if len(subsInstall) < 2 { + return nil, fmt.Errorf("path is too short: %s (%d)", path, len(subsInstall)) + } + + // this can be in any number of subdirs, we join them to compose the item name + + ftype := subsInstall[0] + stage := "" + fname := strings.Join(subsInstall[1:], "/") + + if ftype == PARSERS || ftype == POSTOVERFLOWS { + stage = subsInstall[1] + fname = strings.Join(subsInstall[2:], "/") + } + + ret = &itemFileInfo{ + inhub: false, + ftype: ftype, + stage: stage, + fauthor: "", + fname: fname, + } + default: + return nil, fmt.Errorf("file '%s' is not from hub '%s' nor from the configuration directory '%s'", path, hubDir, installDir) + } + + logger.Tracef("CORRECTED [%s] by [%s] in stage [%s] of type [%s]", ret.fname, ret.fauthor, ret.stage, ret.ftype) + + return ret, nil +} + +// sortedVersions returns the input data, sorted in reverse order (new, old) by semver. +func sortedVersions(raw []string) ([]string, error) { + vs := make([]*semver.Version, len(raw)) + + for idx, r := range raw { + v, err := semver.NewVersion(r) + if err != nil { + return nil, fmt.Errorf("%s: %w", r, err) + } + + vs[idx] = v + } + + sort.Sort(sort.Reverse(semver.Collection(vs))) + + ret := make([]string, len(vs)) + for idx, v := range vs { + ret[idx] = v.Original() + } + + return ret, nil +} + +func newLocalItem(h *Hub, path string, info *itemFileInfo) (*Item, error) { + type localItemName struct { + Name string `yaml:"name"` + } + + _, fileName := filepath.Split(path) + + item := &Item{ + hub: h, + Name: info.fname, + Stage: info.stage, + Type: info.ftype, + FileName: fileName, + State: ItemState{ + LocalPath: path, + Installed: true, + UpToDate: true, + }, + } + + // try to read the name from the file + itemName := localItemName{} + + itemContent, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read %s: %w", path, err) + } + + err = yaml.Unmarshal(itemContent, &itemName) + if err != nil { + return nil, fmt.Errorf("failed to parse %s: %w", path, err) + } + + if itemName.Name != "" { + item.Name = itemName.Name + } + + return item, nil +} + +func (h *Hub) itemVisit(path string, f os.DirEntry, err error) error { + if err != nil { + h.logger.Debugf("while syncing hub dir: %s", err) + // there is a path error, we ignore the file + return nil + } + + // only happens if the current working directory was removed (!) + path, err = filepath.Abs(path) + if err != nil { + return err + } + + // permission errors, files removed while reading, etc. + if f == nil { + return nil + } + + if f.IsDir() { + // if a directory starts with a dot, we don't traverse it + // - single dot prefix is hidden by unix convention + // - double dot prefix is used by k8s to mount config maps + if strings.HasPrefix(f.Name(), ".") { + h.logger.Tracef("skipping hidden directory %s", path) + return filepath.SkipDir + } + + // keep traversing + return nil + } + + // we only care about YAML files + if !isYAMLFileName(f.Name()) { + return nil + } + + info, err := h.getItemFileInfo(path, h.logger) + if err != nil { + h.logger.Warningf("Ignoring file %s: %s", path, err) + return nil + } + + // follow the link to see if it falls in the hub directory + // if it's not a link, target == path + target, err := resolveSymlink(path) + if err != nil { + // target does not exist, the user might have removed the file + // or switched to a hub branch without it; or symlink loop + h.logger.Warningf("Ignoring file %s: %s", path, err) + return nil + } + + targetInHub, err := isPathInside(target, h.local.HubDir) + if err != nil { + h.logger.Warningf("Ignoring file %s: %s", path, err) + return nil + } + + // local (custom) item if the file or link target is not inside the hub dir + if !targetInHub { + h.logger.Tracef("%s is a local file, skip", path) + + item, err := newLocalItem(h, path, info) + if err != nil { + return err + } + + h.addItem(item) + + return nil + } + + hubpath := target + + // try to find which configuration item it is + h.logger.Tracef("check [%s] of %s", info.fname, info.ftype) + + for _, item := range h.GetItemMap(info.ftype) { + if info.fname != item.FileName { + continue + } + + if item.Stage != info.stage { + continue + } + + // if we are walking hub dir, just mark present files as downloaded + if info.inhub { + // wrong author + if info.fauthor != item.Author { + continue + } + + // not the item we're looking for + if !item.validPath(info.fauthor, info.fname) { + continue + } + + src, err := item.downloadPath() + if err != nil { + return err + } + + if path == src { + h.logger.Tracef("marking %s as downloaded", item.Name) + item.State.Downloaded = true + } + } else if !hasPathSuffix(hubpath, item.RemotePath) { + // wrong file + // ///.yaml + continue + } + + err := item.setVersionState(path, info.inhub) + if err != nil { + return err + } + + h.pathIndex[path] = item + + return nil + } + + h.logger.Infof("Ignoring file %s of type %s", path, info.ftype) + + return nil +} + +// checkSubItemVersions checks for the presence, taint and version state of sub-items. +func (i *Item) checkSubItemVersions() []string { + warn := make([]string, 0) + + if !i.HasSubItems() { + return warn + } + + if i.versionStatus() != versionUpToDate { + i.hub.logger.Debugf("%s dependencies not checked: not up-to-date", i.Name) + return warn + } + + // ensure all the sub-items are installed, or tag the parent as tainted + i.hub.logger.Tracef("checking submembers of %s installed:%t", i.Name, i.State.Installed) + + for _, sub := range i.SubItems() { + i.hub.logger.Tracef("check %s installed:%t", sub.Name, sub.State.Installed) + + if !i.State.Installed { + continue + } + + if w := sub.checkSubItemVersions(); len(w) > 0 { + if sub.State.Tainted { + i.addTaint(sub) + warn = append(warn, fmt.Sprintf("%s is tainted by %s", i.Name, sub.FQName())) + } + + warn = append(warn, w...) + + continue + } + + if sub.State.Tainted { + i.addTaint(sub) + warn = append(warn, fmt.Sprintf("%s is tainted by %s", i.Name, sub.FQName())) + + continue + } + + if !sub.State.Installed && i.State.Installed { + i.addTaint(sub) + warn = append(warn, fmt.Sprintf("%s is tainted by missing %s", i.Name, sub.FQName())) + + continue + } + + if !sub.State.UpToDate { + i.State.UpToDate = false + warn = append(warn, fmt.Sprintf("%s is tainted by outdated %s", i.Name, sub.FQName())) + + continue + } + + i.hub.logger.Tracef("checking for %s - tainted:%t uptodate:%t", sub.Name, i.State.Tainted, i.State.UpToDate) + } + + return warn +} + +// syncDir scans a directory for items, and updates the Hub state accordingly. +func (h *Hub) syncDir(dir string) error { + // For each, scan PARSERS, POSTOVERFLOWS... and COLLECTIONS last + for _, scan := range ItemTypes { + // cpath: top-level item directory, either downloaded or installed items. + // i.e. /etc/crowdsec/parsers, /etc/crowdsec/hub/parsers, ... + cpath, err := filepath.Abs(fmt.Sprintf("%s/%s", dir, scan)) + if err != nil { + h.logger.Errorf("failed %s: %s", cpath, err) + continue + } + + // explicit check for non existing directory, avoid spamming log.Debug + if _, err = os.Stat(cpath); os.IsNotExist(err) { + h.logger.Tracef("directory %s doesn't exist, skipping", cpath) + continue + } + + if err = filepath.WalkDir(cpath, h.itemVisit); err != nil { + return err + } + } + + return nil +} + +// insert a string in a sorted slice, case insensitive, and return the new slice. +func insertInOrderNoCase(sl []string, value string) []string { + i := sort.Search(len(sl), func(i int) bool { + return strings.ToLower(sl[i]) >= strings.ToLower(value) + }) + + return append(sl[:i], append([]string{value}, sl[i:]...)...) +} + +func removeDuplicates(sl []string) []string { + seen := make(map[string]struct{}, len(sl)) + j := 0 + + for _, v := range sl { + if _, ok := seen[v]; ok { + continue + } + + seen[v] = struct{}{} + sl[j] = v + j++ + } + + return sl[:j] +} + +// localSync updates the hub state with downloaded, installed and local items. +func (h *Hub) localSync() error { + err := h.syncDir(h.local.InstallDir) + if err != nil { + return fmt.Errorf("failed to scan %s: %w", h.local.InstallDir, err) + } + + if err = h.syncDir(h.local.HubDir); err != nil { + return fmt.Errorf("failed to scan %s: %w", h.local.HubDir, err) + } + + warnings := make([]string, 0) + + for _, item := range h.GetItemMap(COLLECTIONS) { + // check for cyclic dependencies + subs, err := item.descendants() + if err != nil { + return err + } + + // populate the sub- and sub-sub-items with the collections they belong to + for _, sub := range subs { + sub.State.BelongsToCollections = insertInOrderNoCase(sub.State.BelongsToCollections, item.Name) + } + + if !item.State.Installed { + continue + } + + vs := item.versionStatus() + switch vs { + case versionUpToDate: // latest + if w := item.checkSubItemVersions(); len(w) > 0 { + warnings = append(warnings, w...) + } + case versionUpdateAvailable: // not up-to-date + warnings = append(warnings, fmt.Sprintf("update for collection %s available (currently:%s, latest:%s)", item.Name, item.State.LocalVersion, item.Version)) + case versionFuture: + warnings = append(warnings, fmt.Sprintf("collection %s is in the future (currently:%s, latest:%s)", item.Name, item.State.LocalVersion, item.Version)) + case versionUnknown: + if !item.State.IsLocal() { + warnings = append(warnings, fmt.Sprintf("collection %s is tainted by local changes (latest:%s)", item.Name, item.Version)) + } + } + + h.logger.Debugf("installed (%s) - status: %d | installed: %s | latest: %s | full: %+v", item.Name, vs, item.State.LocalVersion, item.Version, item.Versions) + } + + h.Warnings = removeDuplicates(warnings) + + return nil +} + +func (i *Item) setVersionState(path string, inhub bool) error { + var err error + + i.State.LocalHash, err = downloader.SHA256(path) + if err != nil { + return fmt.Errorf("failed to get sha256 of %s: %w", path, err) + } + + // let's reverse sort the versions to deal with hash collisions (#154) + versions := make([]string, 0, len(i.Versions)) + for k := range i.Versions { + versions = append(versions, k) + } + + versions, err = sortedVersions(versions) + if err != nil { + return fmt.Errorf("while syncing %s %s: %w", i.Type, i.FileName, err) + } + + i.State.LocalVersion = "?" + + for _, version := range versions { + if i.Versions[version].Digest == i.State.LocalHash { + i.State.LocalVersion = version + break + } + } + + if i.State.LocalVersion == "?" { + i.hub.logger.Tracef("got tainted match for %s: %s", i.Name, path) + + if !inhub { + i.State.LocalPath = path + i.State.Installed = true + } + + i.State.UpToDate = false + i.addTaint(i) + + return nil + } + + // we got an exact match, update struct + + i.State.Downloaded = true + + if !inhub { + i.hub.logger.Tracef("found exact match for %s, version is %s, latest is %s", i.Name, i.State.LocalVersion, i.Version) + i.State.LocalPath = path + i.State.Tainted = false + // if we're walking the hub, present file doesn't means installed file + i.State.Installed = true + } + + if i.State.LocalVersion == i.Version { + i.hub.logger.Tracef("%s is up-to-date", i.Name) + i.State.UpToDate = true + } + + return nil +} diff --git a/pkg/cwhub/tests/collection_v1.yaml b/pkg/cwhub/testdata/collection_v1.yaml similarity index 100% rename from pkg/cwhub/tests/collection_v1.yaml rename to pkg/cwhub/testdata/collection_v1.yaml diff --git a/pkg/cwhub/tests/collection_v2.yaml b/pkg/cwhub/testdata/collection_v2.yaml similarity index 100% rename from pkg/cwhub/tests/collection_v2.yaml rename to pkg/cwhub/testdata/collection_v2.yaml diff --git a/pkg/cwhub/tests/foobar_parser.yaml b/pkg/cwhub/testdata/foobar_parser.yaml similarity index 100% rename from pkg/cwhub/tests/foobar_parser.yaml rename to pkg/cwhub/testdata/foobar_parser.yaml diff --git a/pkg/cwhub/tests/index1.json b/pkg/cwhub/testdata/index1.json similarity index 93% rename from pkg/cwhub/tests/index1.json rename to pkg/cwhub/testdata/index1.json index a7e6ef6153b..59548bda379 100644 --- a/pkg/cwhub/tests/index1.json +++ b/pkg/cwhub/testdata/index1.json @@ -10,7 +10,6 @@ } }, "long_description": "bG9uZyBkZXNjcmlwdGlvbgo=", - "content": "bG9uZyBkZXNjcmlwdGlvbgo=", "description": "foobar collection : foobar", "author": "crowdsecurity", "labels": null, @@ -34,7 +33,6 @@ } }, "long_description": "bG9uZyBkZXNjcmlwdGlvbgo=", - "content": "bG9uZyBkZXNjcmlwdGlvbgo=", "description": "test_collection : foobar", "author": "crowdsecurity", "labels": null, @@ -52,7 +50,6 @@ } }, "long_description": "bG9uZyBkZXNjcmlwdGlvbgo=", - "content": "bG9uZyBkZXNjcmlwdGlvbgo=", "description": "foobar collection : foobar", "author": "crowdsecurity", "labels": null, @@ -73,7 +70,6 @@ } }, "long_description": "bG9uZyBkZXNjcmlwdGlvbgo=", - "content": "bG9uZyBkZXNjcmlwdGlvbgo=", "description": "A foobar parser", "author": "crowdsecurity", "labels": null @@ -89,7 +85,6 @@ } }, "long_description": "bG9uZyBkZXNjcmlwdGlvbgo=", - "content": "bG9uZyBkZXNjcmlwdGlvbgo=", "description": "A foobar parser", "author": "crowdsecurity", "labels": null @@ -107,7 +102,6 @@ } }, "long_description": "bG9uZyBkZXNjcmlwdGlvbgo=", - "content": "bG9uZyBkZXNjcmlwdGlvbgo=", "description": "a foobar scenario", "author": "crowdsecurity", "labels": { @@ -118,4 +112,4 @@ } } } -} \ No newline at end of file +} diff --git a/pkg/cwhub/tests/index2.json b/pkg/cwhub/testdata/index2.json similarity index 93% rename from pkg/cwhub/tests/index2.json rename to pkg/cwhub/testdata/index2.json index 7f97ebf2308..41c4ccba83a 100644 --- a/pkg/cwhub/tests/index2.json +++ b/pkg/cwhub/testdata/index2.json @@ -10,7 +10,6 @@ } }, "long_description": "bG9uZyBkZXNjcmlwdGlvbgo=", - "content": "bG9uZyBkZXNjcmlwdGlvbgo=", "description": "foobar collection : foobar", "author": "crowdsecurity", "labels": null, @@ -38,7 +37,6 @@ } }, "long_description": "bG9uZyBkZXNjcmlwdGlvbgo=", - "content": "bG9uZyBkZXNjcmlwdGlvbgo=", "description": "test_collection : foobar", "author": "crowdsecurity", "labels": null, @@ -57,7 +55,6 @@ } }, "long_description": "bG9uZyBkZXNjcmlwdGlvbgo=", - "content": "bG9uZyBkZXNjcmlwdGlvbgo=", "description": "foobar collection : foobar", "author": "crowdsecurity", "labels": null, @@ -78,7 +75,6 @@ } }, "long_description": "bG9uZyBkZXNjcmlwdGlvbgo=", - "content": "bG9uZyBkZXNjcmlwdGlvbgo=", "description": "A foobar parser", "author": "crowdsecurity", "labels": null @@ -94,7 +90,6 @@ } }, "long_description": "bG9uZyBkZXNjcmlwdGlvbgo=", - "content": "bG9uZyBkZXNjcmlwdGlvbgo=", "description": "A foobar parser", "author": "crowdsecurity", "labels": null @@ -112,7 +107,6 @@ } }, "long_description": "bG9uZyBkZXNjcmlwdGlvbgo=", - "content": "bG9uZyBkZXNjcmlwdGlvbgo=", "description": "a foobar scenario", "author": "crowdsecurity", "labels": { @@ -132,7 +126,6 @@ } }, "long_description": "bG9uZyBkZXNjcmlwdGlvbgo=", - "content": "bG9uZyBkZXNjcmlwdGlvbgo=", "description": "a foobar scenario", "author": "crowdsecurity", "labels": { @@ -143,4 +136,4 @@ } } } -} \ No newline at end of file +} diff --git a/pkg/cwversion/component/component.go b/pkg/cwversion/component/component.go new file mode 100644 index 00000000000..4036b63cf00 --- /dev/null +++ b/pkg/cwversion/component/component.go @@ -0,0 +1,34 @@ +package component + +// Package component provides functionality for managing the registration of +// optional, compile-time components in the system. This is meant as a space +// saving measure, separate from feature flags (package pkg/fflag) which are +// only enabled/disabled at runtime. + +// Built is a map of all the known components, and whether they are built-in or not. +// This is populated as soon as possible by the respective init() functions +var Built = map[string]bool { + "datasource_appsec": false, + "datasource_cloudwatch": false, + "datasource_docker": false, + "datasource_file": false, + "datasource_journalctl": false, + "datasource_k8s-audit": false, + "datasource_kafka": false, + "datasource_kinesis": false, + "datasource_loki": false, + "datasource_s3": false, + "datasource_syslog": false, + "datasource_wineventlog":false, + "cscli_setup": false, +} + +func Register(name string) { + if _, ok := Built[name]; !ok { + // having a list of the disabled components is essential + // to debug users' issues + panic("cannot register unknown compile-time component: " + name) + } + + Built[name] = true +} diff --git a/pkg/cwversion/constraint/constraint.go b/pkg/cwversion/constraint/constraint.go new file mode 100644 index 00000000000..67593f9ebbc --- /dev/null +++ b/pkg/cwversion/constraint/constraint.go @@ -0,0 +1,32 @@ +package constraint + +import ( + "fmt" + + goversion "github.com/hashicorp/go-version" +) + +const ( + Parser = ">= 1.0, <= 3.0" + Scenario = ">= 1.0, <= 3.0" + API = "v1" + Acquis = ">= 1.0, < 2.0" +) + +func Satisfies(strvers string, constraint string) (bool, error) { + vers, err := goversion.NewVersion(strvers) + if err != nil { + return false, fmt.Errorf("failed to parse '%s': %w", strvers, err) + } + + constraints, err := goversion.NewConstraint(constraint) + if err != nil { + return false, fmt.Errorf("failed to parse constraint '%s'", constraint) + } + + if !constraints.Check(vers) { + return false, nil + } + + return true, nil +} diff --git a/pkg/cwversion/version.go b/pkg/cwversion/version.go index aeac6f2f22c..2cb7de13e18 100644 --- a/pkg/cwversion/version.go +++ b/pkg/cwversion/version.go @@ -1,93 +1,66 @@ package cwversion import ( - "encoding/json" "fmt" - "log" - "net/http" - "runtime" "strings" - goversion "github.com/hashicorp/go-version" - - "github.com/crowdsecurity/go-cs-lib/pkg/version" + "github.com/crowdsecurity/go-cs-lib/maptools" + "github.com/crowdsecurity/go-cs-lib/version" + + "github.com/crowdsecurity/crowdsec/pkg/apiclient/useragent" + "github.com/crowdsecurity/crowdsec/pkg/cwversion/component" + "github.com/crowdsecurity/crowdsec/pkg/cwversion/constraint" ) var ( - Codename string // = "SoumSoum" - System = runtime.GOOS // = "linux" - Constraint_parser = ">= 1.0, <= 2.0" - Constraint_scenario = ">= 1.0, < 3.0" - Constraint_api = "v1" - Constraint_acquis = ">= 1.0, < 2.0" - Libre2 = "WebAssembly" + Codename string // = "SoumSoum" + Libre2 = "WebAssembly" ) -func ShowStr() string { - ret := "" - ret += fmt.Sprintf("version: %s-%s\n", version.Version, version.Tag) +func FullString() string { + dsBuilt := map[string]struct{}{} + dsExcluded := map[string]struct{}{} + + for ds, built := range component.Built { + if built { + dsBuilt[ds] = struct{}{} + continue + } + + dsExcluded[ds] = struct{}{} + } + + ret := fmt.Sprintf("version: %s\n", version.String()) ret += fmt.Sprintf("Codename: %s\n", Codename) ret += fmt.Sprintf("BuildDate: %s\n", version.BuildDate) ret += fmt.Sprintf("GoVersion: %s\n", version.GoVersion) - ret += fmt.Sprintf("Platform: %s\n", System) - return ret -} - -func Show() { - log.Printf("version: %s-%s", version.Version, version.Tag) - log.Printf("Codename: %s", Codename) - log.Printf("BuildDate: %s", version.BuildDate) - log.Printf("GoVersion: %s", version.GoVersion) - log.Printf("Platform: %s\n", System) - log.Printf("libre2: %s\n", Libre2) - log.Printf("Constraint_parser: %s", Constraint_parser) - log.Printf("Constraint_scenario: %s", Constraint_scenario) - log.Printf("Constraint_api: %s", Constraint_api) - log.Printf("Constraint_acquis: %s", Constraint_acquis) -} - -func VersionStr() string { - return fmt.Sprintf("%s-%s-%s", version.Version, System, version.Tag) -} + ret += fmt.Sprintf("Platform: %s\n", version.System) + ret += fmt.Sprintf("libre2: %s\n", Libre2) + ret += fmt.Sprintf("User-Agent: %s\n", useragent.Default()) + ret += fmt.Sprintf("Constraint_parser: %s\n", constraint.Parser) + ret += fmt.Sprintf("Constraint_scenario: %s\n", constraint.Scenario) + ret += fmt.Sprintf("Constraint_api: %s\n", constraint.API) + ret += fmt.Sprintf("Constraint_acquis: %s\n", constraint.Acquis) -func VersionStrip() string { - version := strings.Split(version.Version, "~") - version = strings.Split(version[0], "-") - return version[0] -} + built := "(none)" -func Satisfies(strvers string, constraint string) (bool, error) { - vers, err := goversion.NewVersion(strvers) - if err != nil { - return false, fmt.Errorf("failed to parse '%s' : %v", strvers, err) - } - constraints, err := goversion.NewConstraint(constraint) - if err != nil { - return false, fmt.Errorf("failed to parse constraint '%s'", constraint) + if len(dsBuilt) > 0 { + built = strings.Join(maptools.SortedKeys(dsBuilt), ", ") } - if !constraints.Check(vers) { - return false, nil - } - return true, nil -} -// Latest return latest crowdsec version based on github -func Latest() (string, error) { - latest := make(map[string]interface{}) + ret += fmt.Sprintf("Built-in optional components: %s\n", built) - resp, err := http.Get("https://version.crowdsec.net/latest") - if err != nil { - return "", err + if len(dsExcluded) > 0 { + ret += fmt.Sprintf("Excluded components: %s\n", strings.Join(maptools.SortedKeys(dsExcluded), ", ")) } - defer resp.Body.Close() - err = json.NewDecoder(resp.Body).Decode(&latest) - if err != nil { - return "", err - } - if _, ok := latest["name"]; !ok { - return "", fmt.Errorf("unable to find latest release name from github api: %+v", latest) - } + return ret +} + +// VersionStrip remove the tag from the version string, used to match with a hub branch +func VersionStrip() string { + ret := strings.Split(version.Version, "~") + ret = strings.Split(ret[0], "-") - return latest["name"].(string), nil + return ret[0] } diff --git a/pkg/database/alerts.go b/pkg/database/alerts.go index afdf51688d3..ede9c89fe9a 100644 --- a/pkg/database/alerts.go +++ b/pkg/database/alerts.go @@ -9,17 +9,16 @@ import ( "strings" "time" - "github.com/davecgh/go-spew/spew" + "github.com/mattn/go-sqlite3" "github.com/pkg/errors" log "github.com/sirupsen/logrus" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/go-cs-lib/slicetools" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" - "github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" "github.com/crowdsecurity/crowdsec/pkg/database/ent/event" - "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" "github.com/crowdsecurity/crowdsec/pkg/database/ent/meta" "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" "github.com/crowdsecurity/crowdsec/pkg/models" @@ -27,144 +26,62 @@ import ( ) const ( - paginationSize = 100 // used to queryAlert to avoid 'too many SQL variable' - defaultLimit = 100 // default limit of element to returns when query alerts - bulkSize = 50 // bulk size when create alerts - decisionBulkSize = 50 + paginationSize = 100 // used to queryAlert to avoid 'too many SQL variable' + defaultLimit = 100 // default limit of element to returns when query alerts + alertCreateBulkSize = 50 // bulk size when create alerts + maxLockRetries = 10 // how many times to retry a bulk operation when sqlite3.ErrBusy is encountered ) -func formatAlertCN(source models.Source) string { - cn := source.Cn - - if source.AsNumber != "" { - cn += "/" + source.AsNumber - } - - return cn -} - -func formatAlertSource(alert *models.Alert) string { - if alert.Source == nil { - return "empty source" - } - - if *alert.Source.Scope == types.Ip { - ret := "ip " + *alert.Source.Value - cn := formatAlertCN(*alert.Source) - if cn != "" { - ret += " (" + cn + ")" - } - return ret - } - - if *alert.Source.Scope == types.Range { - ret := "range " + *alert.Source.Value - cn := formatAlertCN(*alert.Source) - if cn != "" { - ret += " (" + cn + ")" - } - return ret - } - - return *alert.Source.Scope + " " + *alert.Source.Value -} - -func formatAlertAsString(machineId string, alert *models.Alert) []string { - src := formatAlertSource(alert) - - /**/ - msg := "" - if alert.Scenario != nil && *alert.Scenario != "" { - msg = *alert.Scenario - } else if alert.Message != nil && *alert.Message != "" { - msg = *alert.Message - } else { - msg = "empty scenario" - } - - reason := fmt.Sprintf("%s by %s", msg, src) - - if len(alert.Decisions) == 0 { - return []string{fmt.Sprintf("(%s) alert : %s", machineId, reason)} - } - - var retStr []string - - for i, decisionItem := range alert.Decisions { - decision := "" - if alert.Simulated != nil && *alert.Simulated { - decision = "(simulated alert)" - } else if decisionItem.Simulated != nil && *decisionItem.Simulated { - decision = "(simulated decision)" - } - if log.GetLevel() >= log.DebugLevel { - /*spew is expensive*/ - log.Debugf("%s", spew.Sdump(decisionItem)) - } - if len(alert.Decisions) > 1 { - reason = fmt.Sprintf("%s for %d/%d decisions", msg, i+1, len(alert.Decisions)) - } - machineIdOrigin := "" - if machineId == "" { - machineIdOrigin = *decisionItem.Origin - } else { - machineIdOrigin = fmt.Sprintf("%s/%s", machineId, *decisionItem.Origin) - } - - decision += fmt.Sprintf("%s %s on %s %s", *decisionItem.Duration, - *decisionItem.Type, *decisionItem.Scope, *decisionItem.Value) - retStr = append(retStr, - fmt.Sprintf("(%s) %s : %s", machineIdOrigin, reason, decision)) - } - return retStr -} - // CreateOrUpdateAlert is specific to PAPI : It checks if alert already exists, otherwise inserts it // if alert already exists, it checks it associated decisions already exists // if some associated decisions are missing (ie. previous insert ended up in error) it inserts them -func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) (string, error) { - +func (c *Client) CreateOrUpdateAlert(ctx context.Context, machineID string, alertItem *models.Alert) (string, error) { if alertItem.UUID == "" { - return "", fmt.Errorf("alert UUID is empty") + return "", errors.New("alert UUID is empty") } - alerts, err := c.Ent.Alert.Query().Where(alert.UUID(alertItem.UUID)).WithDecisions().All(c.CTX) + alerts, err := c.Ent.Alert.Query().Where(alert.UUID(alertItem.UUID)).WithDecisions().All(ctx) if err != nil && !ent.IsNotFound(err) { return "", fmt.Errorf("unable to query alerts for uuid %s: %w", alertItem.UUID, err) } - //alert wasn't found, insert it (expected hotpath) + // alert wasn't found, insert it (expected hotpath) if ent.IsNotFound(err) || len(alerts) == 0 { - ret, err := c.CreateAlert(machineID, []*models.Alert{alertItem}) + alertIDs, err := c.CreateAlert(ctx, machineID, []*models.Alert{alertItem}) if err != nil { return "", fmt.Errorf("unable to create alert: %w", err) } - return ret[0], nil + + return alertIDs[0], nil } - //this should never happen + // this should never happen if len(alerts) > 1 { return "", fmt.Errorf("multiple alerts found for uuid %s", alertItem.UUID) } log.Infof("Alert %s already exists, checking associated decisions", alertItem.UUID) - //alert is found, check for any missing decisions - missingUuids := []string{} - newUuids := []string{} - for _, decItem := range alertItem.Decisions { - newUuids = append(newUuids, decItem.UUID) + + // alert is found, check for any missing decisions + + newUuids := make([]string, len(alertItem.Decisions)) + for i, decItem := range alertItem.Decisions { + newUuids[i] = decItem.UUID } foundAlert := alerts[0] - foundUuids := []string{} - for _, decItem := range foundAlert.Edges.Decisions { - foundUuids = append(foundUuids, decItem.UUID) + foundUuids := make([]string, len(foundAlert.Edges.Decisions)) + + for i, decItem := range foundAlert.Edges.Decisions { + foundUuids[i] = decItem.UUID } sort.Strings(foundUuids) sort.Strings(newUuids) + missingUuids := []string{} + for idx, uuid := range newUuids { if len(foundUuids) < idx+1 || uuid != foundUuids[idx] { log.Warningf("Decision with uuid %s not found in alert %s", uuid, foundAlert.UUID) @@ -177,9 +94,10 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) return "", nil } - //add any and all missing decisions based on their uuids - //prepare missing decisions + // add any and all missing decisions based on their uuids + // prepare missing decisions missingDecisions := []*models.Decision{} + for _, uuid := range missingUuids { for _, newDecision := range alertItem.Decisions { if newDecision.UUID == uuid { @@ -188,37 +106,43 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) } } - //add missing decisions + // add missing decisions log.Debugf("Adding %d missing decisions to alert %s", len(missingDecisions), foundAlert.UUID) - decisions := make([]*ent.Decision, 0) - decisionBulk := make([]*ent.DecisionCreate, 0, decisionBulkSize) + decisionBuilders := []*ent.DecisionCreate{} - for i, decisionItem := range missingDecisions { - var start_ip, start_sfx, end_ip, end_sfx int64 - var sz int + for _, decisionItem := range missingDecisions { + var ( + start_ip, start_sfx, end_ip, end_sfx int64 + sz int + ) /*if the scope is IP or Range, convert the value to integers */ if strings.ToLower(*decisionItem.Scope) == "ip" || strings.ToLower(*decisionItem.Scope) == "range" { sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(*decisionItem.Value) if err != nil { - return "", errors.Wrapf(InvalidIPOrRange, "invalid addr/range %s : %s", *decisionItem.Value, err) + log.Errorf("invalid addr/range '%s': %s", *decisionItem.Value, err) + continue } } + decisionDuration, err := time.ParseDuration(*decisionItem.Duration) if err != nil { log.Warningf("invalid duration %s for decision %s", *decisionItem.Duration, decisionItem.UUID) continue } - //use the created_at from the alert instead + + // use the created_at from the alert instead alertTime, err := time.Parse(time.RFC3339, alertItem.CreatedAt) if err != nil { log.Errorf("unable to parse alert time %s : %s", alertItem.CreatedAt, err) + alertTime = time.Now() } + decisionUntil := alertTime.UTC().Add(decisionDuration) - decisionCreate := c.Ent.Decision.Create(). + decisionBuilder := c.Ent.Decision.Create(). SetUntil(decisionUntil). SetScenario(*decisionItem.Scenario). SetType(*decisionItem.Type). @@ -233,80 +157,58 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) SetSimulated(*alertItem.Simulated). SetUUID(decisionItem.UUID) - decisionBulk = append(decisionBulk, decisionCreate) - if len(decisionBulk) == decisionBulkSize { - decisionsCreateRet, err := c.Ent.Decision.CreateBulk(decisionBulk...).Save(c.CTX) - if err != nil { - return "", errors.Wrapf(BulkError, "creating alert decisions: %s", err) + decisionBuilders = append(decisionBuilders, decisionBuilder) + } - } - decisions = append(decisions, decisionsCreateRet...) - if len(missingDecisions)-i <= decisionBulkSize { - decisionBulk = make([]*ent.DecisionCreate, 0, (len(missingDecisions) - i)) - } else { - decisionBulk = make([]*ent.DecisionCreate, 0, decisionBulkSize) - } + decisions := []*ent.Decision{} + + builderChunks := slicetools.Chunks(decisionBuilders, c.decisionBulkSize) + + for _, builderChunk := range builderChunks { + decisionsCreateRet, err := c.Ent.Decision.CreateBulk(builderChunk...).Save(ctx) + if err != nil { + return "", fmt.Errorf("creating alert decisions: %w", err) } - } - decisionsCreateRet, err := c.Ent.Decision.CreateBulk(decisionBulk...).Save(c.CTX) - if err != nil { - return "", errors.Wrapf(BulkError, "creating alert decisions: %s", err) - } - decisions = append(decisions, decisionsCreateRet...) - //now that we bulk created missing decisions, let's update the alert - err = c.Ent.Alert.Update().Where(alert.UUID(alertItem.UUID)).AddDecisions(decisions...).Exec(c.CTX) - if err != nil { - return "", fmt.Errorf("updating alert %s: %w", alertItem.UUID, err) + + decisions = append(decisions, decisionsCreateRet...) } - return "", nil + // now that we bulk created missing decisions, let's update the alert -} + decisionChunks := slicetools.Chunks(decisions, c.decisionBulkSize) -func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]string, error) { - pageStart := 0 - pageEnd := bulkSize - ret := []string{} - for { - if pageEnd >= len(alertList) { - results, err := c.CreateAlertBulk(machineID, alertList[pageStart:]) - if err != nil { - return []string{}, fmt.Errorf("unable to create alerts: %s", err) - } - ret = append(ret, results...) - break - } - results, err := c.CreateAlertBulk(machineID, alertList[pageStart:pageEnd]) + for _, decisionChunk := range decisionChunks { + err = c.Ent.Alert.Update().Where(alert.UUID(alertItem.UUID)).AddDecisions(decisionChunk...).Exec(ctx) if err != nil { - return []string{}, fmt.Errorf("unable to create alerts: %s", err) + return "", fmt.Errorf("updating alert %s: %w", alertItem.UUID, err) } - ret = append(ret, results...) - pageStart += bulkSize - pageEnd += bulkSize } - return ret, nil + + return "", nil } // UpdateCommunityBlocklist is called to update either the community blocklist (or other lists the user subscribed to) // it takes care of creating the new alert with the associated decisions, and it will as well deleted the "older" overlapping decisions: // 1st pull, you get decisions [1,2,3]. it inserts [1,2,3] // 2nd pull, you get decisions [1,2,3,4]. it inserts [1,2,3,4] and will try to delete [1,2,3,4] with a different alert ID and same origin -func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, int, error) { - var err error - +func (c *Client) UpdateCommunityBlocklist(ctx context.Context, alertItem *models.Alert) (int, int, int, error) { if alertItem == nil { - return 0, 0, 0, fmt.Errorf("nil alert") + return 0, 0, 0, errors.New("nil alert") } + if alertItem.StartAt == nil { - return 0, 0, 0, fmt.Errorf("nil start_at") + return 0, 0, 0, errors.New("nil start_at") } + startAtTime, err := time.Parse(time.RFC3339, *alertItem.StartAt) if err != nil { return 0, 0, 0, errors.Wrapf(ParseTimeFail, "start_at field time '%s': %s", *alertItem.StartAt, err) } + if alertItem.StopAt == nil { - return 0, 0, 0, fmt.Errorf("nil stop_at") + return 0, 0, 0, errors.New("nil stop_at") } + stopAtTime, err := time.Parse(time.RFC3339, *alertItem.StopAt) if err != nil { return 0, 0, 0, errors.Wrapf(ParseTimeFail, "stop_at field time '%s': %s", *alertItem.StopAt, err) @@ -315,6 +217,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in ts, err := time.Parse(time.RFC3339, *alertItem.StopAt) if err != nil { c.Log.Errorf("While parsing StartAt of item %s : %s", *alertItem.StopAt, err) + ts = time.Now().UTC() } @@ -338,9 +241,10 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in SetLeakSpeed(*alertItem.Leakspeed). SetSimulated(*alertItem.Simulated). SetScenarioVersion(*alertItem.ScenarioVersion). - SetScenarioHash(*alertItem.ScenarioHash) + SetScenarioHash(*alertItem.ScenarioHash). + SetRemediation(true) // it's from CAPI, we always have decisions - alertRef, err := alertB.Save(c.CTX) + alertRef, err := alertB.Save(ctx) if err != nil { return 0, 0, 0, errors.Wrapf(BulkError, "error creating alert : %s", err) } @@ -349,13 +253,13 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in return alertRef.ID, 0, 0, nil } - txClient, err := c.Ent.Tx(c.CTX) + txClient, err := c.Ent.Tx(ctx) if err != nil { return 0, 0, 0, errors.Wrapf(BulkError, "error creating transaction : %s", err) } - decisionBulk := make([]*ent.DecisionCreate, 0, decisionBulkSize) - valueList := make([]string, 0, decisionBulkSize) + DecOrigin := CapiMachineID + if *alertItem.Decisions[0].Origin == CapiMachineID || *alertItem.Decisions[0].Origin == CapiListsMachineID { DecOrigin = *alertItem.Decisions[0].Origin } else { @@ -365,25 +269,35 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in deleted := 0 inserted := 0 - for i, decisionItem := range alertItem.Decisions { - var start_ip, start_sfx, end_ip, end_sfx int64 - var sz int + decisionBuilders := make([]*ent.DecisionCreate, 0, len(alertItem.Decisions)) + valueList := make([]string, 0, len(alertItem.Decisions)) + + for _, decisionItem := range alertItem.Decisions { + var ( + start_ip, start_sfx, end_ip, end_sfx int64 + sz int + ) + if decisionItem.Duration == nil { log.Warning("nil duration in community decision") continue } + duration, err := time.ParseDuration(*decisionItem.Duration) if err != nil { rollbackErr := txClient.Rollback() if rollbackErr != nil { log.Errorf("rollback error: %s", rollbackErr) } + return 0, 0, 0, errors.Wrapf(ParseDurationFail, "decision duration '%+v' : %s", *decisionItem.Duration, err) } + if decisionItem.Scope == nil { log.Warning("nil scope in community decision") continue } + /*if the scope is IP or Range, convert the value to integers */ if strings.ToLower(*decisionItem.Scope) == "ip" || strings.ToLower(*decisionItem.Scope) == "range" { sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(*decisionItem.Value) @@ -392,11 +306,13 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in if rollbackErr != nil { log.Errorf("rollback error: %s", rollbackErr) } + return 0, 0, 0, errors.Wrapf(InvalidIPOrRange, "invalid addr/range %s : %s", *decisionItem.Value, err) } } + /*bulk insert some new decisions*/ - decisionBulk = append(decisionBulk, c.Ent.Decision.Create(). + decisionBuilder := c.Ent.Decision.Create(). SetUntil(ts.Add(duration)). SetScenario(*decisionItem.Scenario). SetType(*decisionItem.Type). @@ -409,171 +325,181 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in SetScope(*decisionItem.Scope). SetOrigin(*decisionItem.Origin). SetSimulated(*alertItem.Simulated). - SetOwner(alertRef)) + SetOwner(alertRef) + + decisionBuilders = append(decisionBuilders, decisionBuilder) /*for bulk delete of duplicate decisions*/ if decisionItem.Value == nil { log.Warning("nil value in community decision") continue } - valueList = append(valueList, *decisionItem.Value) - if len(decisionBulk) == decisionBulkSize { - - insertedDecisions, err := txClient.Decision.CreateBulk(decisionBulk...).Save(c.CTX) - if err != nil { - rollbackErr := txClient.Rollback() - if rollbackErr != nil { - log.Errorf("rollback error: %s", rollbackErr) - } - return 0, 0, 0, errors.Wrapf(BulkError, "bulk creating decisions : %s", err) - } - inserted += len(insertedDecisions) - - /*Deleting older decisions from capi*/ - deletedDecisions, err := txClient.Decision.Delete(). - Where(decision.And( - decision.OriginEQ(DecOrigin), - decision.Not(decision.HasOwnerWith(alert.IDEQ(alertRef.ID))), - decision.ValueIn(valueList...), - )).Exec(c.CTX) - if err != nil { - rollbackErr := txClient.Rollback() - if rollbackErr != nil { - log.Errorf("rollback error: %s", rollbackErr) - } - return 0, 0, 0, fmt.Errorf("while deleting older community blocklist decisions: %w", err) - } - deleted += deletedDecisions + valueList = append(valueList, *decisionItem.Value) + } - if len(alertItem.Decisions)-i <= decisionBulkSize { - decisionBulk = make([]*ent.DecisionCreate, 0, (len(alertItem.Decisions) - i)) - valueList = make([]string, 0, (len(alertItem.Decisions) - i)) - } else { - decisionBulk = make([]*ent.DecisionCreate, 0, decisionBulkSize) - valueList = make([]string, 0, decisionBulkSize) - } - } + deleteChunks := slicetools.Chunks(valueList, c.decisionBulkSize) - } - log.Debugf("deleted %d decisions for %s vs %s", deleted, DecOrigin, *alertItem.Decisions[0].Origin) - insertedDecisions, err := txClient.Decision.CreateBulk(decisionBulk...).Save(c.CTX) - if err != nil { - return 0, 0, 0, errors.Wrapf(BulkError, "creating alert decisions: %s", err) - } - inserted += len(insertedDecisions) - /*Deleting older decisions from capi*/ - if len(valueList) > 0 { + for _, deleteChunk := range deleteChunks { + // Deleting older decisions from capi deletedDecisions, err := txClient.Decision.Delete(). Where(decision.And( decision.OriginEQ(DecOrigin), decision.Not(decision.HasOwnerWith(alert.IDEQ(alertRef.ID))), - decision.ValueIn(valueList...), - )).Exec(c.CTX) + decision.ValueIn(deleteChunk...), + )).Exec(ctx) if err != nil { rollbackErr := txClient.Rollback() if rollbackErr != nil { log.Errorf("rollback error: %s", rollbackErr) } + return 0, 0, 0, fmt.Errorf("while deleting older community blocklist decisions: %w", err) } + deleted += deletedDecisions } + + builderChunks := slicetools.Chunks(decisionBuilders, c.decisionBulkSize) + + for _, builderChunk := range builderChunks { + insertedDecisions, err := txClient.Decision.CreateBulk(builderChunk...).Save(ctx) + if err != nil { + rollbackErr := txClient.Rollback() + if rollbackErr != nil { + log.Errorf("rollback error: %s", rollbackErr) + } + + return 0, 0, 0, fmt.Errorf("while bulk creating decisions: %w", err) + } + + inserted += len(insertedDecisions) + } + + log.Debugf("deleted %d decisions for %s vs %s", deleted, DecOrigin, *alertItem.Decisions[0].Origin) + err = txClient.Commit() if err != nil { rollbackErr := txClient.Rollback() if rollbackErr != nil { log.Errorf("rollback error: %s", rollbackErr) } - return 0, 0, 0, errors.Wrapf(BulkError, "error committing transaction : %s", err) + + return 0, 0, 0, fmt.Errorf("error committing transaction: %w", err) } return alertRef.ID, inserted, deleted, nil } -func chunkDecisions(decisions []*ent.Decision, chunkSize int) [][]*ent.Decision { - var ret [][]*ent.Decision - var chunk []*ent.Decision - - for _, d := range decisions { - chunk = append(chunk, d) - if len(chunk) == chunkSize { - ret = append(ret, chunk) - chunk = nil - } - } - if len(chunk) > 0 { - ret = append(ret, chunk) - } - return ret -} +func (c *Client) createDecisionChunk(ctx context.Context, simulated bool, stopAtTime time.Time, decisions []*models.Decision) ([]*ent.Decision, error) { + decisionCreate := []*ent.DecisionCreate{} -func (c *Client) CreateAlertBulk(machineId string, alertList []*models.Alert) ([]string, error) { - ret := []string{} - bulkSize := 20 - var owner *ent.Machine - var err error + for _, decisionItem := range decisions { + var ( + start_ip, start_sfx, end_ip, end_sfx int64 + sz int + ) - if machineId != "" { - owner, err = c.QueryMachineByID(machineId) + duration, err := time.ParseDuration(*decisionItem.Duration) if err != nil { - if errors.Cause(err) != UserNotExists { - return []string{}, errors.Wrapf(QueryFail, "machine '%s': %s", machineId, err) + return nil, errors.Wrapf(ParseDurationFail, "decision duration '%+v' : %s", *decisionItem.Duration, err) + } + + /*if the scope is IP or Range, convert the value to integers */ + if strings.ToLower(*decisionItem.Scope) == "ip" || strings.ToLower(*decisionItem.Scope) == "range" { + sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(*decisionItem.Value) + if err != nil { + log.Errorf("invalid addr/range '%s': %s", *decisionItem.Value, err) + continue } - c.Log.Debugf("CreateAlertBulk: Machine Id %s doesn't exist", machineId) - owner = nil } - } else { - owner = nil + + newDecision := c.Ent.Decision.Create(). + SetUntil(stopAtTime.Add(duration)). + SetScenario(*decisionItem.Scenario). + SetType(*decisionItem.Type). + SetStartIP(start_ip). + SetStartSuffix(start_sfx). + SetEndIP(end_ip). + SetEndSuffix(end_sfx). + SetIPSize(int64(sz)). + SetValue(*decisionItem.Value). + SetScope(*decisionItem.Scope). + SetOrigin(*decisionItem.Origin). + SetSimulated(simulated). + SetUUID(decisionItem.UUID) + + decisionCreate = append(decisionCreate, newDecision) } - c.Log.Debugf("writing %d items", len(alertList)) - bulk := make([]*ent.AlertCreate, 0, bulkSize) - alertDecisions := make([][]*ent.Decision, 0, bulkSize) - for i, alertItem := range alertList { - var decisions []*ent.Decision - var metas []*ent.Meta - var events []*ent.Event + if len(decisionCreate) == 0 { + return nil, nil + } + + ret, err := c.Ent.Decision.CreateBulk(decisionCreate...).Save(ctx) + if err != nil { + return nil, err + } + + return ret, nil +} + +func (c *Client) createAlertChunk(ctx context.Context, machineID string, owner *ent.Machine, alerts []*models.Alert) ([]string, error) { + alertBuilders := []*ent.AlertCreate{} + alertDecisions := [][]*ent.Decision{} + + for _, alertItem := range alerts { + var ( + metas []*ent.Meta + events []*ent.Event + ) startAtTime, err := time.Parse(time.RFC3339, *alertItem.StartAt) if err != nil { - c.Log.Errorf("CreateAlertBulk: Failed to parse startAtTime '%s', defaulting to now: %s", *alertItem.StartAt, err) + c.Log.Errorf("creating alert: Failed to parse startAtTime '%s', defaulting to now: %s", *alertItem.StartAt, err) + startAtTime = time.Now().UTC() } stopAtTime, err := time.Parse(time.RFC3339, *alertItem.StopAt) if err != nil { - c.Log.Errorf("CreateAlertBulk: Failed to parse stopAtTime '%s', defaulting to now: %s", *alertItem.StopAt, err) + c.Log.Errorf("creating alert: Failed to parse stopAtTime '%s', defaulting to now: %s", *alertItem.StopAt, err) + stopAtTime = time.Now().UTC() } + /*display proper alert in logs*/ - for _, disp := range formatAlertAsString(machineId, alertItem) { + for _, disp := range alertItem.FormatAsStrings(machineID, log.StandardLogger()) { c.Log.Info(disp) } - //let's track when we strip or drop data, notify outside of loop to avoid spam + // let's track when we strip or drop data, notify outside of loop to avoid spam stripped := false dropped := false if len(alertItem.Events) > 0 { eventBulk := make([]*ent.EventCreate, len(alertItem.Events)) + for i, eventItem := range alertItem.Events { ts, err := time.Parse(time.RFC3339, *eventItem.Timestamp) if err != nil { - c.Log.Errorf("CreateAlertBulk: Failed to parse event timestamp '%s', defaulting to now: %s", *eventItem.Timestamp, err) + c.Log.Errorf("creating alert: Failed to parse event timestamp '%s', defaulting to now: %s", *eventItem.Timestamp, err) + ts = time.Now().UTC() } + marshallMetas, err := json.Marshal(eventItem.Meta) if err != nil { return nil, errors.Wrapf(MarshalFail, "event meta '%v' : %s", eventItem.Meta, err) } - //the serialized field is too big, let's try to progressively strip it + // the serialized field is too big, let's try to progressively strip it if event.SerializedValidator(string(marshallMetas)) != nil { stripped = true valid := false stripSize := 2048 + for !valid && stripSize > 0 { for _, serializedItem := range eventItem.Meta { if len(serializedItem.Value) > stripSize*2 { @@ -585,32 +511,36 @@ func (c *Client) CreateAlertBulk(machineId string, alertList []*models.Alert) ([ if err != nil { return nil, errors.Wrapf(MarshalFail, "event meta '%v' : %s", eventItem.Meta, err) } + if event.SerializedValidator(string(marshallMetas)) == nil { valid = true } + stripSize /= 2 } - //nothing worked, drop it + // nothing worked, drop it if !valid { dropped = true stripped = false marshallMetas = []byte("") } - } eventBulk[i] = c.Ent.Event.Create(). SetTime(ts). SetSerialized(string(marshallMetas)) } + if stripped { - c.Log.Warningf("stripped 'serialized' field (machine %s / scenario %s)", machineId, *alertItem.Scenario) + c.Log.Warningf("stripped 'serialized' field (machine %s / scenario %s)", machineID, *alertItem.Scenario) } + if dropped { - c.Log.Warningf("dropped 'serialized' field (machine %s / scenario %s)", machineId, *alertItem.Scenario) + c.Log.Warningf("dropped 'serialized' field (machine %s / scenario %s)", machineID, *alertItem.Scenario) } - events, err = c.Ent.Event.CreateBulk(eventBulk...).Save(c.CTX) + + events, err = c.Ent.Event.CreateBulk(eventBulk...).Save(ctx) if err != nil { return nil, errors.Wrapf(BulkError, "creating alert events: %s", err) } @@ -618,75 +548,58 @@ func (c *Client) CreateAlertBulk(machineId string, alertList []*models.Alert) ([ if len(alertItem.Meta) > 0 { metaBulk := make([]*ent.MetaCreate, len(alertItem.Meta)) + for i, metaItem := range alertItem.Meta { - metaBulk[i] = c.Ent.Meta.Create(). - SetKey(metaItem.Key). - SetValue(metaItem.Value) - } - metas, err = c.Ent.Meta.CreateBulk(metaBulk...).Save(c.CTX) - if err != nil { - return nil, errors.Wrapf(BulkError, "creating alert meta: %s", err) - } - } + key := metaItem.Key + value := metaItem.Value - decisions = make([]*ent.Decision, 0) - if len(alertItem.Decisions) > 0 { - decisionBulk := make([]*ent.DecisionCreate, 0, decisionBulkSize) - for i, decisionItem := range alertItem.Decisions { - var start_ip, start_sfx, end_ip, end_sfx int64 - var sz int + if len(metaItem.Value) > 4095 { + c.Log.Warningf("truncated meta %s: value too long", metaItem.Key) - duration, err := time.ParseDuration(*decisionItem.Duration) - if err != nil { - return nil, errors.Wrapf(ParseDurationFail, "decision duration '%+v' : %s", *decisionItem.Duration, err) + value = value[:4095] } - /*if the scope is IP or Range, convert the value to integers */ - if strings.ToLower(*decisionItem.Scope) == "ip" || strings.ToLower(*decisionItem.Scope) == "range" { - sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(*decisionItem.Value) - if err != nil { - return nil, fmt.Errorf("%s: %w", *decisionItem.Value, InvalidIPOrRange) - } + if len(metaItem.Key) > 255 { + c.Log.Warningf("truncated meta %s: key too long", metaItem.Key) + + key = key[:255] } - decisionCreate := c.Ent.Decision.Create(). - SetUntil(stopAtTime.Add(duration)). - SetScenario(*decisionItem.Scenario). - SetType(*decisionItem.Type). - SetStartIP(start_ip). - SetStartSuffix(start_sfx). - SetEndIP(end_ip). - SetEndSuffix(end_sfx). - SetIPSize(int64(sz)). - SetValue(*decisionItem.Value). - SetScope(*decisionItem.Scope). - SetOrigin(*decisionItem.Origin). - SetSimulated(*alertItem.Simulated). - SetUUID(decisionItem.UUID) - - decisionBulk = append(decisionBulk, decisionCreate) - if len(decisionBulk) == decisionBulkSize { - decisionsCreateRet, err := c.Ent.Decision.CreateBulk(decisionBulk...).Save(c.CTX) - if err != nil { - return nil, errors.Wrapf(BulkError, "creating alert decisions: %s", err) + metaBulk[i] = c.Ent.Meta.Create(). + SetKey(key). + SetValue(value) + } - } - decisions = append(decisions, decisionsCreateRet...) - if len(alertItem.Decisions)-i <= decisionBulkSize { - decisionBulk = make([]*ent.DecisionCreate, 0, (len(alertItem.Decisions) - i)) - } else { - decisionBulk = make([]*ent.DecisionCreate, 0, decisionBulkSize) - } - } + metas, err = c.Ent.Meta.CreateBulk(metaBulk...).Save(ctx) + if err != nil { + c.Log.Warningf("error creating alert meta: %s", err) } - decisionsCreateRet, err := c.Ent.Decision.CreateBulk(decisionBulk...).Save(c.CTX) + } + + decisions := []*ent.Decision{} + + decisionChunks := slicetools.Chunks(alertItem.Decisions, c.decisionBulkSize) + for _, decisionChunk := range decisionChunks { + decisionRet, err := c.createDecisionChunk(ctx, *alertItem.Simulated, stopAtTime, decisionChunk) if err != nil { - return nil, errors.Wrapf(BulkError, "creating alert decisions: %s", err) + return nil, fmt.Errorf("creating alert decisions: %w", err) } - decisions = append(decisions, decisionsCreateRet...) + + decisions = append(decisions, decisionRet...) + } + + discarded := len(alertItem.Decisions) - len(decisions) + if discarded > 0 { + c.Log.Warningf("discarded %d decisions for %s", discarded, alertItem.UUID) } - alertB := c.Ent.Alert. + // if all decisions were discarded, discard the alert too + if discarded > 0 && len(decisions) == 0 { + c.Log.Warningf("dropping alert %s with invalid decisions", alertItem.UUID) + continue + } + + alertBuilder := c.Ent.Alert. Create(). SetScenario(*alertItem.Scenario). SetMessage(*alertItem.Message). @@ -707,55 +620,63 @@ func (c *Client) CreateAlertBulk(machineId string, alertList []*models.Alert) ([ SetSimulated(*alertItem.Simulated). SetScenarioVersion(*alertItem.ScenarioVersion). SetScenarioHash(*alertItem.ScenarioHash). + SetRemediation(alertItem.Remediation). SetUUID(alertItem.UUID). AddEvents(events...). AddMetas(metas...) if owner != nil { - alertB.SetOwner(owner) + alertBuilder.SetOwner(owner) } - bulk = append(bulk, alertB) + + alertBuilders = append(alertBuilders, alertBuilder) alertDecisions = append(alertDecisions, decisions) + } - if len(bulk) == bulkSize { - alerts, err := c.Ent.Alert.CreateBulk(bulk...).Save(c.CTX) - if err != nil { - return nil, errors.Wrapf(BulkError, "bulk creating alert : %s", err) - } - for alertIndex, a := range alerts { - ret = append(ret, strconv.Itoa(a.ID)) - d := alertDecisions[alertIndex] - decisionsChunk := chunkDecisions(d, bulkSize) - for _, d2 := range decisionsChunk { - _, err := c.Ent.Alert.Update().Where(alert.IDEQ(a.ID)).AddDecisions(d2...).Save(c.CTX) - if err != nil { - return nil, fmt.Errorf("error while updating decisions: %s", err) - } - } - } - if len(alertList)-i <= bulkSize { - bulk = make([]*ent.AlertCreate, 0, (len(alertList) - i)) - alertDecisions = make([][]*ent.Decision, 0, (len(alertList) - i)) - } else { - bulk = make([]*ent.AlertCreate, 0, bulkSize) - alertDecisions = make([][]*ent.Decision, 0, bulkSize) - } - } + if len(alertBuilders) == 0 { + log.Warningf("no alerts to create, discarded?") + return nil, nil } - alerts, err := c.Ent.Alert.CreateBulk(bulk...).Save(c.CTX) + alertsCreateBulk, err := c.Ent.Alert.CreateBulk(alertBuilders...).Save(ctx) if err != nil { - return nil, errors.Wrapf(BulkError, "leftovers creating alert : %s", err) + return nil, errors.Wrapf(BulkError, "bulk creating alert : %s", err) } - for alertIndex, a := range alerts { - ret = append(ret, strconv.Itoa(a.ID)) - d := alertDecisions[alertIndex] - decisionsChunk := chunkDecisions(d, bulkSize) + ret := make([]string, len(alertsCreateBulk)) + for i, a := range alertsCreateBulk { + ret[i] = strconv.Itoa(a.ID) + + d := alertDecisions[i] + decisionsChunk := slicetools.Chunks(d, c.decisionBulkSize) + for _, d2 := range decisionsChunk { - _, err := c.Ent.Alert.Update().Where(alert.IDEQ(a.ID)).AddDecisions(d2...).Save(c.CTX) - if err != nil { - return nil, fmt.Errorf("error while updating decisions: %s", err) + retry := 0 + + for retry < maxLockRetries { + // so much for the happy path... but sqlite3 errors work differently + _, err := c.Ent.Alert.Update().Where(alert.IDEQ(a.ID)).AddDecisions(d2...).Save(ctx) + if err == nil { + break + } + + if sqliteErr, ok := err.(sqlite3.Error); ok { + if sqliteErr.Code == sqlite3.ErrBusy { + // sqlite3.Error{ + // Code: 5, + // ExtendedCode: 5, + // SystemErrno: 0, + // err: "database is locked", + // } + retry++ + log.Warningf("while updating decisions, sqlite3.ErrBusy: %s, retry %d of %d", err, retry, maxLockRetries) + time.Sleep(1 * time.Second) + + continue + } + } + + return nil, fmt.Errorf("error while updating decisions: %w", err) } } } @@ -763,26 +684,220 @@ func (c *Client) CreateAlertBulk(machineId string, alertList []*models.Alert) ([ return ret, nil } -func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, error) { - predicates := make([]predicate.Alert, 0) - var err error - var start_ip, start_sfx, end_ip, end_sfx int64 - var hasActiveDecision bool - var ip_sz int - var contains bool = true - /*if contains is true, return bans that *contains* the given value (value is the inner) - else, return bans that are *contained* by the given value (value is the outer)*/ +func (c *Client) CreateAlert(ctx context.Context, machineID string, alertList []*models.Alert) ([]string, error) { + var ( + owner *ent.Machine + err error + ) + + if machineID != "" { + owner, err = c.QueryMachineByID(ctx, machineID) + if err != nil { + if !errors.Is(err, UserNotExists) { + return nil, fmt.Errorf("machine '%s': %w", machineID, err) + } - /*the simulated filter is a bit different : if it's not present *or* set to false, specifically exclude records with simulated to true */ - if v, ok := filter["simulated"]; ok { - if v[0] == "false" { - predicates = append(predicates, alert.SimulatedEQ(false)) + c.Log.Debugf("creating alert: machine %s doesn't exist", machineID) + + owner = nil } } + c.Log.Debugf("writing %d items", len(alertList)) + + alertChunks := slicetools.Chunks(alertList, alertCreateBulkSize) + alertIDs := []string{} + + for _, alertChunk := range alertChunks { + ids, err := c.createAlertChunk(ctx, machineID, owner, alertChunk) + if err != nil { + return nil, fmt.Errorf("machine '%s': %w", machineID, err) + } + + alertIDs = append(alertIDs, ids...) + } + + if owner != nil { + err = owner.Update().SetLastPush(time.Now().UTC()).Exec(ctx) + if err != nil { + return nil, fmt.Errorf("machine '%s': %w", machineID, err) + } + } + + return alertIDs, nil +} + +func handleSimulatedFilter(filter map[string][]string, predicates *[]predicate.Alert) { + /* the simulated filter is a bit different : if it's not present *or* set to false, specifically exclude records with simulated to true */ + if v, ok := filter["simulated"]; ok && v[0] == "false" { + *predicates = append(*predicates, alert.SimulatedEQ(false)) + } +} + +func handleOriginFilter(filter map[string][]string, predicates *[]predicate.Alert) { if _, ok := filter["origin"]; ok { filter["include_capi"] = []string{"true"} } +} + +func handleScopeFilter(scope string, predicates *[]predicate.Alert) { + if strings.ToLower(scope) == "ip" { + scope = types.Ip + } else if strings.ToLower(scope) == "range" { + scope = types.Range + } + + *predicates = append(*predicates, alert.SourceScopeEQ(scope)) +} + +func handleTimeFilters(param, value string, predicates *[]predicate.Alert) error { + duration, err := ParseDuration(value) + if err != nil { + return fmt.Errorf("while parsing duration: %w", err) + } + + timePoint := time.Now().UTC().Add(-duration) + if timePoint.IsZero() { + return fmt.Errorf("empty time now() - %s", timePoint.String()) + } + + switch param { + case "since": + *predicates = append(*predicates, alert.StartedAtGTE(timePoint)) + case "created_before": + *predicates = append(*predicates, alert.CreatedAtLTE(timePoint)) + case "until": + *predicates = append(*predicates, alert.StartedAtLTE(timePoint)) + } + + return nil +} + +func handleIPv4Predicates(ip_sz int, contains bool, start_ip, start_sfx, end_ip, end_sfx int64, predicates *[]predicate.Alert) { + if contains { // decision contains {start_ip,end_ip} + *predicates = append(*predicates, alert.And( + alert.HasDecisionsWith(decision.StartIPLTE(start_ip)), + alert.HasDecisionsWith(decision.EndIPGTE(end_ip)), + alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), + )) + } else { // decision is contained within {start_ip,end_ip} + *predicates = append(*predicates, alert.And( + alert.HasDecisionsWith(decision.StartIPGTE(start_ip)), + alert.HasDecisionsWith(decision.EndIPLTE(end_ip)), + alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), + )) + } +} + +func handleIPv6Predicates(ip_sz int, contains bool, start_ip, start_sfx, end_ip, end_sfx int64, predicates *[]predicate.Alert) { + if contains { // decision contains {start_ip,end_ip} + *predicates = append(*predicates, alert.And( + // matching addr size + alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), + alert.Or( + // decision.start_ip < query.start_ip + alert.HasDecisionsWith(decision.StartIPLT(start_ip)), + alert.And( + // decision.start_ip == query.start_ip + alert.HasDecisionsWith(decision.StartIPEQ(start_ip)), + // decision.start_suffix <= query.start_suffix + alert.HasDecisionsWith(decision.StartSuffixLTE(start_sfx)), + ), + ), + alert.Or( + // decision.end_ip > query.end_ip + alert.HasDecisionsWith(decision.EndIPGT(end_ip)), + alert.And( + // decision.end_ip == query.end_ip + alert.HasDecisionsWith(decision.EndIPEQ(end_ip)), + // decision.end_suffix >= query.end_suffix + alert.HasDecisionsWith(decision.EndSuffixGTE(end_sfx)), + ), + ), + )) + } else { // decision is contained within {start_ip,end_ip} + *predicates = append(*predicates, alert.And( + // matching addr size + alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), + alert.Or( + // decision.start_ip > query.start_ip + alert.HasDecisionsWith(decision.StartIPGT(start_ip)), + alert.And( + // decision.start_ip == query.start_ip + alert.HasDecisionsWith(decision.StartIPEQ(start_ip)), + // decision.start_suffix >= query.start_suffix + alert.HasDecisionsWith(decision.StartSuffixGTE(start_sfx)), + ), + ), + alert.Or( + // decision.end_ip < query.end_ip + alert.HasDecisionsWith(decision.EndIPLT(end_ip)), + alert.And( + // decision.end_ip == query.end_ip + alert.HasDecisionsWith(decision.EndIPEQ(end_ip)), + // decision.end_suffix <= query.end_suffix + alert.HasDecisionsWith(decision.EndSuffixLTE(end_sfx)), + ), + ), + )) + } +} + +func handleIPPredicates(ip_sz int, contains bool, start_ip, start_sfx, end_ip, end_sfx int64, predicates *[]predicate.Alert) error { + if ip_sz == 4 { + handleIPv4Predicates(ip_sz, contains, start_ip, start_sfx, end_ip, end_sfx, predicates) + } else if ip_sz == 16 { + handleIPv6Predicates(ip_sz, contains, start_ip, start_sfx, end_ip, end_sfx, predicates) + } else if ip_sz != 0 { + return errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz) + } + + return nil +} + +func handleIncludeCapiFilter(value string, predicates *[]predicate.Alert) error { + if value == "false" { + *predicates = append(*predicates, alert.And( + // do not show alerts with active decisions having origin CAPI or lists + alert.And( + alert.Not(alert.HasDecisionsWith(decision.OriginEQ(types.CAPIOrigin))), + alert.Not(alert.HasDecisionsWith(decision.OriginEQ(types.ListOrigin))), + ), + alert.Not( + alert.And( + // do not show neither alerts with no decisions if the Source Scope is lists: or CAPI + alert.Not(alert.HasDecisions()), + alert.Or( + alert.SourceScopeHasPrefix(types.ListOrigin+":"), + alert.SourceScopeEQ(types.CommunityBlocklistPullSourceScope), + ), + ), + ), + )) + } else if value != "true" { + log.Errorf("invalid bool '%s' for include_capi", value) + } + + return nil +} + +func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, error) { + predicates := make([]predicate.Alert, 0) + + var ( + err error + start_ip, start_sfx, end_ip, end_sfx int64 + hasActiveDecision bool + ip_sz int + ) + + contains := true + + /*if contains is true, return bans that *contains* the given value (value is the inner) + else, return bans that are *contained* by the given value (value is the outer)*/ + + handleSimulatedFilter(filter, &predicates) + handleOriginFilter(filter, &predicates) for param, value := range filter { switch param { @@ -792,13 +907,7 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e return nil, errors.Wrapf(InvalidFilter, "invalid contains value : %s", err) } case "scope": - var scope string = value[0] - if strings.ToLower(scope) == "ip" { - scope = types.Ip - } else if strings.ToLower(scope) == "range" { - scope = types.Range - } - predicates = append(predicates, alert.SourceScopeEQ(scope)) + handleScopeFilter(value[0], &predicates) case "value": predicates = append(predicates, alert.SourceValueEQ(value[0])) case "scenario": @@ -808,54 +917,23 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e if err != nil { return nil, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", value[0], err) } - case "since": - duration, err := types.ParseDuration(value[0]) - if err != nil { - return nil, fmt.Errorf("while parsing duration: %w", err) + case "since", "created_before", "until": + if err := handleTimeFilters(param, value[0], &predicates); err != nil { + return nil, err } - since := time.Now().UTC().Add(-duration) - if since.IsZero() { - return nil, fmt.Errorf("Empty time now() - %s", since.String()) - } - predicates = append(predicates, alert.StartedAtGTE(since)) - case "created_before": - duration, err := types.ParseDuration(value[0]) - if err != nil { - return nil, fmt.Errorf("while parsing duration: %w", err) - } - since := time.Now().UTC().Add(-duration) - if since.IsZero() { - return nil, fmt.Errorf("empty time now() - %s", since.String()) - } - predicates = append(predicates, alert.CreatedAtLTE(since)) - case "until": - duration, err := types.ParseDuration(value[0]) - if err != nil { - return nil, fmt.Errorf("while parsing duration: %w", err) - } - until := time.Now().UTC().Add(-duration) - if until.IsZero() { - return nil, fmt.Errorf("empty time now() - %s", until.String()) - } - predicates = append(predicates, alert.StartedAtLTE(until)) case "decision_type": predicates = append(predicates, alert.HasDecisionsWith(decision.TypeEQ(value[0]))) case "origin": predicates = append(predicates, alert.HasDecisionsWith(decision.OriginEQ(value[0]))) - case "include_capi": //allows to exclude one or more specific origins - if value[0] == "false" { - predicates = append(predicates, alert.HasDecisionsWith( - decision.Or(decision.OriginEQ(types.CrowdSecOrigin), - decision.OriginEQ(types.CscliOrigin), - decision.OriginEQ(types.ConsoleOrigin), - decision.OriginEQ(types.CscliImportOrigin)))) - } else if value[0] != "true" { - log.Errorf("Invalid bool '%s' for include_capi", value[0]) + case "include_capi": // allows to exclude one or more specific origins + if err = handleIncludeCapiFilter(value[0], &predicates); err != nil { + return nil, err } case "has_active_decision": if hasActiveDecision, err = strconv.ParseBool(value[0]); err != nil { return nil, errors.Wrapf(ParseType, "'%s' is not a boolean: %s", value[0], err) } + if hasActiveDecision { predicates = append(predicates, alert.HasDecisionsWith(decision.UntilGTE(time.Now().UTC()))) } else { @@ -874,103 +952,36 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e } } - if ip_sz == 4 { - if contains { /*decision contains {start_ip,end_ip}*/ - predicates = append(predicates, alert.And( - alert.HasDecisionsWith(decision.StartIPLTE(start_ip)), - alert.HasDecisionsWith(decision.EndIPGTE(end_ip)), - alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), - )) - } else { /*decision is contained within {start_ip,end_ip}*/ - predicates = append(predicates, alert.And( - alert.HasDecisionsWith(decision.StartIPGTE(start_ip)), - alert.HasDecisionsWith(decision.EndIPLTE(end_ip)), - alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), - )) - } - } else if ip_sz == 16 { - - if contains { /*decision contains {start_ip,end_ip}*/ - predicates = append(predicates, alert.And( - //matching addr size - alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), - alert.Or( - //decision.start_ip < query.start_ip - alert.HasDecisionsWith(decision.StartIPLT(start_ip)), - alert.And( - //decision.start_ip == query.start_ip - alert.HasDecisionsWith(decision.StartIPEQ(start_ip)), - //decision.start_suffix <= query.start_suffix - alert.HasDecisionsWith(decision.StartSuffixLTE(start_sfx)), - )), - alert.Or( - //decision.end_ip > query.end_ip - alert.HasDecisionsWith(decision.EndIPGT(end_ip)), - alert.And( - //decision.end_ip == query.end_ip - alert.HasDecisionsWith(decision.EndIPEQ(end_ip)), - //decision.end_suffix >= query.end_suffix - alert.HasDecisionsWith(decision.EndSuffixGTE(end_sfx)), - ), - ), - )) - } else { /*decision is contained within {start_ip,end_ip}*/ - predicates = append(predicates, alert.And( - //matching addr size - alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), - alert.Or( - //decision.start_ip > query.start_ip - alert.HasDecisionsWith(decision.StartIPGT(start_ip)), - alert.And( - //decision.start_ip == query.start_ip - alert.HasDecisionsWith(decision.StartIPEQ(start_ip)), - //decision.start_suffix >= query.start_suffix - alert.HasDecisionsWith(decision.StartSuffixGTE(start_sfx)), - )), - alert.Or( - //decision.end_ip < query.end_ip - alert.HasDecisionsWith(decision.EndIPLT(end_ip)), - alert.And( - //decision.end_ip == query.end_ip - alert.HasDecisionsWith(decision.EndIPEQ(end_ip)), - //decision.end_suffix <= query.end_suffix - alert.HasDecisionsWith(decision.EndSuffixLTE(end_sfx)), - ), - ), - )) - } - } else if ip_sz != 0 { - return nil, errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz) + if err := handleIPPredicates(ip_sz, contains, start_ip, start_sfx, end_ip, end_sfx, &predicates); err != nil { + return nil, err } + return predicates, nil } + func BuildAlertRequestFromFilter(alerts *ent.AlertQuery, filter map[string][]string) (*ent.AlertQuery, error) { preds, err := AlertPredicatesFromFilter(filter) if err != nil { return nil, err } + return alerts.Where(preds...), nil } -func (c *Client) AlertsCountPerScenario(filters map[string][]string) (map[string]int, error) { - +func (c *Client) AlertsCountPerScenario(ctx context.Context, filters map[string][]string) (map[string]int, error) { var res []struct { Scenario string Count int } - ctx := context.Background() - query := c.Ent.Alert.Query() query, err := BuildAlertRequestFromFilter(query, filters) - if err != nil { return nil, fmt.Errorf("failed to build alert request: %w", err) } err = query.GroupBy(alert.FieldScenario).Aggregate(ent.Count()).Scan(ctx, &res) - if err != nil { return nil, fmt.Errorf("failed to count alerts per scenario: %w", err) } @@ -984,13 +995,13 @@ func (c *Client) AlertsCountPerScenario(filters map[string][]string) (map[string return counts, nil } -func (c *Client) TotalAlerts() (int, error) { - return c.Ent.Alert.Query().Count(c.CTX) +func (c *Client) TotalAlerts(ctx context.Context) (int, error) { + return c.Ent.Alert.Query().Count(ctx) } -func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, error) { - +func (c *Client) QueryAlertWithFilter(ctx context.Context, filter map[string][]string) ([]*ent.Alert, error) { sort := "DESC" // we sort by desc by default + if val, ok := filter["sort"]; ok { if val[0] != "ASC" && val[0] != "DESC" { c.Log.Errorf("invalid 'sort' parameter: %s", val) @@ -998,40 +1009,46 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, sort = val[0] } } + limit := defaultLimit + if val, ok := filter["limit"]; ok { limitConv, err := strconv.Atoi(val[0]) if err != nil { - return []*ent.Alert{}, errors.Wrapf(QueryFail, "bad limit in parameters: %s", val) + return nil, errors.Wrapf(QueryFail, "bad limit in parameters: %s", val) } - limit = limitConv + limit = limitConv } + offset := 0 ret := make([]*ent.Alert, 0) + for { alerts := c.Ent.Alert.Query() + alerts, err := BuildAlertRequestFromFilter(alerts, filter) if err != nil { - return []*ent.Alert{}, err + return nil, err } - //only if with_decisions is present and set to false, we exclude this + // only if with_decisions is present and set to false, we exclude this if val, ok := filter["with_decisions"]; ok && val[0] == "false" { c.Log.Debugf("skipping decisions") } else { alerts = alerts. WithDecisions() } + alerts = alerts. WithEvents(). WithMetas(). WithOwner() if limit == 0 { - limit, err = alerts.Count(c.CTX) + limit, err = alerts.Count(ctx) if err != nil { - return []*ent.Alert{}, fmt.Errorf("unable to count nb alerts: %s", err) + return nil, fmt.Errorf("unable to count nb alerts: %w", err) } } @@ -1041,60 +1058,64 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, alerts = alerts.Order(ent.Desc(alert.FieldCreatedAt), ent.Desc(alert.FieldID)) } - result, err := alerts.Limit(paginationSize).Offset(offset).All(c.CTX) + result, err := alerts.Limit(paginationSize).Offset(offset).All(ctx) if err != nil { - return []*ent.Alert{}, errors.Wrapf(QueryFail, "pagination size: %d, offset: %d: %s", paginationSize, offset, err) + return nil, errors.Wrapf(QueryFail, "pagination size: %d, offset: %d: %s", paginationSize, offset, err) } + if diff := limit - len(ret); diff < paginationSize { if len(result) < diff { ret = append(ret, result...) c.Log.Debugf("Pagination done, %d < %d", len(result), diff) + break } - ret = append(ret, result[0:diff]...) + ret = append(ret, result[0:diff]...) } else { ret = append(ret, result...) } + if len(ret) == limit || len(ret) == 0 || len(ret) < paginationSize { c.Log.Debugf("Pagination done len(ret) = %d", len(ret)) break } + offset += paginationSize } return ret, nil } -func (c *Client) DeleteAlertGraphBatch(alertItems []*ent.Alert) (int, error) { +func (c *Client) DeleteAlertGraphBatch(ctx context.Context, alertItems []*ent.Alert) (int, error) { idList := make([]int, 0) for _, alert := range alertItems { idList = append(idList, alert.ID) } _, err := c.Ent.Event.Delete(). - Where(event.HasOwnerWith(alert.IDIn(idList...))).Exec(c.CTX) + Where(event.HasOwnerWith(alert.IDIn(idList...))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraphBatch : %s", err) return 0, errors.Wrapf(DeleteFail, "alert graph delete batch events") } _, err = c.Ent.Meta.Delete(). - Where(meta.HasOwnerWith(alert.IDIn(idList...))).Exec(c.CTX) + Where(meta.HasOwnerWith(alert.IDIn(idList...))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraphBatch : %s", err) return 0, errors.Wrapf(DeleteFail, "alert graph delete batch meta") } _, err = c.Ent.Decision.Delete(). - Where(decision.HasOwnerWith(alert.IDIn(idList...))).Exec(c.CTX) + Where(decision.HasOwnerWith(alert.IDIn(idList...))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraphBatch : %s", err) return 0, errors.Wrapf(DeleteFail, "alert graph delete batch decisions") } deleted, err := c.Ent.Alert.Delete(). - Where(alert.IDIn(idList...)).Exec(c.CTX) + Where(alert.IDIn(idList...)).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraphBatch : %s", err) return deleted, errors.Wrapf(DeleteFail, "alert graph delete batch") @@ -1105,10 +1126,10 @@ func (c *Client) DeleteAlertGraphBatch(alertItems []*ent.Alert) (int, error) { return deleted, nil } -func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error { +func (c *Client) DeleteAlertGraph(ctx context.Context, alertItem *ent.Alert) error { // delete the associated events _, err := c.Ent.Event.Delete(). - Where(event.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(c.CTX) + Where(event.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraph : %s", err) return errors.Wrapf(DeleteFail, "event with alert ID '%d'", alertItem.ID) @@ -1116,7 +1137,7 @@ func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error { // delete the associated meta _, err = c.Ent.Meta.Delete(). - Where(meta.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(c.CTX) + Where(meta.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraph : %s", err) return errors.Wrapf(DeleteFail, "meta with alert ID '%d'", alertItem.ID) @@ -1124,14 +1145,14 @@ func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error { // delete the associated decisions _, err = c.Ent.Decision.Delete(). - Where(decision.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(c.CTX) + Where(decision.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraph : %s", err) return errors.Wrapf(DeleteFail, "decision with alert ID '%d'", alertItem.ID) } // delete the alert - err = c.Ent.Alert.DeleteOne(alertItem).Exec(c.CTX) + err = c.Ent.Alert.DeleteOne(alertItem).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraph : %s", err) return errors.Wrapf(DeleteFail, "alert with ID '%d'", alertItem.ID) @@ -1140,204 +1161,37 @@ func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error { return nil } -func (c *Client) DeleteAlertByID(id int) error { - alertItem, err := c.Ent.Alert.Query().Where(alert.IDEQ(id)).Only(c.CTX) +func (c *Client) DeleteAlertByID(ctx context.Context, id int) error { + alertItem, err := c.Ent.Alert.Query().Where(alert.IDEQ(id)).Only(ctx) if err != nil { return err } - return c.DeleteAlertGraph(alertItem) + return c.DeleteAlertGraph(ctx, alertItem) } -func (c *Client) DeleteAlertWithFilter(filter map[string][]string) (int, error) { +func (c *Client) DeleteAlertWithFilter(ctx context.Context, filter map[string][]string) (int, error) { preds, err := AlertPredicatesFromFilter(filter) if err != nil { return 0, err } - return c.Ent.Alert.Delete().Where(preds...).Exec(c.CTX) -} - -func (c *Client) FlushOrphans() { - /* While it has only been linked to some very corner-case bug : https://github.com/crowdsecurity/crowdsec/issues/778 */ - /* We want to take care of orphaned events for which the parent alert/decision has been deleted */ - - events_count, err := c.Ent.Event.Delete().Where(event.Not(event.HasOwner())).Exec(c.CTX) - if err != nil { - c.Log.Warningf("error while deleting orphan events : %s", err) - return - } - if events_count > 0 { - c.Log.Infof("%d deleted orphan events", events_count) - } - - events_count, err = c.Ent.Decision.Delete().Where( - decision.Not(decision.HasOwner())).Where(decision.UntilLTE(time.Now().UTC())).Exec(c.CTX) - - if err != nil { - c.Log.Warningf("error while deleting orphan decisions : %s", err) - return - } - if events_count > 0 { - c.Log.Infof("%d deleted orphan decisions", events_count) - } -} - -func (c *Client) FlushAgentsAndBouncers(agentsCfg *csconfig.AuthGCCfg, bouncersCfg *csconfig.AuthGCCfg) error { - log.Debug("starting FlushAgentsAndBouncers") - if bouncersCfg != nil { - if bouncersCfg.ApiDuration != nil { - log.Debug("trying to delete old bouncers from api") - deletionCount, err := c.Ent.Bouncer.Delete().Where( - bouncer.LastPullLTE(time.Now().UTC().Add(-*bouncersCfg.ApiDuration)), - ).Where( - bouncer.AuthTypeEQ(types.ApiKeyAuthType), - ).Exec(c.CTX) - if err != nil { - c.Log.Errorf("while auto-deleting expired bouncers (api key) : %s", err) - } else if deletionCount > 0 { - c.Log.Infof("deleted %d expired bouncers (api auth)", deletionCount) - } - } - if bouncersCfg.CertDuration != nil { - log.Debug("trying to delete old bouncers from cert") - - deletionCount, err := c.Ent.Bouncer.Delete().Where( - bouncer.LastPullLTE(time.Now().UTC().Add(-*bouncersCfg.CertDuration)), - ).Where( - bouncer.AuthTypeEQ(types.TlsAuthType), - ).Exec(c.CTX) - if err != nil { - c.Log.Errorf("while auto-deleting expired bouncers (api key) : %s", err) - } else if deletionCount > 0 { - c.Log.Infof("deleted %d expired bouncers (api auth)", deletionCount) - } - } - } - if agentsCfg != nil { - if agentsCfg.CertDuration != nil { - log.Debug("trying to delete old agents from cert") - - deletionCount, err := c.Ent.Machine.Delete().Where( - machine.LastHeartbeatLTE(time.Now().UTC().Add(-*agentsCfg.CertDuration)), - ).Where( - machine.Not(machine.HasAlerts()), - ).Where( - machine.AuthTypeEQ(types.TlsAuthType), - ).Exec(c.CTX) - log.Debugf("deleted %d entries", deletionCount) - if err != nil { - c.Log.Errorf("while auto-deleting expired machine (cert) : %s", err) - } else if deletionCount > 0 { - c.Log.Infof("deleted %d expired machine (cert auth)", deletionCount) - } - } - if agentsCfg.LoginPasswordDuration != nil { - log.Debug("trying to delete old agents from password") - - deletionCount, err := c.Ent.Machine.Delete().Where( - machine.LastHeartbeatLTE(time.Now().UTC().Add(-*agentsCfg.LoginPasswordDuration)), - ).Where( - machine.Not(machine.HasAlerts()), - ).Where( - machine.AuthTypeEQ(types.PasswordAuthType), - ).Exec(c.CTX) - log.Debugf("deleted %d entries", deletionCount) - if err != nil { - c.Log.Errorf("while auto-deleting expired machine (password) : %s", err) - } else if deletionCount > 0 { - c.Log.Infof("deleted %d expired machine (password auth)", deletionCount) - } - } - } - return nil + return c.Ent.Alert.Delete().Where(preds...).Exec(ctx) } -func (c *Client) FlushAlerts(MaxAge string, MaxItems int) error { - var deletedByAge int - var deletedByNbItem int - var totalAlerts int - var err error - - if !c.CanFlush { - c.Log.Debug("a list is being imported, flushing later") - return nil - } - - c.Log.Debug("Flushing orphan alerts") - c.FlushOrphans() - c.Log.Debug("Done flushing orphan alerts") - totalAlerts, err = c.TotalAlerts() - if err != nil { - c.Log.Warningf("FlushAlerts (max items count) : %s", err) - return fmt.Errorf("unable to get alerts count: %w", err) - } - c.Log.Debugf("FlushAlerts (Total alerts): %d", totalAlerts) - if MaxAge != "" { - filter := map[string][]string{ - "created_before": {MaxAge}, - } - nbDeleted, err := c.DeleteAlertWithFilter(filter) - if err != nil { - c.Log.Warningf("FlushAlerts (max age) : %s", err) - return fmt.Errorf("unable to flush alerts with filter until=%s: %w", MaxAge, err) - } - c.Log.Debugf("FlushAlerts (deleted max age alerts): %d", nbDeleted) - deletedByAge = nbDeleted - } - if MaxItems > 0 { - //We get the highest id for the alerts - //We subtract MaxItems to avoid deleting alerts that are not old enough - //This gives us the oldest alert that we want to keep - //We then delete all the alerts with an id lower than this one - //We can do this because the id is auto-increment, and the database won't reuse the same id twice - lastAlert, err := c.QueryAlertWithFilter(map[string][]string{ - "sort": {"DESC"}, - "limit": {"1"}, - //we do not care about fetching the edges, we just want the id - "with_decisions": {"false"}, - }) - c.Log.Debugf("FlushAlerts (last alert): %+v", lastAlert) - if err != nil { - c.Log.Errorf("FlushAlerts: could not get last alert: %s", err) - return fmt.Errorf("could not get last alert: %w", err) - } - - if len(lastAlert) != 0 { - maxid := lastAlert[0].ID - MaxItems - - c.Log.Debugf("FlushAlerts (max id): %d", maxid) - - if maxid > 0 { - //This may lead to orphan alerts (at least on MySQL), but the next time the flush job will run, they will be deleted - deletedByNbItem, err = c.Ent.Alert.Delete().Where(alert.IDLT(maxid)).Exec(c.CTX) - - if err != nil { - c.Log.Errorf("FlushAlerts: Could not delete alerts : %s", err) - return fmt.Errorf("could not delete alerts: %w", err) - } - } - } - } - if deletedByNbItem > 0 { - c.Log.Infof("flushed %d/%d alerts because max number of alerts has been reached (%d max)", deletedByNbItem, totalAlerts, MaxItems) - } - if deletedByAge > 0 { - c.Log.Infof("flushed %d/%d alerts because they were created %s ago or more", deletedByAge, totalAlerts, MaxAge) - } - return nil -} - -func (c *Client) GetAlertByID(alertID int) (*ent.Alert, error) { - alert, err := c.Ent.Alert.Query().Where(alert.IDEQ(alertID)).WithDecisions().WithEvents().WithMetas().WithOwner().First(c.CTX) +func (c *Client) GetAlertByID(ctx context.Context, alertID int) (*ent.Alert, error) { + alert, err := c.Ent.Alert.Query().Where(alert.IDEQ(alertID)).WithDecisions().WithEvents().WithMetas().WithOwner().First(ctx) if err != nil { /*record not found, 404*/ if ent.IsNotFound(err) { log.Warningf("GetAlertByID (not found): %s", err) return &ent.Alert{}, ItemNotFound } + c.Log.Warningf("GetAlertByID : %s", err) + return &ent.Alert{}, QueryFail } + return alert, nil } diff --git a/pkg/database/bouncers.go b/pkg/database/bouncers.go index 98bfd45873b..04ef830ae72 100644 --- a/pkg/database/bouncers.go +++ b/pkg/database/bouncers.go @@ -1,96 +1,160 @@ package database import ( + "context" "fmt" + "strings" "time" "github.com/pkg/errors" "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer" + "github.com/crowdsecurity/crowdsec/pkg/models" ) -func (c *Client) SelectBouncer(apiKeyHash string) (*ent.Bouncer, error) { - result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash)).First(c.CTX) +type BouncerNotFoundError struct { + BouncerName string +} + +func (e *BouncerNotFoundError) Error() string { + return fmt.Sprintf("'%s' does not exist", e.BouncerName) +} + +func (c *Client) BouncerUpdateBaseMetrics(ctx context.Context, bouncerName string, bouncerType string, baseMetrics models.BaseMetrics) error { + os := baseMetrics.Os + features := strings.Join(baseMetrics.FeatureFlags, ",") + + _, err := c.Ent.Bouncer. + Update(). + Where(bouncer.NameEQ(bouncerName)). + SetNillableVersion(baseMetrics.Version). + SetOsname(*os.Name). + SetOsversion(*os.Version). + SetFeatureflags(features). + SetType(bouncerType). + Save(ctx) if err != nil { - return &ent.Bouncer{}, errors.Wrapf(QueryFail, "select bouncer: %s", err) + return fmt.Errorf("unable to update base bouncer metrics in database: %w", err) + } + + return nil +} + +func (c *Client) SelectBouncer(ctx context.Context, apiKeyHash string) (*ent.Bouncer, error) { + result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash)).First(ctx) + if err != nil { + return nil, err } return result, nil } -func (c *Client) SelectBouncerByName(bouncerName string) (*ent.Bouncer, error) { - result, err := c.Ent.Bouncer.Query().Where(bouncer.NameEQ(bouncerName)).First(c.CTX) +func (c *Client) SelectBouncerByName(ctx context.Context, bouncerName string) (*ent.Bouncer, error) { + result, err := c.Ent.Bouncer.Query().Where(bouncer.NameEQ(bouncerName)).First(ctx) if err != nil { - return &ent.Bouncer{}, errors.Wrapf(QueryFail, "select bouncer: %s", err) + return nil, err } return result, nil } -func (c *Client) ListBouncers() ([]*ent.Bouncer, error) { - result, err := c.Ent.Bouncer.Query().All(c.CTX) +func (c *Client) ListBouncers(ctx context.Context) ([]*ent.Bouncer, error) { + result, err := c.Ent.Bouncer.Query().All(ctx) if err != nil { - return []*ent.Bouncer{}, errors.Wrapf(QueryFail, "listing bouncer: %s", err) + return nil, errors.Wrapf(QueryFail, "listing bouncers: %s", err) } + return result, nil } -func (c *Client) CreateBouncer(name string, ipAddr string, apiKey string, authType string) (*ent.Bouncer, error) { +func (c *Client) CreateBouncer(ctx context.Context, name string, ipAddr string, apiKey string, authType string) (*ent.Bouncer, error) { bouncer, err := c.Ent.Bouncer. Create(). SetName(name). SetAPIKey(apiKey). SetRevoked(false). SetAuthType(authType). - Save(c.CTX) + Save(ctx) if err != nil { if ent.IsConstraintError(err) { return nil, fmt.Errorf("bouncer %s already exists", name) } - return nil, fmt.Errorf("unable to create bouncer: %s", err) + + return nil, fmt.Errorf("unable to create bouncer: %w", err) } + return bouncer, nil } -func (c *Client) DeleteBouncer(name string) error { +func (c *Client) DeleteBouncer(ctx context.Context, name string) error { nbDeleted, err := c.Ent.Bouncer. Delete(). Where(bouncer.NameEQ(name)). - Exec(c.CTX) + Exec(ctx) if err != nil { return err } if nbDeleted == 0 { - return fmt.Errorf("bouncer doesn't exist") + return &BouncerNotFoundError{BouncerName: name} } return nil } -func (c *Client) UpdateBouncerLastPull(lastPull time.Time, ID int) error { - _, err := c.Ent.Bouncer.UpdateOneID(ID). +func (c *Client) BulkDeleteBouncers(ctx context.Context, bouncers []*ent.Bouncer) (int, error) { + ids := make([]int, len(bouncers)) + for i, b := range bouncers { + ids[i] = b.ID + } + + nbDeleted, err := c.Ent.Bouncer.Delete().Where(bouncer.IDIn(ids...)).Exec(ctx) + if err != nil { + return nbDeleted, fmt.Errorf("unable to delete bouncers: %w", err) + } + + return nbDeleted, nil +} + +func (c *Client) UpdateBouncerLastPull(ctx context.Context, lastPull time.Time, id int) error { + _, err := c.Ent.Bouncer.UpdateOneID(id). SetLastPull(lastPull). - Save(c.CTX) + Save(ctx) if err != nil { - return fmt.Errorf("unable to update machine last pull in database: %s", err) + return fmt.Errorf("unable to update machine last pull in database: %w", err) } + return nil } -func (c *Client) UpdateBouncerIP(ipAddr string, ID int) error { - _, err := c.Ent.Bouncer.UpdateOneID(ID).SetIPAddress(ipAddr).Save(c.CTX) +func (c *Client) UpdateBouncerIP(ctx context.Context, ipAddr string, id int) error { + _, err := c.Ent.Bouncer.UpdateOneID(id).SetIPAddress(ipAddr).Save(ctx) if err != nil { - return fmt.Errorf("unable to update bouncer ip address in database: %s", err) + return fmt.Errorf("unable to update bouncer ip address in database: %w", err) } + return nil } -func (c *Client) UpdateBouncerTypeAndVersion(bType string, version string, ID int) error { - _, err := c.Ent.Bouncer.UpdateOneID(ID).SetVersion(version).SetType(bType).Save(c.CTX) +func (c *Client) UpdateBouncerTypeAndVersion(ctx context.Context, bType string, version string, id int) error { + _, err := c.Ent.Bouncer.UpdateOneID(id).SetVersion(version).SetType(bType).Save(ctx) if err != nil { - return fmt.Errorf("unable to update bouncer type and version in database: %s", err) + return fmt.Errorf("unable to update bouncer type and version in database: %w", err) } + return nil } + +func (c *Client) QueryBouncersInactiveSince(ctx context.Context, t time.Time) ([]*ent.Bouncer, error) { + return c.Ent.Bouncer.Query().Where( + // poor man's coalesce + bouncer.Or( + bouncer.LastPullLT(t), + bouncer.And( + bouncer.LastPullIsNil(), + bouncer.CreatedAtLT(t), + ), + ), + ).All(ctx) +} diff --git a/pkg/database/config.go b/pkg/database/config.go index 8c3578ad596..89ccb1e1b28 100644 --- a/pkg/database/config.go +++ b/pkg/database/config.go @@ -1,17 +1,20 @@ package database import ( + "context" + "github.com/pkg/errors" "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" ) -func (c *Client) GetConfigItem(key string) (*string, error) { - result, err := c.Ent.ConfigItem.Query().Where(configitem.NameEQ(key)).First(c.CTX) +func (c *Client) GetConfigItem(ctx context.Context, key string) (*string, error) { + result, err := c.Ent.ConfigItem.Query().Where(configitem.NameEQ(key)).First(ctx) if err != nil && ent.IsNotFound(err) { return nil, nil } + if err != nil { return nil, errors.Wrapf(QueryFail, "select config item: %s", err) } @@ -19,16 +22,16 @@ func (c *Client) GetConfigItem(key string) (*string, error) { return &result.Value, nil } -func (c *Client) SetConfigItem(key string, value string) error { - - nbUpdated, err := c.Ent.ConfigItem.Update().SetValue(value).Where(configitem.NameEQ(key)).Save(c.CTX) - if (err != nil && ent.IsNotFound(err)) || nbUpdated == 0 { //not found, create - err := c.Ent.ConfigItem.Create().SetName(key).SetValue(value).Exec(c.CTX) +func (c *Client) SetConfigItem(ctx context.Context, key string, value string) error { + nbUpdated, err := c.Ent.ConfigItem.Update().SetValue(value).Where(configitem.NameEQ(key)).Save(ctx) + if (err != nil && ent.IsNotFound(err)) || nbUpdated == 0 { // not found, create + err := c.Ent.ConfigItem.Create().SetName(key).SetValue(value).Exec(ctx) if err != nil { return errors.Wrapf(QueryFail, "insert config item: %s", err) } } else if err != nil { return errors.Wrapf(QueryFail, "update config item: %s", err) } + return nil } diff --git a/pkg/database/database.go b/pkg/database/database.go index 46de4e73a73..bb41dd3b645 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -3,31 +3,29 @@ package database import ( "context" "database/sql" + "errors" "fmt" "os" - "time" entsql "entgo.io/ent/dialect/sql" - "github.com/go-co-op/gocron" + // load database backends _ "github.com/go-sql-driver/mysql" _ "github.com/jackc/pgx/v4/stdlib" _ "github.com/mattn/go-sqlite3" log "github.com/sirupsen/logrus" - "github.com/crowdsecurity/go-cs-lib/pkg/ptr" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/types" ) type Client struct { - Ent *ent.Client - CTX context.Context - Log *log.Logger - CanFlush bool - Type string - WalMode *bool + Ent *ent.Client + Log *log.Logger + CanFlush bool + Type string + WalMode *bool + decisionBulkSize int } func getEntDriver(dbtype string, dbdialect string, dsn string, config *csconfig.DatabaseCfg) (*entsql.Driver, error) { @@ -35,132 +33,82 @@ func getEntDriver(dbtype string, dbdialect string, dsn string, config *csconfig. if err != nil { return nil, err } - if config.MaxOpenConns == nil { - log.Warningf("MaxOpenConns is 0, defaulting to %d", csconfig.DEFAULT_MAX_OPEN_CONNS) - config.MaxOpenConns = ptr.Of(csconfig.DEFAULT_MAX_OPEN_CONNS) + + if config.MaxOpenConns == 0 { + config.MaxOpenConns = csconfig.DEFAULT_MAX_OPEN_CONNS } - db.SetMaxOpenConns(*config.MaxOpenConns) + + db.SetMaxOpenConns(config.MaxOpenConns) drv := entsql.OpenDB(dbdialect, db) + return drv, nil } -func NewClient(config *csconfig.DatabaseCfg) (*Client, error) { +func NewClient(ctx context.Context, config *csconfig.DatabaseCfg) (*Client, error) { var client *ent.Client - var err error + if config == nil { - return &Client{}, fmt.Errorf("DB config is empty") + return nil, errors.New("DB config is empty") } /*The logger that will be used by db operations*/ clog := log.New() if err := types.ConfigureLogger(clog); err != nil { return nil, fmt.Errorf("while configuring db logger: %w", err) } + if config.LogLevel != nil { clog.SetLevel(*config.LogLevel) } - entLogger := clog.WithField("context", "ent") + entLogger := clog.WithField("context", "ent") entOpt := ent.Log(entLogger.Debug) + typ, dia, err := config.ConnectionDialect() if err != nil { - return &Client{}, err //unsupported database caught here + return nil, err // unsupported database caught here } + if config.Type == "sqlite" { /*if it's the first startup, we want to touch and chmod file*/ - if _, err := os.Stat(config.DbPath); os.IsNotExist(err) { - f, err := os.OpenFile(config.DbPath, os.O_CREATE|os.O_RDWR, 0600) + if _, err = os.Stat(config.DbPath); os.IsNotExist(err) { + f, err := os.OpenFile(config.DbPath, os.O_CREATE|os.O_RDWR, 0o600) if err != nil { - return &Client{}, fmt.Errorf("failed to create SQLite database file %q: %w", config.DbPath, err) + return nil, fmt.Errorf("failed to create SQLite database file %q: %w", config.DbPath, err) } + if err := f.Close(); err != nil { - return &Client{}, fmt.Errorf("failed to create SQLite database file %q: %w", config.DbPath, err) + return nil, fmt.Errorf("failed to create SQLite database file %q: %w", config.DbPath, err) } } - //Always try to set permissions to simplify a bit the code for windows (as the permissions set by OpenFile will be garbage) - if err := setFilePerm(config.DbPath, 0640); err != nil { - return &Client{}, fmt.Errorf("unable to set perms on %s: %v", config.DbPath, err) + // Always try to set permissions to simplify a bit the code for windows (as the permissions set by OpenFile will be garbage) + if err = setFilePerm(config.DbPath, 0o640); err != nil { + return nil, fmt.Errorf("unable to set perms on %s: %w", config.DbPath, err) } } + drv, err := getEntDriver(typ, dia, config.ConnectionString(), config) if err != nil { - return &Client{}, fmt.Errorf("failed opening connection to %s: %v", config.Type, err) + return nil, fmt.Errorf("failed opening connection to %s: %w", config.Type, err) } + client = ent.NewClient(ent.Driver(drv), entOpt) + if config.LogLevel != nil && *config.LogLevel >= log.DebugLevel { clog.Debugf("Enabling request debug") - client = client.Debug() - } - if err = client.Schema.Create(context.Background()); err != nil { - return nil, fmt.Errorf("failed creating schema resources: %v", err) - } - return &Client{Ent: client, CTX: context.Background(), Log: clog, CanFlush: true, Type: config.Type, WalMode: config.UseWal}, nil -} -func (c *Client) StartFlushScheduler(config *csconfig.FlushDBCfg) (*gocron.Scheduler, error) { - maxItems := 0 - maxAge := "" - if config.MaxItems != nil && *config.MaxItems <= 0 { - return nil, fmt.Errorf("max_items can't be zero or negative number") - } - if config.MaxItems != nil { - maxItems = *config.MaxItems - } - if config.MaxAge != nil && *config.MaxAge != "" { - maxAge = *config.MaxAge + client = client.Debug() } - // Init & Start cronjob every minute for alerts - scheduler := gocron.NewScheduler(time.UTC) - job, err := scheduler.Every(1).Minute().Do(c.FlushAlerts, maxAge, maxItems) - if err != nil { - return nil, fmt.Errorf("while starting FlushAlerts scheduler: %w", err) - } - job.SingletonMode() - // Init & Start cronjob every hour for bouncers/agents - if config.AgentsGC != nil { - if config.AgentsGC.Cert != nil { - duration, err := types.ParseDuration(*config.AgentsGC.Cert) - if err != nil { - return nil, fmt.Errorf("while parsing agents cert auto-delete duration: %w", err) - } - config.AgentsGC.CertDuration = &duration - } - if config.AgentsGC.LoginPassword != nil { - duration, err := types.ParseDuration(*config.AgentsGC.LoginPassword) - if err != nil { - return nil, fmt.Errorf("while parsing agents login/password auto-delete duration: %w", err) - } - config.AgentsGC.LoginPasswordDuration = &duration - } - if config.AgentsGC.Api != nil { - log.Warning("agents auto-delete for API auth is not supported (use cert or login_password)") - } - } - if config.BouncersGC != nil { - if config.BouncersGC.Cert != nil { - duration, err := types.ParseDuration(*config.BouncersGC.Cert) - if err != nil { - return nil, fmt.Errorf("while parsing bouncers cert auto-delete duration: %w", err) - } - config.BouncersGC.CertDuration = &duration - } - if config.BouncersGC.Api != nil { - duration, err := types.ParseDuration(*config.BouncersGC.Api) - if err != nil { - return nil, fmt.Errorf("while parsing bouncers api auto-delete duration: %w", err) - } - config.BouncersGC.ApiDuration = &duration - } - if config.BouncersGC.LoginPassword != nil { - log.Warning("bouncers auto-delete for login/password auth is not supported (use cert or api)") - } - } - baJob, err := scheduler.Every(1).Minute().Do(c.FlushAgentsAndBouncers, config.AgentsGC, config.BouncersGC) - if err != nil { - return nil, fmt.Errorf("while starting FlushAgentsAndBouncers scheduler: %w", err) + if err = client.Schema.Create(ctx); err != nil { + return nil, fmt.Errorf("failed creating schema resources: %w", err) } - baJob.SingletonMode() - scheduler.StartAsync() - return scheduler, nil + return &Client{ + Ent: client, + Log: clog, + CanFlush: true, + Type: config.Type, + WalMode: config.UseWal, + decisionBulkSize: config.DecisionBulkSize, + }, nil } diff --git a/pkg/database/decisions.go b/pkg/database/decisions.go index cf4b9c966c1..7522a272799 100644 --- a/pkg/database/decisions.go +++ b/pkg/database/decisions.go @@ -1,6 +1,7 @@ package database import ( + "context" "fmt" "strconv" "strings" @@ -9,12 +10,16 @@ import ( "entgo.io/ent/dialect/sql" "github.com/pkg/errors" + "github.com/crowdsecurity/go-cs-lib/slicetools" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" "github.com/crowdsecurity/crowdsec/pkg/types" ) +const decisionDeleteBulkSize = 256 // scientifically proven to be the best value for bulk delete + type DecisionsByScenario struct { Scenario string Count int @@ -23,11 +28,10 @@ type DecisionsByScenario struct { } func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string][]string) (*ent.DecisionQuery, error) { - var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz int - var contains = true + contains := true /*if contains is true, return bans that *contains* the given value (value is the inner) else, return bans that are *contained* by the given value (value is the outer)*/ @@ -36,6 +40,7 @@ func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string] if v[0] == "false" { query = query.Where(decision.SimulatedEQ(false)) } + delete(filter, "simulated") } else { query = query.Where(decision.SimulatedEQ(false)) @@ -48,7 +53,7 @@ func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string] if err != nil { return nil, errors.Wrapf(InvalidFilter, "invalid contains value : %s", err) } - case "scopes": + case "scopes", "scope": // Swagger mentions both of them, let's just support both to make sure we don't break anything scopes := strings.Split(value[0], ",") for i, scope := range scopes { switch strings.ToLower(scope) { @@ -62,6 +67,7 @@ func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string] scopes[i] = types.AS } } + query = query.Where(decision.ScopeIn(scopes...)) case "value": query = query.Where(decision.ValueEQ(value[0])) @@ -106,23 +112,25 @@ func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string] query = query.Where(decision.IDGT(id)) } } + query, err = applyStartIpEndIpFilter(query, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx) if err != nil { return nil, fmt.Errorf("fail to apply StartIpEndIpFilter: %w", err) } + return query, nil } -func (c *Client) QueryAllDecisionsWithFilters(filters map[string][]string) ([]*ent.Decision, error) { + +func (c *Client) QueryAllDecisionsWithFilters(ctx context.Context, filters map[string][]string) ([]*ent.Decision, error) { query := c.Ent.Decision.Query().Where( decision.UntilGT(time.Now().UTC()), ) - //Allow a bouncer to ask for non-deduplicated results + // Allow a bouncer to ask for non-deduplicated results if v, ok := filters["dedup"]; !ok || v[0] != "false" { query = query.Where(longestDecisionForScopeTypeValue) } query, err := BuildDecisionRequestWithFilter(query, filters) - if err != nil { c.Log.Warningf("QueryAllDecisionsWithFilters : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "get all decisions with filters") @@ -130,19 +138,20 @@ func (c *Client) QueryAllDecisionsWithFilters(filters map[string][]string) ([]*e query = query.Order(ent.Asc(decision.FieldID)) - data, err := query.All(c.CTX) + data, err := query.All(ctx) if err != nil { c.Log.Warningf("QueryAllDecisionsWithFilters : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "get all decisions with filters") } + return data, nil } -func (c *Client) QueryExpiredDecisionsWithFilters(filters map[string][]string) ([]*ent.Decision, error) { +func (c *Client) QueryExpiredDecisionsWithFilters(ctx context.Context, filters map[string][]string) ([]*ent.Decision, error) { query := c.Ent.Decision.Query().Where( decision.UntilLT(time.Now().UTC()), ) - //Allow a bouncer to ask for non-deduplicated results + // Allow a bouncer to ask for non-deduplicated results if v, ok := filters["dedup"]; !ok || v[0] != "false" { query = query.Where(longestDecisionForScopeTypeValue) } @@ -155,20 +164,22 @@ func (c *Client) QueryExpiredDecisionsWithFilters(filters map[string][]string) ( c.Log.Warningf("QueryExpiredDecisionsWithFilters : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "get expired decisions with filters") } - data, err := query.All(c.CTX) + + data, err := query.All(ctx) if err != nil { c.Log.Warningf("QueryExpiredDecisionsWithFilters : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions") } + return data, nil } -func (c *Client) QueryDecisionCountByScenario(filters map[string][]string) ([]*DecisionsByScenario, error) { +func (c *Client) QueryDecisionCountByScenario(ctx context.Context) ([]*DecisionsByScenario, error) { query := c.Ent.Decision.Query().Where( decision.UntilGT(time.Now().UTC()), ) - query, err := BuildDecisionRequestWithFilter(query, filters) + query, err := BuildDecisionRequestWithFilter(query, make(map[string][]string)) if err != nil { c.Log.Warningf("QueryDecisionCountByScenario : %s", err) return nil, errors.Wrap(QueryFail, "count all decisions with filters") @@ -176,8 +187,7 @@ func (c *Client) QueryDecisionCountByScenario(filters map[string][]string) ([]*D var r []*DecisionsByScenario - err = query.GroupBy(decision.FieldScenario, decision.FieldOrigin, decision.FieldType).Aggregate(ent.Count()).Scan(c.CTX, &r) - + err = query.GroupBy(decision.FieldScenario, decision.FieldOrigin, decision.FieldType).Aggregate(ent.Count()).Scan(ctx, &r) if err != nil { c.Log.Warningf("QueryDecisionCountByScenario : %s", err) return nil, errors.Wrap(QueryFail, "count all decisions with filters") @@ -186,7 +196,7 @@ func (c *Client) QueryDecisionCountByScenario(filters map[string][]string) ([]*D return r, nil } -func (c *Client) QueryDecisionWithFilter(filter map[string][]string) ([]*ent.Decision, error) { +func (c *Client) QueryDecisionWithFilter(ctx context.Context, filter map[string][]string) ([]*ent.Decision, error) { var data []*ent.Decision var err error @@ -208,7 +218,7 @@ func (c *Client) QueryDecisionWithFilter(filter map[string][]string) ([]*ent.Dec decision.FieldValue, decision.FieldScope, decision.FieldOrigin, - ).Scan(c.CTX, &data) + ).Scan(ctx, &data) if err != nil { c.Log.Warningf("QueryDecisionWithFilter : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "query decision failed") @@ -245,15 +255,20 @@ func longestDecisionForScopeTypeValue(s *sql.Selector) { ) } -func (c *Client) QueryExpiredDecisionsSinceWithFilters(since time.Time, filters map[string][]string) ([]*ent.Decision, error) { +func (c *Client) QueryExpiredDecisionsSinceWithFilters(ctx context.Context, since *time.Time, filters map[string][]string) ([]*ent.Decision, error) { query := c.Ent.Decision.Query().Where( decision.UntilLT(time.Now().UTC()), - decision.UntilGT(since), ) - //Allow a bouncer to ask for non-deduplicated results + + if since != nil { + query = query.Where(decision.UntilGT(*since)) + } + + // Allow a bouncer to ask for non-deduplicated results if v, ok := filters["dedup"]; !ok || v[0] != "false" { query = query.Where(longestDecisionForScopeTypeValue) } + query, err := BuildDecisionRequestWithFilter(query, filters) if err != nil { c.Log.Warningf("QueryExpiredDecisionsSinceWithFilters : %s", err) @@ -262,7 +277,7 @@ func (c *Client) QueryExpiredDecisionsSinceWithFilters(since time.Time, filters query = query.Order(ent.Asc(decision.FieldID)) - data, err := query.All(c.CTX) + data, err := query.All(ctx) if err != nil { c.Log.Warningf("QueryExpiredDecisionsSinceWithFilters : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions with filters") @@ -271,15 +286,20 @@ func (c *Client) QueryExpiredDecisionsSinceWithFilters(since time.Time, filters return data, nil } -func (c *Client) QueryNewDecisionsSinceWithFilters(since time.Time, filters map[string][]string) ([]*ent.Decision, error) { +func (c *Client) QueryNewDecisionsSinceWithFilters(ctx context.Context, since *time.Time, filters map[string][]string) ([]*ent.Decision, error) { query := c.Ent.Decision.Query().Where( - decision.CreatedAtGT(since), decision.UntilGT(time.Now().UTC()), ) - //Allow a bouncer to ask for non-deduplicated results + + if since != nil { + query = query.Where(decision.CreatedAtGT(*since)) + } + + // Allow a bouncer to ask for non-deduplicated results if v, ok := filters["dedup"]; !ok || v[0] != "false" { query = query.Where(longestDecisionForScopeTypeValue) } + query, err := BuildDecisionRequestWithFilter(query, filters) if err != nil { c.Log.Warningf("QueryNewDecisionsSinceWithFilters : %s", err) @@ -288,34 +308,25 @@ func (c *Client) QueryNewDecisionsSinceWithFilters(since time.Time, filters map[ query = query.Order(ent.Asc(decision.FieldID)) - data, err := query.All(c.CTX) + data, err := query.All(ctx) if err != nil { c.Log.Warningf("QueryNewDecisionsSinceWithFilters : %s", err) return []*ent.Decision{}, errors.Wrapf(QueryFail, "new decisions since '%s'", since.String()) } - return data, nil -} -func (c *Client) DeleteDecisionById(decisionId int) ([]*ent.Decision, error) { - toDelete, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionId)).All(c.CTX) - if err != nil { - c.Log.Warningf("DeleteDecisionById : %s", err) - return nil, errors.Wrapf(DeleteFail, "decision with id '%d' doesn't exist", decisionId) - } - count, err := c.BulkDeleteDecisions(toDelete, false) - c.Log.Debugf("deleted %d decisions", count) - return toDelete, err + return data, nil } -func (c *Client) DeleteDecisionsWithFilter(filter map[string][]string) (string, []*ent.Decision, error) { +func (c *Client) DeleteDecisionsWithFilter(ctx context.Context, filter map[string][]string) (string, []*ent.Decision, error) { var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz int - var contains = true + contains := true /*if contains is true, return bans that *contains* the given value (value is the inner) else, return bans that are *contained* by the given value (value is the outer) */ decisions := c.Ent.Decision.Query() + for param, value := range filter { switch param { case "contains": @@ -358,48 +369,48 @@ func (c *Client) DeleteDecisionsWithFilter(filter map[string][]string) (string, } else if ip_sz == 16 { if contains { /*decision contains {start_ip,end_ip}*/ decisions = decisions.Where(decision.And( - //matching addr size + // matching addr size decision.IPSizeEQ(int64(ip_sz)), decision.Or( - //decision.start_ip < query.start_ip + // decision.start_ip < query.start_ip decision.StartIPLT(start_ip), decision.And( - //decision.start_ip == query.start_ip + // decision.start_ip == query.start_ip decision.StartIPEQ(start_ip), - //decision.start_suffix <= query.start_suffix + // decision.start_suffix <= query.start_suffix decision.StartSuffixLTE(start_sfx), )), decision.Or( - //decision.end_ip > query.end_ip + // decision.end_ip > query.end_ip decision.EndIPGT(end_ip), decision.And( - //decision.end_ip == query.end_ip + // decision.end_ip == query.end_ip decision.EndIPEQ(end_ip), - //decision.end_suffix >= query.end_suffix + // decision.end_suffix >= query.end_suffix decision.EndSuffixGTE(end_sfx), ), ), )) } else { decisions = decisions.Where(decision.And( - //matching addr size + // matching addr size decision.IPSizeEQ(int64(ip_sz)), decision.Or( - //decision.start_ip > query.start_ip + // decision.start_ip > query.start_ip decision.StartIPGT(start_ip), decision.And( - //decision.start_ip == query.start_ip + // decision.start_ip == query.start_ip decision.StartIPEQ(start_ip), - //decision.start_suffix >= query.start_suffix + // decision.start_suffix >= query.start_suffix decision.StartSuffixGTE(start_sfx), )), decision.Or( - //decision.end_ip < query.end_ip + // decision.end_ip < query.end_ip decision.EndIPLT(end_ip), decision.And( - //decision.end_ip == query.end_ip + // decision.end_ip == query.end_ip decision.EndIPEQ(end_ip), - //decision.end_suffix <= query.end_suffix + // decision.end_suffix <= query.end_suffix decision.EndSuffixLTE(end_sfx), ), ), @@ -409,28 +420,31 @@ func (c *Client) DeleteDecisionsWithFilter(filter map[string][]string) (string, return "0", nil, errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz) } - toDelete, err := decisions.All(c.CTX) + toDelete, err := decisions.All(ctx) if err != nil { c.Log.Warningf("DeleteDecisionsWithFilter : %s", err) return "0", nil, errors.Wrap(DeleteFail, "decisions with provided filter") } - count, err := c.BulkDeleteDecisions(toDelete, false) + + count, err := c.DeleteDecisions(ctx, toDelete) if err != nil { c.Log.Warningf("While deleting decisions : %s", err) return "0", nil, errors.Wrap(DeleteFail, "decisions with provided filter") } + return strconv.Itoa(count), toDelete, nil } -// SoftDeleteDecisionsWithFilter updates the expiration time to now() for the decisions matching the filter, and returns the updated items -func (c *Client) SoftDeleteDecisionsWithFilter(filter map[string][]string) (string, []*ent.Decision, error) { +// ExpireDecisionsWithFilter updates the expiration time to now() for the decisions matching the filter, and returns the updated items +func (c *Client) ExpireDecisionsWithFilter(ctx context.Context, filter map[string][]string) (string, []*ent.Decision, error) { var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz int - var contains = true + contains := true /*if contains is true, return bans that *contains* the given value (value is the inner) else, return bans that are *contained* by the given value (value is the outer)*/ decisions := c.Ent.Decision.Query().Where(decision.UntilGT(time.Now().UTC())) + for param, value := range filter { switch param { case "contains": @@ -479,24 +493,24 @@ func (c *Client) SoftDeleteDecisionsWithFilter(filter map[string][]string) (stri /*decision contains {start_ip,end_ip}*/ if contains { decisions = decisions.Where(decision.And( - //matching addr size + // matching addr size decision.IPSizeEQ(int64(ip_sz)), decision.Or( - //decision.start_ip < query.start_ip + // decision.start_ip < query.start_ip decision.StartIPLT(start_ip), decision.And( - //decision.start_ip == query.start_ip + // decision.start_ip == query.start_ip decision.StartIPEQ(start_ip), - //decision.start_suffix <= query.start_suffix + // decision.start_suffix <= query.start_suffix decision.StartSuffixLTE(start_sfx), )), decision.Or( - //decision.end_ip > query.end_ip + // decision.end_ip > query.end_ip decision.EndIPGT(end_ip), decision.And( - //decision.end_ip == query.end_ip + // decision.end_ip == query.end_ip decision.EndIPEQ(end_ip), - //decision.end_suffix >= query.end_suffix + // decision.end_suffix >= query.end_suffix decision.EndSuffixGTE(end_sfx), ), ), @@ -504,24 +518,24 @@ func (c *Client) SoftDeleteDecisionsWithFilter(filter map[string][]string) (stri } else { /*decision is contained within {start_ip,end_ip}*/ decisions = decisions.Where(decision.And( - //matching addr size + // matching addr size decision.IPSizeEQ(int64(ip_sz)), decision.Or( - //decision.start_ip > query.start_ip + // decision.start_ip > query.start_ip decision.StartIPGT(start_ip), decision.And( - //decision.start_ip == query.start_ip + // decision.start_ip == query.start_ip decision.StartIPEQ(start_ip), - //decision.start_suffix >= query.start_suffix + // decision.start_suffix >= query.start_suffix decision.StartSuffixGTE(start_sfx), )), decision.Or( - //decision.end_ip < query.end_ip + // decision.end_ip < query.end_ip decision.EndIPLT(end_ip), decision.And( - //decision.end_ip == query.end_ip + // decision.end_ip == query.end_ip decision.EndIPEQ(end_ip), - //decision.end_suffix <= query.end_suffix + // decision.end_suffix <= query.end_suffix decision.EndSuffixLTE(end_sfx), ), ), @@ -530,107 +544,132 @@ func (c *Client) SoftDeleteDecisionsWithFilter(filter map[string][]string) (stri } else if ip_sz != 0 { return "0", nil, errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz) } - DecisionsToDelete, err := decisions.All(c.CTX) + + DecisionsToDelete, err := decisions.All(ctx) if err != nil { - c.Log.Warningf("SoftDeleteDecisionsWithFilter : %s", err) - return "0", nil, errors.Wrap(DeleteFail, "soft delete decisions with provided filter") + c.Log.Warningf("ExpireDecisionsWithFilter : %s", err) + return "0", nil, errors.Wrap(DeleteFail, "expire decisions with provided filter") } - count, err := c.BulkDeleteDecisions(DecisionsToDelete, true) + count, err := c.ExpireDecisions(ctx, DecisionsToDelete) if err != nil { - return "0", nil, errors.Wrapf(DeleteFail, "soft delete decisions with provided filter : %s", err) + return "0", nil, errors.Wrapf(DeleteFail, "expire decisions with provided filter : %s", err) } + return strconv.Itoa(count), DecisionsToDelete, err } -// BulkDeleteDecisions set the expiration of a bulk of decisions to now() or hard deletes them. -// We are doing it this way so we can return impacted decisions for sync with CAPI/PAPI -func (c *Client) BulkDeleteDecisions(DecisionsToDelete []*ent.Decision, softDelete bool) (int, error) { - bulkSize := 256 //scientifically proven to be the best value for bulk delete - idsToDelete := make([]int, 0, bulkSize) - totalUpdates := 0 - for i := 0; i < len(DecisionsToDelete); i++ { - idsToDelete = append(idsToDelete, DecisionsToDelete[i].ID) - if len(idsToDelete) == bulkSize { - - if softDelete { - nbUpdates, err := c.Ent.Decision.Update().Where( - decision.IDIn(idsToDelete...), - ).SetUntil(time.Now().UTC()).Save(c.CTX) - if err != nil { - return totalUpdates, errors.Wrap(err, "soft delete decisions with provided filter") - } - totalUpdates += nbUpdates - } else { - nbUpdates, err := c.Ent.Decision.Delete().Where( - decision.IDIn(idsToDelete...), - ).Exec(c.CTX) - if err != nil { - return totalUpdates, errors.Wrap(err, "hard delete decisions with provided filter") - } - totalUpdates += nbUpdates - } - idsToDelete = make([]int, 0, bulkSize) +func decisionIDs(decisions []*ent.Decision) []int { + ids := make([]int, len(decisions)) + for i, d := range decisions { + ids[i] = d.ID + } + + return ids +} + +// ExpireDecisions sets the expiration of a list of decisions to now() +// It returns the number of impacted decisions for the CAPI/PAPI +func (c *Client) ExpireDecisions(ctx context.Context, decisions []*ent.Decision) (int, error) { + if len(decisions) <= decisionDeleteBulkSize { + ids := decisionIDs(decisions) + + rows, err := c.Ent.Decision.Update().Where( + decision.IDIn(ids...), + ).SetUntil(time.Now().UTC()).Save(ctx) + if err != nil { + return 0, fmt.Errorf("expire decisions with provided filter: %w", err) } + + return rows, nil } - if len(idsToDelete) > 0 { - if softDelete { - nbUpdates, err := c.Ent.Decision.Update().Where( - decision.IDIn(idsToDelete...), - ).SetUntil(time.Now().UTC()).Save(c.CTX) - if err != nil { - return totalUpdates, errors.Wrap(err, "soft delete decisions with provided filter") - } - totalUpdates += nbUpdates - } else { - nbUpdates, err := c.Ent.Decision.Delete().Where( - decision.IDIn(idsToDelete...), - ).Exec(c.CTX) - if err != nil { - return totalUpdates, errors.Wrap(err, "hard delete decisions with provided filter") - } - totalUpdates += nbUpdates + // big batch, let's split it and recurse + + total := 0 + + for _, chunk := range slicetools.Chunks(decisions, decisionDeleteBulkSize) { + rows, err := c.ExpireDecisions(ctx, chunk) + if err != nil { + return total, err } + total += rows } - return totalUpdates, nil + + return total, nil } -// SoftDeleteDecisionByID set the expiration of a decision to now() -func (c *Client) SoftDeleteDecisionByID(decisionID int) (int, []*ent.Decision, error) { - toUpdate, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(c.CTX) +// DeleteDecisions removes a list of decisions from the database +// It returns the number of impacted decisions for the CAPI/PAPI +func (c *Client) DeleteDecisions(ctx context.Context, decisions []*ent.Decision) (int, error) { + if len(decisions) < decisionDeleteBulkSize { + ids := decisionIDs(decisions) + + rows, err := c.Ent.Decision.Delete().Where( + decision.IDIn(ids...), + ).Exec(ctx) + if err != nil { + return 0, fmt.Errorf("hard delete decisions with provided filter: %w", err) + } + + return rows, nil + } + + // big batch, let's split it and recurse + + tot := 0 + + for _, chunk := range slicetools.Chunks(decisions, decisionDeleteBulkSize) { + rows, err := c.DeleteDecisions(ctx, chunk) + if err != nil { + return tot, err + } + + tot += rows + } + return tot, nil +} + +// ExpireDecision set the expiration of a decision to now() +func (c *Client) ExpireDecisionByID(ctx context.Context, decisionID int) (int, []*ent.Decision, error) { + toUpdate, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(ctx) + + // XXX: do we want 500 or 404 here? if err != nil || len(toUpdate) == 0 { - c.Log.Warningf("SoftDeleteDecisionByID : %v (nb soft deleted: %d)", err, len(toUpdate)) + c.Log.Warningf("ExpireDecisionByID : %v (nb expired: %d)", err, len(toUpdate)) return 0, nil, errors.Wrapf(DeleteFail, "decision with id '%d' doesn't exist", decisionID) } if len(toUpdate) == 0 { return 0, nil, ItemNotFound } - count, err := c.BulkDeleteDecisions(toUpdate, true) + + count, err := c.ExpireDecisions(ctx, toUpdate) + return count, toUpdate, err } -func (c *Client) CountDecisionsByValue(decisionValue string) (int, error) { +func (c *Client) CountDecisionsByValue(ctx context.Context, decisionValue string) (int, error) { var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz, count int - ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(decisionValue) + ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(decisionValue) if err != nil { return 0, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", decisionValue, err) } contains := true decisions := c.Ent.Decision.Query() + decisions, err = applyStartIpEndIpFilter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx) if err != nil { return 0, errors.Wrapf(err, "fail to apply StartIpEndIpFilter") } - count, err = decisions.Count(c.CTX) + count, err = decisions.Count(ctx) if err != nil { return 0, errors.Wrapf(err, "fail to count decisions") } @@ -638,12 +677,70 @@ func (c *Client) CountDecisionsByValue(decisionValue string) (int, error) { return count, nil } -func (c *Client) CountDecisionsSinceByValue(decisionValue string, since time.Time) (int, error) { +func (c *Client) CountActiveDecisionsByValue(ctx context.Context, decisionValue string) (int, error) { var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz, count int + + ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(decisionValue) + if err != nil { + return 0, fmt.Errorf("unable to convert '%s' to int: %w", decisionValue, err) + } + + contains := true + decisions := c.Ent.Decision.Query() + + decisions, err = applyStartIpEndIpFilter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx) + if err != nil { + return 0, fmt.Errorf("fail to apply StartIpEndIpFilter: %w", err) + } + + decisions = decisions.Where(decision.UntilGT(time.Now().UTC())) + + count, err = decisions.Count(ctx) + if err != nil { + return 0, fmt.Errorf("fail to count decisions: %w", err) + } + + return count, nil +} + +func (c *Client) GetActiveDecisionsTimeLeftByValue(ctx context.Context, decisionValue string) (time.Duration, error) { + var err error + var start_ip, start_sfx, end_ip, end_sfx int64 + var ip_sz int + ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(decisionValue) + if err != nil { + return 0, fmt.Errorf("unable to convert '%s' to int: %w", decisionValue, err) + } + + contains := true + decisions := c.Ent.Decision.Query().Where( + decision.UntilGT(time.Now().UTC()), + ) + decisions, err = applyStartIpEndIpFilter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx) + if err != nil { + return 0, fmt.Errorf("fail to apply StartIpEndIpFilter: %w", err) + } + + decisions = decisions.Order(ent.Desc(decision.FieldUntil)) + + decision, err := decisions.First(ctx) + if err != nil && !ent.IsNotFound(err) { + return 0, fmt.Errorf("fail to get decision: %w", err) + } + + if decision == nil { + return 0, nil + } + + return decision.Until.Sub(time.Now().UTC()), nil +} + +func (c *Client) CountDecisionsSinceByValue(ctx context.Context, decisionValue string, since time.Time) (int, error) { + ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(decisionValue) if err != nil { return 0, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", decisionValue, err) } @@ -652,11 +749,13 @@ func (c *Client) CountDecisionsSinceByValue(decisionValue string, since time.Tim decisions := c.Ent.Decision.Query().Where( decision.CreatedAtGT(since), ) + decisions, err = applyStartIpEndIpFilter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx) if err != nil { return 0, errors.Wrapf(err, "fail to apply StartIpEndIpFilter") } - count, err = decisions.Count(c.CTX) + + count, err := decisions.Count(ctx) if err != nil { return 0, errors.Wrapf(err, "fail to count decisions") } @@ -681,28 +780,32 @@ func applyStartIpEndIpFilter(decisions *ent.DecisionQuery, contains bool, ip_sz decision.IPSizeEQ(int64(ip_sz)), )) } - } else if ip_sz == 16 { + + return decisions, nil + } + + if ip_sz == 16 { /*decision contains {start_ip,end_ip}*/ if contains { decisions = decisions.Where(decision.And( - //matching addr size + // matching addr size decision.IPSizeEQ(int64(ip_sz)), decision.Or( - //decision.start_ip < query.start_ip + // decision.start_ip < query.start_ip decision.StartIPLT(start_ip), decision.And( - //decision.start_ip == query.start_ip + // decision.start_ip == query.start_ip decision.StartIPEQ(start_ip), - //decision.start_suffix <= query.start_suffix + // decision.start_suffix <= query.start_suffix decision.StartSuffixLTE(start_sfx), )), decision.Or( - //decision.end_ip > query.end_ip + // decision.end_ip > query.end_ip decision.EndIPGT(end_ip), decision.And( - //decision.end_ip == query.end_ip + // decision.end_ip == query.end_ip decision.EndIPEQ(end_ip), - //decision.end_suffix >= query.end_suffix + // decision.end_suffix >= query.end_suffix decision.EndSuffixGTE(end_sfx), ), ), @@ -710,40 +813,47 @@ func applyStartIpEndIpFilter(decisions *ent.DecisionQuery, contains bool, ip_sz } else { /*decision is contained within {start_ip,end_ip}*/ decisions = decisions.Where(decision.And( - //matching addr size + // matching addr size decision.IPSizeEQ(int64(ip_sz)), decision.Or( - //decision.start_ip > query.start_ip + // decision.start_ip > query.start_ip decision.StartIPGT(start_ip), decision.And( - //decision.start_ip == query.start_ip + // decision.start_ip == query.start_ip decision.StartIPEQ(start_ip), - //decision.start_suffix >= query.start_suffix + // decision.start_suffix >= query.start_suffix decision.StartSuffixGTE(start_sfx), )), decision.Or( - //decision.end_ip < query.end_ip + // decision.end_ip < query.end_ip decision.EndIPLT(end_ip), decision.And( - //decision.end_ip == query.end_ip + // decision.end_ip == query.end_ip decision.EndIPEQ(end_ip), - //decision.end_suffix <= query.end_suffix + // decision.end_suffix <= query.end_suffix decision.EndSuffixLTE(end_sfx), ), ), )) } - } else if ip_sz != 0 { + + return decisions, nil + } + + if ip_sz != 0 { return nil, errors.Wrapf(InvalidFilter, "unknown ip size %d", ip_sz) } + return decisions, nil } func decisionPredicatesFromStr(s string, predicateFunc func(string) predicate.Decision) []predicate.Decision { words := strings.Split(s, ",") predicates := make([]predicate.Decision, len(words)) + for i, word := range words { predicates[i] = predicateFunc(word) } + return predicates } diff --git a/pkg/database/ent/alert.go b/pkg/database/ent/alert.go index 2649923bf5e..eb0e1cb7612 100644 --- a/pkg/database/ent/alert.go +++ b/pkg/database/ent/alert.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" @@ -18,9 +19,9 @@ type Alert struct { // ID of the ent. ID int `json:"id,omitempty"` // CreatedAt holds the value of the "created_at" field. - CreatedAt *time.Time `json:"created_at,omitempty"` + CreatedAt time.Time `json:"created_at,omitempty"` // UpdatedAt holds the value of the "updated_at" field. - UpdatedAt *time.Time `json:"updated_at,omitempty"` + UpdatedAt time.Time `json:"updated_at,omitempty"` // Scenario holds the value of the "scenario" field. Scenario string `json:"scenario,omitempty"` // BucketId holds the value of the "bucketId" field. @@ -63,10 +64,13 @@ type Alert struct { Simulated bool `json:"simulated,omitempty"` // UUID holds the value of the "uuid" field. UUID string `json:"uuid,omitempty"` + // Remediation holds the value of the "remediation" field. + Remediation bool `json:"remediation,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the AlertQuery when eager-loading is set. Edges AlertEdges `json:"edges"` machine_alerts *int + selectValues sql.SelectValues } // AlertEdges holds the relations/edges for other nodes in the graph. @@ -87,12 +91,10 @@ type AlertEdges struct { // OwnerOrErr returns the Owner value or an error if the edge // was not loaded in eager-loading, or loaded but was not found. func (e AlertEdges) OwnerOrErr() (*Machine, error) { - if e.loadedTypes[0] { - if e.Owner == nil { - // Edge was loaded but was not found. - return nil, &NotFoundError{label: machine.Label} - } + if e.Owner != nil { return e.Owner, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: machine.Label} } return nil, &NotLoadedError{edge: "owner"} } @@ -129,7 +131,7 @@ func (*Alert) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case alert.FieldSimulated: + case alert.FieldSimulated, alert.FieldRemediation: values[i] = new(sql.NullBool) case alert.FieldSourceLatitude, alert.FieldSourceLongitude: values[i] = new(sql.NullFloat64) @@ -142,7 +144,7 @@ func (*Alert) scanValues(columns []string) ([]any, error) { case alert.ForeignKeys[0]: // machine_alerts values[i] = new(sql.NullInt64) default: - return nil, fmt.Errorf("unexpected column %q for type Alert", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -166,15 +168,13 @@ func (a *Alert) assignValues(columns []string, values []any) error { if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) } else if value.Valid { - a.CreatedAt = new(time.Time) - *a.CreatedAt = value.Time + a.CreatedAt = value.Time } case alert.FieldUpdatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field updated_at", values[i]) } else if value.Valid { - a.UpdatedAt = new(time.Time) - *a.UpdatedAt = value.Time + a.UpdatedAt = value.Time } case alert.FieldScenario: if value, ok := values[i].(*sql.NullString); !ok { @@ -302,6 +302,12 @@ func (a *Alert) assignValues(columns []string, values []any) error { } else if value.Valid { a.UUID = value.String } + case alert.FieldRemediation: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field remediation", values[i]) + } else if value.Valid { + a.Remediation = value.Bool + } case alert.ForeignKeys[0]: if value, ok := values[i].(*sql.NullInt64); !ok { return fmt.Errorf("unexpected type %T for edge-field machine_alerts", value) @@ -309,36 +315,44 @@ func (a *Alert) assignValues(columns []string, values []any) error { a.machine_alerts = new(int) *a.machine_alerts = int(value.Int64) } + default: + a.selectValues.Set(columns[i], values[i]) } } return nil } +// Value returns the ent.Value that was dynamically selected and assigned to the Alert. +// This includes values selected through modifiers, order, etc. +func (a *Alert) Value(name string) (ent.Value, error) { + return a.selectValues.Get(name) +} + // QueryOwner queries the "owner" edge of the Alert entity. func (a *Alert) QueryOwner() *MachineQuery { - return (&AlertClient{config: a.config}).QueryOwner(a) + return NewAlertClient(a.config).QueryOwner(a) } // QueryDecisions queries the "decisions" edge of the Alert entity. func (a *Alert) QueryDecisions() *DecisionQuery { - return (&AlertClient{config: a.config}).QueryDecisions(a) + return NewAlertClient(a.config).QueryDecisions(a) } // QueryEvents queries the "events" edge of the Alert entity. func (a *Alert) QueryEvents() *EventQuery { - return (&AlertClient{config: a.config}).QueryEvents(a) + return NewAlertClient(a.config).QueryEvents(a) } // QueryMetas queries the "metas" edge of the Alert entity. func (a *Alert) QueryMetas() *MetaQuery { - return (&AlertClient{config: a.config}).QueryMetas(a) + return NewAlertClient(a.config).QueryMetas(a) } // Update returns a builder for updating this Alert. // Note that you need to call Alert.Unwrap() before calling this method if this Alert // was returned from a transaction, and the transaction was committed or rolled back. func (a *Alert) Update() *AlertUpdateOne { - return (&AlertClient{config: a.config}).UpdateOne(a) + return NewAlertClient(a.config).UpdateOne(a) } // Unwrap unwraps the Alert entity that was returned from a transaction after it was closed, @@ -357,15 +371,11 @@ func (a *Alert) String() string { var builder strings.Builder builder.WriteString("Alert(") builder.WriteString(fmt.Sprintf("id=%v, ", a.ID)) - if v := a.CreatedAt; v != nil { - builder.WriteString("created_at=") - builder.WriteString(v.Format(time.ANSIC)) - } + builder.WriteString("created_at=") + builder.WriteString(a.CreatedAt.Format(time.ANSIC)) builder.WriteString(", ") - if v := a.UpdatedAt; v != nil { - builder.WriteString("updated_at=") - builder.WriteString(v.Format(time.ANSIC)) - } + builder.WriteString("updated_at=") + builder.WriteString(a.UpdatedAt.Format(time.ANSIC)) builder.WriteString(", ") builder.WriteString("scenario=") builder.WriteString(a.Scenario) @@ -429,15 +439,12 @@ func (a *Alert) String() string { builder.WriteString(", ") builder.WriteString("uuid=") builder.WriteString(a.UUID) + builder.WriteString(", ") + builder.WriteString("remediation=") + builder.WriteString(fmt.Sprintf("%v", a.Remediation)) builder.WriteByte(')') return builder.String() } // Alerts is a parsable slice of Alert. type Alerts []*Alert - -func (a Alerts) config(cfg config) { - for _i := range a { - a[_i].config = cfg - } -} diff --git a/pkg/database/ent/alert/alert.go b/pkg/database/ent/alert/alert.go index abee13fb97a..62aade98e87 100644 --- a/pkg/database/ent/alert/alert.go +++ b/pkg/database/ent/alert/alert.go @@ -4,6 +4,9 @@ package alert import ( "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" ) const ( @@ -57,6 +60,8 @@ const ( FieldSimulated = "simulated" // FieldUUID holds the string denoting the uuid field in the database. FieldUUID = "uuid" + // FieldRemediation holds the string denoting the remediation field in the database. + FieldRemediation = "remediation" // EdgeOwner holds the string denoting the owner edge name in mutations. EdgeOwner = "owner" // EdgeDecisions holds the string denoting the decisions edge name in mutations. @@ -123,6 +128,7 @@ var Columns = []string{ FieldScenarioHash, FieldSimulated, FieldUUID, + FieldRemediation, } // ForeignKeys holds the SQL foreign-keys that are owned by the "alerts" @@ -149,8 +155,6 @@ func ValidColumn(column string) bool { var ( // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time - // UpdateDefaultCreatedAt holds the default value on update for the "created_at" field. - UpdateDefaultCreatedAt func() time.Time // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. DefaultUpdatedAt func() time.Time // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. @@ -168,3 +172,208 @@ var ( // DefaultSimulated holds the default value on creation for the "simulated" field. DefaultSimulated bool ) + +// OrderOption defines the ordering options for the Alert queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByScenario orders the results by the scenario field. +func ByScenario(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldScenario, opts...).ToFunc() +} + +// ByBucketId orders the results by the bucketId field. +func ByBucketId(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldBucketId, opts...).ToFunc() +} + +// ByMessage orders the results by the message field. +func ByMessage(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMessage, opts...).ToFunc() +} + +// ByEventsCountField orders the results by the eventsCount field. +func ByEventsCountField(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldEventsCount, opts...).ToFunc() +} + +// ByStartedAt orders the results by the startedAt field. +func ByStartedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStartedAt, opts...).ToFunc() +} + +// ByStoppedAt orders the results by the stoppedAt field. +func ByStoppedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStoppedAt, opts...).ToFunc() +} + +// BySourceIp orders the results by the sourceIp field. +func BySourceIp(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSourceIp, opts...).ToFunc() +} + +// BySourceRange orders the results by the sourceRange field. +func BySourceRange(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSourceRange, opts...).ToFunc() +} + +// BySourceAsNumber orders the results by the sourceAsNumber field. +func BySourceAsNumber(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSourceAsNumber, opts...).ToFunc() +} + +// BySourceAsName orders the results by the sourceAsName field. +func BySourceAsName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSourceAsName, opts...).ToFunc() +} + +// BySourceCountry orders the results by the sourceCountry field. +func BySourceCountry(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSourceCountry, opts...).ToFunc() +} + +// BySourceLatitude orders the results by the sourceLatitude field. +func BySourceLatitude(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSourceLatitude, opts...).ToFunc() +} + +// BySourceLongitude orders the results by the sourceLongitude field. +func BySourceLongitude(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSourceLongitude, opts...).ToFunc() +} + +// BySourceScope orders the results by the sourceScope field. +func BySourceScope(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSourceScope, opts...).ToFunc() +} + +// BySourceValue orders the results by the sourceValue field. +func BySourceValue(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSourceValue, opts...).ToFunc() +} + +// ByCapacity orders the results by the capacity field. +func ByCapacity(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCapacity, opts...).ToFunc() +} + +// ByLeakSpeed orders the results by the leakSpeed field. +func ByLeakSpeed(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLeakSpeed, opts...).ToFunc() +} + +// ByScenarioVersion orders the results by the scenarioVersion field. +func ByScenarioVersion(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldScenarioVersion, opts...).ToFunc() +} + +// ByScenarioHash orders the results by the scenarioHash field. +func ByScenarioHash(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldScenarioHash, opts...).ToFunc() +} + +// BySimulated orders the results by the simulated field. +func BySimulated(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSimulated, opts...).ToFunc() +} + +// ByUUID orders the results by the uuid field. +func ByUUID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUUID, opts...).ToFunc() +} + +// ByRemediation orders the results by the remediation field. +func ByRemediation(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRemediation, opts...).ToFunc() +} + +// ByOwnerField orders the results by owner field. +func ByOwnerField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newOwnerStep(), sql.OrderByField(field, opts...)) + } +} + +// ByDecisionsCount orders the results by decisions count. +func ByDecisionsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newDecisionsStep(), opts...) + } +} + +// ByDecisions orders the results by decisions terms. +func ByDecisions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newDecisionsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByEventsCount orders the results by events count. +func ByEventsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newEventsStep(), opts...) + } +} + +// ByEvents orders the results by events terms. +func ByEvents(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newEventsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByMetasCount orders the results by metas count. +func ByMetasCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newMetasStep(), opts...) + } +} + +// ByMetas orders the results by metas terms. +func ByMetas(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newMetasStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} +func newOwnerStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(OwnerInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), + ) +} +func newDecisionsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(DecisionsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, DecisionsTable, DecisionsColumn), + ) +} +func newEventsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(EventsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, EventsTable, EventsColumn), + ) +} +func newMetasStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(MetasInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, MetasTable, MetasColumn), + ) +} diff --git a/pkg/database/ent/alert/where.go b/pkg/database/ent/alert/where.go index ef5b89b615f..da6080fffb9 100644 --- a/pkg/database/ent/alert/where.go +++ b/pkg/database/ent/alert/where.go @@ -12,2440 +12,1617 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Alert(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Alert(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.Alert(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.Alert(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.Alert(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.Alert(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.Alert(sql.FieldLTE(FieldID, id)) } // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldCreatedAt, v)) } // UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. func UpdatedAt(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldUpdatedAt, v)) } // Scenario applies equality check predicate on the "scenario" field. It's identical to ScenarioEQ. func Scenario(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldScenario, v)) } // BucketId applies equality check predicate on the "bucketId" field. It's identical to BucketIdEQ. func BucketId(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldBucketId, v)) } // Message applies equality check predicate on the "message" field. It's identical to MessageEQ. func Message(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldMessage, v)) } // EventsCount applies equality check predicate on the "eventsCount" field. It's identical to EventsCountEQ. func EventsCount(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldEventsCount), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldEventsCount, v)) } // StartedAt applies equality check predicate on the "startedAt" field. It's identical to StartedAtEQ. func StartedAt(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStartedAt), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldStartedAt, v)) } // StoppedAt applies equality check predicate on the "stoppedAt" field. It's identical to StoppedAtEQ. func StoppedAt(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStoppedAt), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldStoppedAt, v)) } // SourceIp applies equality check predicate on the "sourceIp" field. It's identical to SourceIpEQ. func SourceIp(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceIp, v)) } // SourceRange applies equality check predicate on the "sourceRange" field. It's identical to SourceRangeEQ. func SourceRange(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceRange, v)) } // SourceAsNumber applies equality check predicate on the "sourceAsNumber" field. It's identical to SourceAsNumberEQ. func SourceAsNumber(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceAsNumber, v)) } // SourceAsName applies equality check predicate on the "sourceAsName" field. It's identical to SourceAsNameEQ. func SourceAsName(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceAsName, v)) } // SourceCountry applies equality check predicate on the "sourceCountry" field. It's identical to SourceCountryEQ. func SourceCountry(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceCountry, v)) } // SourceLatitude applies equality check predicate on the "sourceLatitude" field. It's identical to SourceLatitudeEQ. func SourceLatitude(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceLatitude), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceLatitude, v)) } // SourceLongitude applies equality check predicate on the "sourceLongitude" field. It's identical to SourceLongitudeEQ. func SourceLongitude(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceLongitude), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceLongitude, v)) } // SourceScope applies equality check predicate on the "sourceScope" field. It's identical to SourceScopeEQ. func SourceScope(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceScope, v)) } // SourceValue applies equality check predicate on the "sourceValue" field. It's identical to SourceValueEQ. func SourceValue(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceValue, v)) } // Capacity applies equality check predicate on the "capacity" field. It's identical to CapacityEQ. func Capacity(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCapacity), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldCapacity, v)) } // LeakSpeed applies equality check predicate on the "leakSpeed" field. It's identical to LeakSpeedEQ. func LeakSpeed(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldLeakSpeed, v)) } // ScenarioVersion applies equality check predicate on the "scenarioVersion" field. It's identical to ScenarioVersionEQ. func ScenarioVersion(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldScenarioVersion, v)) } // ScenarioHash applies equality check predicate on the "scenarioHash" field. It's identical to ScenarioHashEQ. func ScenarioHash(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldScenarioHash, v)) } // Simulated applies equality check predicate on the "simulated" field. It's identical to SimulatedEQ. func Simulated(v bool) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSimulated), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSimulated, v)) } // UUID applies equality check predicate on the "uuid" field. It's identical to UUIDEQ. func UUID(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldUUID, v)) +} + +// Remediation applies equality check predicate on the "remediation" field. It's identical to RemediationEQ. +func Remediation(v bool) predicate.Alert { + return predicate.Alert(sql.FieldEQ(FieldRemediation, v)) } // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldCreatedAt, v)) } // CreatedAtNEQ applies the NEQ predicate on the "created_at" field. func CreatedAtNEQ(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldCreatedAt, v)) } // CreatedAtIn applies the In predicate on the "created_at" field. func CreatedAtIn(vs ...time.Time) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldCreatedAt), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldCreatedAt, vs...)) } // CreatedAtNotIn applies the NotIn predicate on the "created_at" field. func CreatedAtNotIn(vs ...time.Time) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldCreatedAt), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldCreatedAt, vs...)) } // CreatedAtGT applies the GT predicate on the "created_at" field. func CreatedAtGT(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldCreatedAt), v)) - }) + return predicate.Alert(sql.FieldGT(FieldCreatedAt, v)) } // CreatedAtGTE applies the GTE predicate on the "created_at" field. func CreatedAtGTE(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldCreatedAt), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldCreatedAt, v)) } // CreatedAtLT applies the LT predicate on the "created_at" field. func CreatedAtLT(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldCreatedAt), v)) - }) + return predicate.Alert(sql.FieldLT(FieldCreatedAt, v)) } // CreatedAtLTE applies the LTE predicate on the "created_at" field. func CreatedAtLTE(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldCreatedAt), v)) - }) -} - -// CreatedAtIsNil applies the IsNil predicate on the "created_at" field. -func CreatedAtIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldCreatedAt))) - }) -} - -// CreatedAtNotNil applies the NotNil predicate on the "created_at" field. -func CreatedAtNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldCreatedAt))) - }) + return predicate.Alert(sql.FieldLTE(FieldCreatedAt, v)) } // UpdatedAtEQ applies the EQ predicate on the "updated_at" field. func UpdatedAtEQ(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldUpdatedAt, v)) } // UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. func UpdatedAtNEQ(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldUpdatedAt, v)) } // UpdatedAtIn applies the In predicate on the "updated_at" field. func UpdatedAtIn(vs ...time.Time) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldUpdatedAt, vs...)) } // UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. func UpdatedAtNotIn(vs ...time.Time) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldUpdatedAt, vs...)) } // UpdatedAtGT applies the GT predicate on the "updated_at" field. func UpdatedAtGT(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Alert(sql.FieldGT(FieldUpdatedAt, v)) } // UpdatedAtGTE applies the GTE predicate on the "updated_at" field. func UpdatedAtGTE(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUpdatedAt), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldUpdatedAt, v)) } // UpdatedAtLT applies the LT predicate on the "updated_at" field. func UpdatedAtLT(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Alert(sql.FieldLT(FieldUpdatedAt, v)) } // UpdatedAtLTE applies the LTE predicate on the "updated_at" field. func UpdatedAtLTE(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUpdatedAt), v)) - }) -} - -// UpdatedAtIsNil applies the IsNil predicate on the "updated_at" field. -func UpdatedAtIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUpdatedAt))) - }) -} - -// UpdatedAtNotNil applies the NotNil predicate on the "updated_at" field. -func UpdatedAtNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUpdatedAt))) - }) + return predicate.Alert(sql.FieldLTE(FieldUpdatedAt, v)) } // ScenarioEQ applies the EQ predicate on the "scenario" field. func ScenarioEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldScenario, v)) } // ScenarioNEQ applies the NEQ predicate on the "scenario" field. func ScenarioNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldScenario, v)) } // ScenarioIn applies the In predicate on the "scenario" field. func ScenarioIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldScenario), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldScenario, vs...)) } // ScenarioNotIn applies the NotIn predicate on the "scenario" field. func ScenarioNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldScenario), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldScenario, vs...)) } // ScenarioGT applies the GT predicate on the "scenario" field. func ScenarioGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldGT(FieldScenario, v)) } // ScenarioGTE applies the GTE predicate on the "scenario" field. func ScenarioGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldScenario, v)) } // ScenarioLT applies the LT predicate on the "scenario" field. func ScenarioLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldLT(FieldScenario, v)) } // ScenarioLTE applies the LTE predicate on the "scenario" field. func ScenarioLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldScenario, v)) } // ScenarioContains applies the Contains predicate on the "scenario" field. func ScenarioContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldContains(FieldScenario, v)) } // ScenarioHasPrefix applies the HasPrefix predicate on the "scenario" field. func ScenarioHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldScenario, v)) } // ScenarioHasSuffix applies the HasSuffix predicate on the "scenario" field. func ScenarioHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldScenario, v)) } // ScenarioEqualFold applies the EqualFold predicate on the "scenario" field. func ScenarioEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldScenario, v)) } // ScenarioContainsFold applies the ContainsFold predicate on the "scenario" field. func ScenarioContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldScenario, v)) } // BucketIdEQ applies the EQ predicate on the "bucketId" field. func BucketIdEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldBucketId, v)) } // BucketIdNEQ applies the NEQ predicate on the "bucketId" field. func BucketIdNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldBucketId, v)) } // BucketIdIn applies the In predicate on the "bucketId" field. func BucketIdIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldBucketId), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldBucketId, vs...)) } // BucketIdNotIn applies the NotIn predicate on the "bucketId" field. func BucketIdNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldBucketId), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldBucketId, vs...)) } // BucketIdGT applies the GT predicate on the "bucketId" field. func BucketIdGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldGT(FieldBucketId, v)) } // BucketIdGTE applies the GTE predicate on the "bucketId" field. func BucketIdGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldBucketId, v)) } // BucketIdLT applies the LT predicate on the "bucketId" field. func BucketIdLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldLT(FieldBucketId, v)) } // BucketIdLTE applies the LTE predicate on the "bucketId" field. func BucketIdLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldBucketId, v)) } // BucketIdContains applies the Contains predicate on the "bucketId" field. func BucketIdContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldContains(FieldBucketId, v)) } // BucketIdHasPrefix applies the HasPrefix predicate on the "bucketId" field. func BucketIdHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldBucketId, v)) } // BucketIdHasSuffix applies the HasSuffix predicate on the "bucketId" field. func BucketIdHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldBucketId, v)) } // BucketIdIsNil applies the IsNil predicate on the "bucketId" field. func BucketIdIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldBucketId))) - }) + return predicate.Alert(sql.FieldIsNull(FieldBucketId)) } // BucketIdNotNil applies the NotNil predicate on the "bucketId" field. func BucketIdNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldBucketId))) - }) + return predicate.Alert(sql.FieldNotNull(FieldBucketId)) } // BucketIdEqualFold applies the EqualFold predicate on the "bucketId" field. func BucketIdEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldBucketId, v)) } // BucketIdContainsFold applies the ContainsFold predicate on the "bucketId" field. func BucketIdContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldBucketId, v)) } // MessageEQ applies the EQ predicate on the "message" field. func MessageEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldMessage, v)) } // MessageNEQ applies the NEQ predicate on the "message" field. func MessageNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldMessage, v)) } // MessageIn applies the In predicate on the "message" field. func MessageIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldMessage), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldMessage, vs...)) } // MessageNotIn applies the NotIn predicate on the "message" field. func MessageNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldMessage), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldMessage, vs...)) } // MessageGT applies the GT predicate on the "message" field. func MessageGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldGT(FieldMessage, v)) } // MessageGTE applies the GTE predicate on the "message" field. func MessageGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldMessage, v)) } // MessageLT applies the LT predicate on the "message" field. func MessageLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldLT(FieldMessage, v)) } // MessageLTE applies the LTE predicate on the "message" field. func MessageLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldMessage, v)) } // MessageContains applies the Contains predicate on the "message" field. func MessageContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldContains(FieldMessage, v)) } // MessageHasPrefix applies the HasPrefix predicate on the "message" field. func MessageHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldMessage, v)) } // MessageHasSuffix applies the HasSuffix predicate on the "message" field. func MessageHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldMessage, v)) } // MessageIsNil applies the IsNil predicate on the "message" field. func MessageIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldMessage))) - }) + return predicate.Alert(sql.FieldIsNull(FieldMessage)) } // MessageNotNil applies the NotNil predicate on the "message" field. func MessageNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldMessage))) - }) + return predicate.Alert(sql.FieldNotNull(FieldMessage)) } // MessageEqualFold applies the EqualFold predicate on the "message" field. func MessageEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldMessage, v)) } // MessageContainsFold applies the ContainsFold predicate on the "message" field. func MessageContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldMessage, v)) } // EventsCountEQ applies the EQ predicate on the "eventsCount" field. func EventsCountEQ(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldEventsCount), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldEventsCount, v)) } // EventsCountNEQ applies the NEQ predicate on the "eventsCount" field. func EventsCountNEQ(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldEventsCount), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldEventsCount, v)) } // EventsCountIn applies the In predicate on the "eventsCount" field. func EventsCountIn(vs ...int32) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldEventsCount), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldEventsCount, vs...)) } // EventsCountNotIn applies the NotIn predicate on the "eventsCount" field. func EventsCountNotIn(vs ...int32) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldEventsCount), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldEventsCount, vs...)) } // EventsCountGT applies the GT predicate on the "eventsCount" field. func EventsCountGT(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldEventsCount), v)) - }) + return predicate.Alert(sql.FieldGT(FieldEventsCount, v)) } // EventsCountGTE applies the GTE predicate on the "eventsCount" field. func EventsCountGTE(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldEventsCount), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldEventsCount, v)) } // EventsCountLT applies the LT predicate on the "eventsCount" field. func EventsCountLT(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldEventsCount), v)) - }) + return predicate.Alert(sql.FieldLT(FieldEventsCount, v)) } // EventsCountLTE applies the LTE predicate on the "eventsCount" field. func EventsCountLTE(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldEventsCount), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldEventsCount, v)) } // EventsCountIsNil applies the IsNil predicate on the "eventsCount" field. func EventsCountIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldEventsCount))) - }) + return predicate.Alert(sql.FieldIsNull(FieldEventsCount)) } // EventsCountNotNil applies the NotNil predicate on the "eventsCount" field. func EventsCountNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldEventsCount))) - }) + return predicate.Alert(sql.FieldNotNull(FieldEventsCount)) } // StartedAtEQ applies the EQ predicate on the "startedAt" field. func StartedAtEQ(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStartedAt), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldStartedAt, v)) } // StartedAtNEQ applies the NEQ predicate on the "startedAt" field. func StartedAtNEQ(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldStartedAt), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldStartedAt, v)) } // StartedAtIn applies the In predicate on the "startedAt" field. func StartedAtIn(vs ...time.Time) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldStartedAt), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldStartedAt, vs...)) } // StartedAtNotIn applies the NotIn predicate on the "startedAt" field. func StartedAtNotIn(vs ...time.Time) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldStartedAt), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldStartedAt, vs...)) } // StartedAtGT applies the GT predicate on the "startedAt" field. func StartedAtGT(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldStartedAt), v)) - }) + return predicate.Alert(sql.FieldGT(FieldStartedAt, v)) } // StartedAtGTE applies the GTE predicate on the "startedAt" field. func StartedAtGTE(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldStartedAt), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldStartedAt, v)) } // StartedAtLT applies the LT predicate on the "startedAt" field. func StartedAtLT(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldStartedAt), v)) - }) + return predicate.Alert(sql.FieldLT(FieldStartedAt, v)) } // StartedAtLTE applies the LTE predicate on the "startedAt" field. func StartedAtLTE(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldStartedAt), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldStartedAt, v)) } // StartedAtIsNil applies the IsNil predicate on the "startedAt" field. func StartedAtIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldStartedAt))) - }) + return predicate.Alert(sql.FieldIsNull(FieldStartedAt)) } // StartedAtNotNil applies the NotNil predicate on the "startedAt" field. func StartedAtNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldStartedAt))) - }) + return predicate.Alert(sql.FieldNotNull(FieldStartedAt)) } // StoppedAtEQ applies the EQ predicate on the "stoppedAt" field. func StoppedAtEQ(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStoppedAt), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldStoppedAt, v)) } // StoppedAtNEQ applies the NEQ predicate on the "stoppedAt" field. func StoppedAtNEQ(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldStoppedAt), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldStoppedAt, v)) } // StoppedAtIn applies the In predicate on the "stoppedAt" field. func StoppedAtIn(vs ...time.Time) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldStoppedAt), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldStoppedAt, vs...)) } // StoppedAtNotIn applies the NotIn predicate on the "stoppedAt" field. func StoppedAtNotIn(vs ...time.Time) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldStoppedAt), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldStoppedAt, vs...)) } // StoppedAtGT applies the GT predicate on the "stoppedAt" field. func StoppedAtGT(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldStoppedAt), v)) - }) + return predicate.Alert(sql.FieldGT(FieldStoppedAt, v)) } // StoppedAtGTE applies the GTE predicate on the "stoppedAt" field. func StoppedAtGTE(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldStoppedAt), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldStoppedAt, v)) } // StoppedAtLT applies the LT predicate on the "stoppedAt" field. func StoppedAtLT(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldStoppedAt), v)) - }) + return predicate.Alert(sql.FieldLT(FieldStoppedAt, v)) } // StoppedAtLTE applies the LTE predicate on the "stoppedAt" field. func StoppedAtLTE(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldStoppedAt), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldStoppedAt, v)) } // StoppedAtIsNil applies the IsNil predicate on the "stoppedAt" field. func StoppedAtIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldStoppedAt))) - }) + return predicate.Alert(sql.FieldIsNull(FieldStoppedAt)) } // StoppedAtNotNil applies the NotNil predicate on the "stoppedAt" field. func StoppedAtNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldStoppedAt))) - }) + return predicate.Alert(sql.FieldNotNull(FieldStoppedAt)) } // SourceIpEQ applies the EQ predicate on the "sourceIp" field. func SourceIpEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceIp, v)) } // SourceIpNEQ applies the NEQ predicate on the "sourceIp" field. func SourceIpNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSourceIp, v)) } // SourceIpIn applies the In predicate on the "sourceIp" field. func SourceIpIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSourceIp), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldSourceIp, vs...)) } // SourceIpNotIn applies the NotIn predicate on the "sourceIp" field. func SourceIpNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSourceIp), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldSourceIp, vs...)) } // SourceIpGT applies the GT predicate on the "sourceIp" field. func SourceIpGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldGT(FieldSourceIp, v)) } // SourceIpGTE applies the GTE predicate on the "sourceIp" field. func SourceIpGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldSourceIp, v)) } // SourceIpLT applies the LT predicate on the "sourceIp" field. func SourceIpLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldLT(FieldSourceIp, v)) } // SourceIpLTE applies the LTE predicate on the "sourceIp" field. func SourceIpLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldSourceIp, v)) } // SourceIpContains applies the Contains predicate on the "sourceIp" field. func SourceIpContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldContains(FieldSourceIp, v)) } // SourceIpHasPrefix applies the HasPrefix predicate on the "sourceIp" field. func SourceIpHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldSourceIp, v)) } // SourceIpHasSuffix applies the HasSuffix predicate on the "sourceIp" field. func SourceIpHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldSourceIp, v)) } // SourceIpIsNil applies the IsNil predicate on the "sourceIp" field. func SourceIpIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldSourceIp))) - }) + return predicate.Alert(sql.FieldIsNull(FieldSourceIp)) } // SourceIpNotNil applies the NotNil predicate on the "sourceIp" field. func SourceIpNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldSourceIp))) - }) + return predicate.Alert(sql.FieldNotNull(FieldSourceIp)) } // SourceIpEqualFold applies the EqualFold predicate on the "sourceIp" field. func SourceIpEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldSourceIp, v)) } // SourceIpContainsFold applies the ContainsFold predicate on the "sourceIp" field. func SourceIpContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldSourceIp, v)) } // SourceRangeEQ applies the EQ predicate on the "sourceRange" field. func SourceRangeEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceRange, v)) } // SourceRangeNEQ applies the NEQ predicate on the "sourceRange" field. func SourceRangeNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSourceRange, v)) } // SourceRangeIn applies the In predicate on the "sourceRange" field. func SourceRangeIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSourceRange), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldSourceRange, vs...)) } // SourceRangeNotIn applies the NotIn predicate on the "sourceRange" field. func SourceRangeNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSourceRange), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldSourceRange, vs...)) } // SourceRangeGT applies the GT predicate on the "sourceRange" field. func SourceRangeGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldGT(FieldSourceRange, v)) } // SourceRangeGTE applies the GTE predicate on the "sourceRange" field. func SourceRangeGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldSourceRange, v)) } // SourceRangeLT applies the LT predicate on the "sourceRange" field. func SourceRangeLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldLT(FieldSourceRange, v)) } // SourceRangeLTE applies the LTE predicate on the "sourceRange" field. func SourceRangeLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldSourceRange, v)) } // SourceRangeContains applies the Contains predicate on the "sourceRange" field. func SourceRangeContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldContains(FieldSourceRange, v)) } // SourceRangeHasPrefix applies the HasPrefix predicate on the "sourceRange" field. func SourceRangeHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldSourceRange, v)) } // SourceRangeHasSuffix applies the HasSuffix predicate on the "sourceRange" field. func SourceRangeHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldSourceRange, v)) } // SourceRangeIsNil applies the IsNil predicate on the "sourceRange" field. func SourceRangeIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldSourceRange))) - }) + return predicate.Alert(sql.FieldIsNull(FieldSourceRange)) } // SourceRangeNotNil applies the NotNil predicate on the "sourceRange" field. func SourceRangeNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldSourceRange))) - }) + return predicate.Alert(sql.FieldNotNull(FieldSourceRange)) } // SourceRangeEqualFold applies the EqualFold predicate on the "sourceRange" field. func SourceRangeEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldSourceRange, v)) } // SourceRangeContainsFold applies the ContainsFold predicate on the "sourceRange" field. func SourceRangeContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldSourceRange, v)) } // SourceAsNumberEQ applies the EQ predicate on the "sourceAsNumber" field. func SourceAsNumberEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceAsNumber, v)) } // SourceAsNumberNEQ applies the NEQ predicate on the "sourceAsNumber" field. func SourceAsNumberNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSourceAsNumber, v)) } // SourceAsNumberIn applies the In predicate on the "sourceAsNumber" field. func SourceAsNumberIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSourceAsNumber), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldSourceAsNumber, vs...)) } // SourceAsNumberNotIn applies the NotIn predicate on the "sourceAsNumber" field. func SourceAsNumberNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSourceAsNumber), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldSourceAsNumber, vs...)) } // SourceAsNumberGT applies the GT predicate on the "sourceAsNumber" field. func SourceAsNumberGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldGT(FieldSourceAsNumber, v)) } // SourceAsNumberGTE applies the GTE predicate on the "sourceAsNumber" field. func SourceAsNumberGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldSourceAsNumber, v)) } // SourceAsNumberLT applies the LT predicate on the "sourceAsNumber" field. func SourceAsNumberLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldLT(FieldSourceAsNumber, v)) } // SourceAsNumberLTE applies the LTE predicate on the "sourceAsNumber" field. func SourceAsNumberLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldSourceAsNumber, v)) } // SourceAsNumberContains applies the Contains predicate on the "sourceAsNumber" field. func SourceAsNumberContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldContains(FieldSourceAsNumber, v)) } // SourceAsNumberHasPrefix applies the HasPrefix predicate on the "sourceAsNumber" field. func SourceAsNumberHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldSourceAsNumber, v)) } // SourceAsNumberHasSuffix applies the HasSuffix predicate on the "sourceAsNumber" field. func SourceAsNumberHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldSourceAsNumber, v)) } // SourceAsNumberIsNil applies the IsNil predicate on the "sourceAsNumber" field. func SourceAsNumberIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldSourceAsNumber))) - }) + return predicate.Alert(sql.FieldIsNull(FieldSourceAsNumber)) } // SourceAsNumberNotNil applies the NotNil predicate on the "sourceAsNumber" field. func SourceAsNumberNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldSourceAsNumber))) - }) + return predicate.Alert(sql.FieldNotNull(FieldSourceAsNumber)) } // SourceAsNumberEqualFold applies the EqualFold predicate on the "sourceAsNumber" field. func SourceAsNumberEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldSourceAsNumber, v)) } // SourceAsNumberContainsFold applies the ContainsFold predicate on the "sourceAsNumber" field. func SourceAsNumberContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldSourceAsNumber, v)) } // SourceAsNameEQ applies the EQ predicate on the "sourceAsName" field. func SourceAsNameEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceAsName, v)) } // SourceAsNameNEQ applies the NEQ predicate on the "sourceAsName" field. func SourceAsNameNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSourceAsName, v)) } // SourceAsNameIn applies the In predicate on the "sourceAsName" field. func SourceAsNameIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSourceAsName), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldSourceAsName, vs...)) } // SourceAsNameNotIn applies the NotIn predicate on the "sourceAsName" field. func SourceAsNameNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSourceAsName), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldSourceAsName, vs...)) } // SourceAsNameGT applies the GT predicate on the "sourceAsName" field. func SourceAsNameGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldGT(FieldSourceAsName, v)) } // SourceAsNameGTE applies the GTE predicate on the "sourceAsName" field. func SourceAsNameGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldSourceAsName, v)) } // SourceAsNameLT applies the LT predicate on the "sourceAsName" field. func SourceAsNameLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldLT(FieldSourceAsName, v)) } // SourceAsNameLTE applies the LTE predicate on the "sourceAsName" field. func SourceAsNameLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldSourceAsName, v)) } // SourceAsNameContains applies the Contains predicate on the "sourceAsName" field. func SourceAsNameContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldContains(FieldSourceAsName, v)) } // SourceAsNameHasPrefix applies the HasPrefix predicate on the "sourceAsName" field. func SourceAsNameHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldSourceAsName, v)) } // SourceAsNameHasSuffix applies the HasSuffix predicate on the "sourceAsName" field. func SourceAsNameHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldSourceAsName, v)) } // SourceAsNameIsNil applies the IsNil predicate on the "sourceAsName" field. func SourceAsNameIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldSourceAsName))) - }) + return predicate.Alert(sql.FieldIsNull(FieldSourceAsName)) } // SourceAsNameNotNil applies the NotNil predicate on the "sourceAsName" field. func SourceAsNameNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldSourceAsName))) - }) + return predicate.Alert(sql.FieldNotNull(FieldSourceAsName)) } // SourceAsNameEqualFold applies the EqualFold predicate on the "sourceAsName" field. func SourceAsNameEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldSourceAsName, v)) } // SourceAsNameContainsFold applies the ContainsFold predicate on the "sourceAsName" field. func SourceAsNameContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldSourceAsName, v)) } // SourceCountryEQ applies the EQ predicate on the "sourceCountry" field. func SourceCountryEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceCountry, v)) } // SourceCountryNEQ applies the NEQ predicate on the "sourceCountry" field. func SourceCountryNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSourceCountry, v)) } // SourceCountryIn applies the In predicate on the "sourceCountry" field. func SourceCountryIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSourceCountry), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldSourceCountry, vs...)) } // SourceCountryNotIn applies the NotIn predicate on the "sourceCountry" field. func SourceCountryNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSourceCountry), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldSourceCountry, vs...)) } // SourceCountryGT applies the GT predicate on the "sourceCountry" field. func SourceCountryGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldGT(FieldSourceCountry, v)) } // SourceCountryGTE applies the GTE predicate on the "sourceCountry" field. func SourceCountryGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldSourceCountry, v)) } // SourceCountryLT applies the LT predicate on the "sourceCountry" field. func SourceCountryLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldLT(FieldSourceCountry, v)) } // SourceCountryLTE applies the LTE predicate on the "sourceCountry" field. func SourceCountryLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldSourceCountry, v)) } // SourceCountryContains applies the Contains predicate on the "sourceCountry" field. func SourceCountryContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldContains(FieldSourceCountry, v)) } // SourceCountryHasPrefix applies the HasPrefix predicate on the "sourceCountry" field. func SourceCountryHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldSourceCountry, v)) } // SourceCountryHasSuffix applies the HasSuffix predicate on the "sourceCountry" field. func SourceCountryHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldSourceCountry, v)) } // SourceCountryIsNil applies the IsNil predicate on the "sourceCountry" field. func SourceCountryIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldSourceCountry))) - }) + return predicate.Alert(sql.FieldIsNull(FieldSourceCountry)) } // SourceCountryNotNil applies the NotNil predicate on the "sourceCountry" field. func SourceCountryNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldSourceCountry))) - }) + return predicate.Alert(sql.FieldNotNull(FieldSourceCountry)) } // SourceCountryEqualFold applies the EqualFold predicate on the "sourceCountry" field. func SourceCountryEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldSourceCountry, v)) } // SourceCountryContainsFold applies the ContainsFold predicate on the "sourceCountry" field. func SourceCountryContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldSourceCountry, v)) } // SourceLatitudeEQ applies the EQ predicate on the "sourceLatitude" field. func SourceLatitudeEQ(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceLatitude), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceLatitude, v)) } // SourceLatitudeNEQ applies the NEQ predicate on the "sourceLatitude" field. func SourceLatitudeNEQ(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSourceLatitude), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSourceLatitude, v)) } // SourceLatitudeIn applies the In predicate on the "sourceLatitude" field. func SourceLatitudeIn(vs ...float32) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSourceLatitude), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldSourceLatitude, vs...)) } // SourceLatitudeNotIn applies the NotIn predicate on the "sourceLatitude" field. func SourceLatitudeNotIn(vs ...float32) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSourceLatitude), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldSourceLatitude, vs...)) } // SourceLatitudeGT applies the GT predicate on the "sourceLatitude" field. func SourceLatitudeGT(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSourceLatitude), v)) - }) + return predicate.Alert(sql.FieldGT(FieldSourceLatitude, v)) } // SourceLatitudeGTE applies the GTE predicate on the "sourceLatitude" field. func SourceLatitudeGTE(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSourceLatitude), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldSourceLatitude, v)) } // SourceLatitudeLT applies the LT predicate on the "sourceLatitude" field. func SourceLatitudeLT(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSourceLatitude), v)) - }) + return predicate.Alert(sql.FieldLT(FieldSourceLatitude, v)) } // SourceLatitudeLTE applies the LTE predicate on the "sourceLatitude" field. func SourceLatitudeLTE(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSourceLatitude), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldSourceLatitude, v)) } // SourceLatitudeIsNil applies the IsNil predicate on the "sourceLatitude" field. func SourceLatitudeIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldSourceLatitude))) - }) + return predicate.Alert(sql.FieldIsNull(FieldSourceLatitude)) } // SourceLatitudeNotNil applies the NotNil predicate on the "sourceLatitude" field. -func SourceLatitudeNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldSourceLatitude))) - }) +func SourceLatitudeNotNil() predicate.Alert { + return predicate.Alert(sql.FieldNotNull(FieldSourceLatitude)) } // SourceLongitudeEQ applies the EQ predicate on the "sourceLongitude" field. func SourceLongitudeEQ(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceLongitude), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceLongitude, v)) } // SourceLongitudeNEQ applies the NEQ predicate on the "sourceLongitude" field. func SourceLongitudeNEQ(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSourceLongitude), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSourceLongitude, v)) } // SourceLongitudeIn applies the In predicate on the "sourceLongitude" field. func SourceLongitudeIn(vs ...float32) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSourceLongitude), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldSourceLongitude, vs...)) } // SourceLongitudeNotIn applies the NotIn predicate on the "sourceLongitude" field. func SourceLongitudeNotIn(vs ...float32) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSourceLongitude), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldSourceLongitude, vs...)) } // SourceLongitudeGT applies the GT predicate on the "sourceLongitude" field. func SourceLongitudeGT(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSourceLongitude), v)) - }) + return predicate.Alert(sql.FieldGT(FieldSourceLongitude, v)) } // SourceLongitudeGTE applies the GTE predicate on the "sourceLongitude" field. func SourceLongitudeGTE(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSourceLongitude), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldSourceLongitude, v)) } // SourceLongitudeLT applies the LT predicate on the "sourceLongitude" field. func SourceLongitudeLT(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSourceLongitude), v)) - }) + return predicate.Alert(sql.FieldLT(FieldSourceLongitude, v)) } // SourceLongitudeLTE applies the LTE predicate on the "sourceLongitude" field. func SourceLongitudeLTE(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSourceLongitude), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldSourceLongitude, v)) } // SourceLongitudeIsNil applies the IsNil predicate on the "sourceLongitude" field. func SourceLongitudeIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldSourceLongitude))) - }) + return predicate.Alert(sql.FieldIsNull(FieldSourceLongitude)) } // SourceLongitudeNotNil applies the NotNil predicate on the "sourceLongitude" field. func SourceLongitudeNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldSourceLongitude))) - }) + return predicate.Alert(sql.FieldNotNull(FieldSourceLongitude)) } // SourceScopeEQ applies the EQ predicate on the "sourceScope" field. func SourceScopeEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceScope, v)) } // SourceScopeNEQ applies the NEQ predicate on the "sourceScope" field. func SourceScopeNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSourceScope, v)) } // SourceScopeIn applies the In predicate on the "sourceScope" field. func SourceScopeIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSourceScope), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldSourceScope, vs...)) } // SourceScopeNotIn applies the NotIn predicate on the "sourceScope" field. func SourceScopeNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSourceScope), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldSourceScope, vs...)) } // SourceScopeGT applies the GT predicate on the "sourceScope" field. func SourceScopeGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldGT(FieldSourceScope, v)) } // SourceScopeGTE applies the GTE predicate on the "sourceScope" field. func SourceScopeGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldSourceScope, v)) } // SourceScopeLT applies the LT predicate on the "sourceScope" field. func SourceScopeLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldLT(FieldSourceScope, v)) } // SourceScopeLTE applies the LTE predicate on the "sourceScope" field. func SourceScopeLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldSourceScope, v)) } // SourceScopeContains applies the Contains predicate on the "sourceScope" field. func SourceScopeContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldContains(FieldSourceScope, v)) } // SourceScopeHasPrefix applies the HasPrefix predicate on the "sourceScope" field. func SourceScopeHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldSourceScope, v)) } // SourceScopeHasSuffix applies the HasSuffix predicate on the "sourceScope" field. func SourceScopeHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldSourceScope, v)) } // SourceScopeIsNil applies the IsNil predicate on the "sourceScope" field. func SourceScopeIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldSourceScope))) - }) + return predicate.Alert(sql.FieldIsNull(FieldSourceScope)) } // SourceScopeNotNil applies the NotNil predicate on the "sourceScope" field. func SourceScopeNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldSourceScope))) - }) + return predicate.Alert(sql.FieldNotNull(FieldSourceScope)) } // SourceScopeEqualFold applies the EqualFold predicate on the "sourceScope" field. func SourceScopeEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldSourceScope, v)) } // SourceScopeContainsFold applies the ContainsFold predicate on the "sourceScope" field. func SourceScopeContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldSourceScope, v)) } // SourceValueEQ applies the EQ predicate on the "sourceValue" field. func SourceValueEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceValue, v)) } // SourceValueNEQ applies the NEQ predicate on the "sourceValue" field. func SourceValueNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSourceValue, v)) } // SourceValueIn applies the In predicate on the "sourceValue" field. func SourceValueIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSourceValue), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldSourceValue, vs...)) } // SourceValueNotIn applies the NotIn predicate on the "sourceValue" field. func SourceValueNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSourceValue), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldSourceValue, vs...)) } // SourceValueGT applies the GT predicate on the "sourceValue" field. func SourceValueGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldGT(FieldSourceValue, v)) } // SourceValueGTE applies the GTE predicate on the "sourceValue" field. func SourceValueGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldSourceValue, v)) } // SourceValueLT applies the LT predicate on the "sourceValue" field. func SourceValueLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldLT(FieldSourceValue, v)) } // SourceValueLTE applies the LTE predicate on the "sourceValue" field. func SourceValueLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldSourceValue, v)) } // SourceValueContains applies the Contains predicate on the "sourceValue" field. func SourceValueContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldContains(FieldSourceValue, v)) } // SourceValueHasPrefix applies the HasPrefix predicate on the "sourceValue" field. func SourceValueHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldSourceValue, v)) } // SourceValueHasSuffix applies the HasSuffix predicate on the "sourceValue" field. func SourceValueHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldSourceValue, v)) } // SourceValueIsNil applies the IsNil predicate on the "sourceValue" field. func SourceValueIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldSourceValue))) - }) + return predicate.Alert(sql.FieldIsNull(FieldSourceValue)) } // SourceValueNotNil applies the NotNil predicate on the "sourceValue" field. func SourceValueNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldSourceValue))) - }) + return predicate.Alert(sql.FieldNotNull(FieldSourceValue)) } // SourceValueEqualFold applies the EqualFold predicate on the "sourceValue" field. func SourceValueEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldSourceValue, v)) } // SourceValueContainsFold applies the ContainsFold predicate on the "sourceValue" field. func SourceValueContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldSourceValue, v)) } // CapacityEQ applies the EQ predicate on the "capacity" field. func CapacityEQ(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCapacity), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldCapacity, v)) } // CapacityNEQ applies the NEQ predicate on the "capacity" field. func CapacityNEQ(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldCapacity), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldCapacity, v)) } // CapacityIn applies the In predicate on the "capacity" field. func CapacityIn(vs ...int32) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldCapacity), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldCapacity, vs...)) } // CapacityNotIn applies the NotIn predicate on the "capacity" field. func CapacityNotIn(vs ...int32) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldCapacity), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldCapacity, vs...)) } // CapacityGT applies the GT predicate on the "capacity" field. func CapacityGT(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldCapacity), v)) - }) + return predicate.Alert(sql.FieldGT(FieldCapacity, v)) } // CapacityGTE applies the GTE predicate on the "capacity" field. func CapacityGTE(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldCapacity), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldCapacity, v)) } // CapacityLT applies the LT predicate on the "capacity" field. func CapacityLT(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldCapacity), v)) - }) + return predicate.Alert(sql.FieldLT(FieldCapacity, v)) } // CapacityLTE applies the LTE predicate on the "capacity" field. func CapacityLTE(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldCapacity), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldCapacity, v)) } // CapacityIsNil applies the IsNil predicate on the "capacity" field. func CapacityIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldCapacity))) - }) + return predicate.Alert(sql.FieldIsNull(FieldCapacity)) } // CapacityNotNil applies the NotNil predicate on the "capacity" field. func CapacityNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldCapacity))) - }) + return predicate.Alert(sql.FieldNotNull(FieldCapacity)) } // LeakSpeedEQ applies the EQ predicate on the "leakSpeed" field. func LeakSpeedEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldLeakSpeed, v)) } // LeakSpeedNEQ applies the NEQ predicate on the "leakSpeed" field. func LeakSpeedNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldLeakSpeed, v)) } // LeakSpeedIn applies the In predicate on the "leakSpeed" field. func LeakSpeedIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldLeakSpeed), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldLeakSpeed, vs...)) } // LeakSpeedNotIn applies the NotIn predicate on the "leakSpeed" field. func LeakSpeedNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldLeakSpeed), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldLeakSpeed, vs...)) } // LeakSpeedGT applies the GT predicate on the "leakSpeed" field. func LeakSpeedGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldGT(FieldLeakSpeed, v)) } // LeakSpeedGTE applies the GTE predicate on the "leakSpeed" field. func LeakSpeedGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldLeakSpeed, v)) } // LeakSpeedLT applies the LT predicate on the "leakSpeed" field. func LeakSpeedLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldLT(FieldLeakSpeed, v)) } // LeakSpeedLTE applies the LTE predicate on the "leakSpeed" field. func LeakSpeedLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldLeakSpeed, v)) } // LeakSpeedContains applies the Contains predicate on the "leakSpeed" field. func LeakSpeedContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldContains(FieldLeakSpeed, v)) } // LeakSpeedHasPrefix applies the HasPrefix predicate on the "leakSpeed" field. func LeakSpeedHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldLeakSpeed, v)) } // LeakSpeedHasSuffix applies the HasSuffix predicate on the "leakSpeed" field. func LeakSpeedHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldLeakSpeed, v)) } // LeakSpeedIsNil applies the IsNil predicate on the "leakSpeed" field. func LeakSpeedIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldLeakSpeed))) - }) + return predicate.Alert(sql.FieldIsNull(FieldLeakSpeed)) } // LeakSpeedNotNil applies the NotNil predicate on the "leakSpeed" field. func LeakSpeedNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldLeakSpeed))) - }) + return predicate.Alert(sql.FieldNotNull(FieldLeakSpeed)) } // LeakSpeedEqualFold applies the EqualFold predicate on the "leakSpeed" field. func LeakSpeedEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldLeakSpeed, v)) } // LeakSpeedContainsFold applies the ContainsFold predicate on the "leakSpeed" field. func LeakSpeedContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldLeakSpeed, v)) } // ScenarioVersionEQ applies the EQ predicate on the "scenarioVersion" field. func ScenarioVersionEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldScenarioVersion, v)) } // ScenarioVersionNEQ applies the NEQ predicate on the "scenarioVersion" field. func ScenarioVersionNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldScenarioVersion, v)) } // ScenarioVersionIn applies the In predicate on the "scenarioVersion" field. func ScenarioVersionIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldScenarioVersion), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldScenarioVersion, vs...)) } // ScenarioVersionNotIn applies the NotIn predicate on the "scenarioVersion" field. func ScenarioVersionNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldScenarioVersion), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldScenarioVersion, vs...)) } // ScenarioVersionGT applies the GT predicate on the "scenarioVersion" field. func ScenarioVersionGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldGT(FieldScenarioVersion, v)) } // ScenarioVersionGTE applies the GTE predicate on the "scenarioVersion" field. func ScenarioVersionGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldScenarioVersion, v)) } // ScenarioVersionLT applies the LT predicate on the "scenarioVersion" field. func ScenarioVersionLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldLT(FieldScenarioVersion, v)) } // ScenarioVersionLTE applies the LTE predicate on the "scenarioVersion" field. func ScenarioVersionLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldScenarioVersion, v)) } // ScenarioVersionContains applies the Contains predicate on the "scenarioVersion" field. func ScenarioVersionContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldContains(FieldScenarioVersion, v)) } // ScenarioVersionHasPrefix applies the HasPrefix predicate on the "scenarioVersion" field. func ScenarioVersionHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldScenarioVersion, v)) } // ScenarioVersionHasSuffix applies the HasSuffix predicate on the "scenarioVersion" field. func ScenarioVersionHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldScenarioVersion, v)) } // ScenarioVersionIsNil applies the IsNil predicate on the "scenarioVersion" field. func ScenarioVersionIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldScenarioVersion))) - }) + return predicate.Alert(sql.FieldIsNull(FieldScenarioVersion)) } // ScenarioVersionNotNil applies the NotNil predicate on the "scenarioVersion" field. func ScenarioVersionNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldScenarioVersion))) - }) + return predicate.Alert(sql.FieldNotNull(FieldScenarioVersion)) } // ScenarioVersionEqualFold applies the EqualFold predicate on the "scenarioVersion" field. func ScenarioVersionEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldScenarioVersion, v)) } // ScenarioVersionContainsFold applies the ContainsFold predicate on the "scenarioVersion" field. func ScenarioVersionContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldScenarioVersion, v)) } // ScenarioHashEQ applies the EQ predicate on the "scenarioHash" field. func ScenarioHashEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldScenarioHash, v)) } // ScenarioHashNEQ applies the NEQ predicate on the "scenarioHash" field. func ScenarioHashNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldScenarioHash, v)) } // ScenarioHashIn applies the In predicate on the "scenarioHash" field. func ScenarioHashIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldScenarioHash), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldScenarioHash, vs...)) } // ScenarioHashNotIn applies the NotIn predicate on the "scenarioHash" field. func ScenarioHashNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldScenarioHash), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldScenarioHash, vs...)) } // ScenarioHashGT applies the GT predicate on the "scenarioHash" field. func ScenarioHashGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldGT(FieldScenarioHash, v)) } // ScenarioHashGTE applies the GTE predicate on the "scenarioHash" field. func ScenarioHashGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldScenarioHash, v)) } // ScenarioHashLT applies the LT predicate on the "scenarioHash" field. func ScenarioHashLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldLT(FieldScenarioHash, v)) } // ScenarioHashLTE applies the LTE predicate on the "scenarioHash" field. func ScenarioHashLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldScenarioHash, v)) } // ScenarioHashContains applies the Contains predicate on the "scenarioHash" field. func ScenarioHashContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldContains(FieldScenarioHash, v)) } // ScenarioHashHasPrefix applies the HasPrefix predicate on the "scenarioHash" field. func ScenarioHashHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldScenarioHash, v)) } // ScenarioHashHasSuffix applies the HasSuffix predicate on the "scenarioHash" field. func ScenarioHashHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldScenarioHash, v)) } // ScenarioHashIsNil applies the IsNil predicate on the "scenarioHash" field. func ScenarioHashIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldScenarioHash))) - }) + return predicate.Alert(sql.FieldIsNull(FieldScenarioHash)) } // ScenarioHashNotNil applies the NotNil predicate on the "scenarioHash" field. func ScenarioHashNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldScenarioHash))) - }) + return predicate.Alert(sql.FieldNotNull(FieldScenarioHash)) } // ScenarioHashEqualFold applies the EqualFold predicate on the "scenarioHash" field. func ScenarioHashEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldScenarioHash, v)) } // ScenarioHashContainsFold applies the ContainsFold predicate on the "scenarioHash" field. func ScenarioHashContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldScenarioHash, v)) } // SimulatedEQ applies the EQ predicate on the "simulated" field. func SimulatedEQ(v bool) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSimulated), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSimulated, v)) } // SimulatedNEQ applies the NEQ predicate on the "simulated" field. func SimulatedNEQ(v bool) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSimulated), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSimulated, v)) } // UUIDEQ applies the EQ predicate on the "uuid" field. func UUIDEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldUUID, v)) } // UUIDNEQ applies the NEQ predicate on the "uuid" field. func UUIDNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldUUID, v)) } // UUIDIn applies the In predicate on the "uuid" field. func UUIDIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUUID), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldUUID, vs...)) } // UUIDNotIn applies the NotIn predicate on the "uuid" field. func UUIDNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUUID), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldUUID, vs...)) } // UUIDGT applies the GT predicate on the "uuid" field. func UUIDGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldGT(FieldUUID, v)) } // UUIDGTE applies the GTE predicate on the "uuid" field. func UUIDGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldUUID, v)) } // UUIDLT applies the LT predicate on the "uuid" field. func UUIDLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldLT(FieldUUID, v)) } // UUIDLTE applies the LTE predicate on the "uuid" field. func UUIDLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldUUID, v)) } // UUIDContains applies the Contains predicate on the "uuid" field. func UUIDContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldContains(FieldUUID, v)) } // UUIDHasPrefix applies the HasPrefix predicate on the "uuid" field. func UUIDHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldUUID, v)) } // UUIDHasSuffix applies the HasSuffix predicate on the "uuid" field. func UUIDHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldUUID, v)) } // UUIDIsNil applies the IsNil predicate on the "uuid" field. func UUIDIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUUID))) - }) + return predicate.Alert(sql.FieldIsNull(FieldUUID)) } // UUIDNotNil applies the NotNil predicate on the "uuid" field. func UUIDNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUUID))) - }) + return predicate.Alert(sql.FieldNotNull(FieldUUID)) } // UUIDEqualFold applies the EqualFold predicate on the "uuid" field. func UUIDEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldUUID, v)) } // UUIDContainsFold applies the ContainsFold predicate on the "uuid" field. func UUIDContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldUUID, v)) +} + +// RemediationEQ applies the EQ predicate on the "remediation" field. +func RemediationEQ(v bool) predicate.Alert { + return predicate.Alert(sql.FieldEQ(FieldRemediation, v)) +} + +// RemediationNEQ applies the NEQ predicate on the "remediation" field. +func RemediationNEQ(v bool) predicate.Alert { + return predicate.Alert(sql.FieldNEQ(FieldRemediation, v)) +} + +// RemediationIsNil applies the IsNil predicate on the "remediation" field. +func RemediationIsNil() predicate.Alert { + return predicate.Alert(sql.FieldIsNull(FieldRemediation)) +} + +// RemediationNotNil applies the NotNil predicate on the "remediation" field. +func RemediationNotNil() predicate.Alert { + return predicate.Alert(sql.FieldNotNull(FieldRemediation)) } // HasOwner applies the HasEdge predicate on the "owner" edge. @@ -2453,7 +1630,6 @@ func HasOwner() predicate.Alert { return predicate.Alert(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(OwnerTable, FieldID), sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), ) sqlgraph.HasNeighbors(s, step) @@ -2463,11 +1639,7 @@ func HasOwner() predicate.Alert { // HasOwnerWith applies the HasEdge predicate on the "owner" edge with a given conditions (other predicates). func HasOwnerWith(preds ...predicate.Machine) predicate.Alert { return predicate.Alert(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(OwnerInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), - ) + step := newOwnerStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -2481,7 +1653,6 @@ func HasDecisions() predicate.Alert { return predicate.Alert(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(DecisionsTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, false, DecisionsTable, DecisionsColumn), ) sqlgraph.HasNeighbors(s, step) @@ -2491,11 +1662,7 @@ func HasDecisions() predicate.Alert { // HasDecisionsWith applies the HasEdge predicate on the "decisions" edge with a given conditions (other predicates). func HasDecisionsWith(preds ...predicate.Decision) predicate.Alert { return predicate.Alert(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(DecisionsInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.O2M, false, DecisionsTable, DecisionsColumn), - ) + step := newDecisionsStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -2509,7 +1676,6 @@ func HasEvents() predicate.Alert { return predicate.Alert(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(EventsTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, false, EventsTable, EventsColumn), ) sqlgraph.HasNeighbors(s, step) @@ -2519,11 +1685,7 @@ func HasEvents() predicate.Alert { // HasEventsWith applies the HasEdge predicate on the "events" edge with a given conditions (other predicates). func HasEventsWith(preds ...predicate.Event) predicate.Alert { return predicate.Alert(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(EventsInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.O2M, false, EventsTable, EventsColumn), - ) + step := newEventsStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -2537,7 +1699,6 @@ func HasMetas() predicate.Alert { return predicate.Alert(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(MetasTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, false, MetasTable, MetasColumn), ) sqlgraph.HasNeighbors(s, step) @@ -2547,11 +1708,7 @@ func HasMetas() predicate.Alert { // HasMetasWith applies the HasEdge predicate on the "metas" edge with a given conditions (other predicates). func HasMetasWith(preds ...predicate.Meta) predicate.Alert { return predicate.Alert(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(MetasInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.O2M, false, MetasTable, MetasColumn), - ) + step := newMetasStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -2562,32 +1719,15 @@ func HasMetasWith(preds ...predicate.Meta) predicate.Alert { // And groups predicates with the AND operator between them. func And(predicates ...predicate.Alert) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Alert(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.Alert) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Alert(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.Alert) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.Alert(sql.NotPredicates(p)) } diff --git a/pkg/database/ent/alert_create.go b/pkg/database/ent/alert_create.go index 42da5b137ba..753183a9eb9 100644 --- a/pkg/database/ent/alert_create.go +++ b/pkg/database/ent/alert_create.go @@ -338,6 +338,20 @@ func (ac *AlertCreate) SetNillableUUID(s *string) *AlertCreate { return ac } +// SetRemediation sets the "remediation" field. +func (ac *AlertCreate) SetRemediation(b bool) *AlertCreate { + ac.mutation.SetRemediation(b) + return ac +} + +// SetNillableRemediation sets the "remediation" field if the given value is not nil. +func (ac *AlertCreate) SetNillableRemediation(b *bool) *AlertCreate { + if b != nil { + ac.SetRemediation(*b) + } + return ac +} + // SetOwnerID sets the "owner" edge to the Machine entity by ID. func (ac *AlertCreate) SetOwnerID(id int) *AlertCreate { ac.mutation.SetOwnerID(id) @@ -409,50 +423,8 @@ func (ac *AlertCreate) Mutation() *AlertMutation { // Save creates the Alert in the database. func (ac *AlertCreate) Save(ctx context.Context) (*Alert, error) { - var ( - err error - node *Alert - ) ac.defaults() - if len(ac.hooks) == 0 { - if err = ac.check(); err != nil { - return nil, err - } - node, err = ac.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*AlertMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = ac.check(); err != nil { - return nil, err - } - ac.mutation = mutation - if node, err = ac.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(ac.hooks) - 1; i >= 0; i-- { - if ac.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = ac.hooks[i](mut) - } - v, err := mut.Mutate(ctx, ac.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Alert) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from AlertMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, ac.sqlSave, ac.mutation, ac.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -515,6 +487,12 @@ func (ac *AlertCreate) defaults() { // check runs all checks and user-defined validators on the builder. func (ac *AlertCreate) check() error { + if _, ok := ac.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Alert.created_at"`)} + } + if _, ok := ac.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "Alert.updated_at"`)} + } if _, ok := ac.mutation.Scenario(); !ok { return &ValidationError{Name: "scenario", err: errors.New(`ent: missing required field "Alert.scenario"`)} } @@ -525,6 +503,9 @@ func (ac *AlertCreate) check() error { } func (ac *AlertCreate) sqlSave(ctx context.Context) (*Alert, error) { + if err := ac.check(); err != nil { + return nil, err + } _node, _spec := ac.createSpec() if err := sqlgraph.CreateNode(ctx, ac.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -534,204 +515,112 @@ func (ac *AlertCreate) sqlSave(ctx context.Context) (*Alert, error) { } id := _spec.ID.Value.(int64) _node.ID = int(id) + ac.mutation.id = &_node.ID + ac.mutation.done = true return _node, nil } func (ac *AlertCreate) createSpec() (*Alert, *sqlgraph.CreateSpec) { var ( _node = &Alert{config: ac.config} - _spec = &sqlgraph.CreateSpec{ - Table: alert.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(alert.Table, sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt)) ) if value, ok := ac.mutation.CreatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldCreatedAt, - }) - _node.CreatedAt = &value + _spec.SetField(alert.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value } if value, ok := ac.mutation.UpdatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldUpdatedAt, - }) - _node.UpdatedAt = &value + _spec.SetField(alert.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value } if value, ok := ac.mutation.Scenario(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldScenario, - }) + _spec.SetField(alert.FieldScenario, field.TypeString, value) _node.Scenario = value } if value, ok := ac.mutation.BucketId(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldBucketId, - }) + _spec.SetField(alert.FieldBucketId, field.TypeString, value) _node.BucketId = value } if value, ok := ac.mutation.Message(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldMessage, - }) + _spec.SetField(alert.FieldMessage, field.TypeString, value) _node.Message = value } if value, ok := ac.mutation.EventsCount(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldEventsCount, - }) + _spec.SetField(alert.FieldEventsCount, field.TypeInt32, value) _node.EventsCount = value } if value, ok := ac.mutation.StartedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldStartedAt, - }) + _spec.SetField(alert.FieldStartedAt, field.TypeTime, value) _node.StartedAt = value } if value, ok := ac.mutation.StoppedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldStoppedAt, - }) + _spec.SetField(alert.FieldStoppedAt, field.TypeTime, value) _node.StoppedAt = value } if value, ok := ac.mutation.SourceIp(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceIp, - }) + _spec.SetField(alert.FieldSourceIp, field.TypeString, value) _node.SourceIp = value } if value, ok := ac.mutation.SourceRange(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceRange, - }) + _spec.SetField(alert.FieldSourceRange, field.TypeString, value) _node.SourceRange = value } if value, ok := ac.mutation.SourceAsNumber(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceAsNumber, - }) + _spec.SetField(alert.FieldSourceAsNumber, field.TypeString, value) _node.SourceAsNumber = value } if value, ok := ac.mutation.SourceAsName(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceAsName, - }) + _spec.SetField(alert.FieldSourceAsName, field.TypeString, value) _node.SourceAsName = value } if value, ok := ac.mutation.SourceCountry(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceCountry, - }) + _spec.SetField(alert.FieldSourceCountry, field.TypeString, value) _node.SourceCountry = value } if value, ok := ac.mutation.SourceLatitude(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLatitude, - }) + _spec.SetField(alert.FieldSourceLatitude, field.TypeFloat32, value) _node.SourceLatitude = value } if value, ok := ac.mutation.SourceLongitude(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLongitude, - }) + _spec.SetField(alert.FieldSourceLongitude, field.TypeFloat32, value) _node.SourceLongitude = value } if value, ok := ac.mutation.SourceScope(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceScope, - }) + _spec.SetField(alert.FieldSourceScope, field.TypeString, value) _node.SourceScope = value } if value, ok := ac.mutation.SourceValue(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceValue, - }) + _spec.SetField(alert.FieldSourceValue, field.TypeString, value) _node.SourceValue = value } if value, ok := ac.mutation.Capacity(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldCapacity, - }) + _spec.SetField(alert.FieldCapacity, field.TypeInt32, value) _node.Capacity = value } if value, ok := ac.mutation.LeakSpeed(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldLeakSpeed, - }) + _spec.SetField(alert.FieldLeakSpeed, field.TypeString, value) _node.LeakSpeed = value } if value, ok := ac.mutation.ScenarioVersion(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldScenarioVersion, - }) + _spec.SetField(alert.FieldScenarioVersion, field.TypeString, value) _node.ScenarioVersion = value } if value, ok := ac.mutation.ScenarioHash(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldScenarioHash, - }) + _spec.SetField(alert.FieldScenarioHash, field.TypeString, value) _node.ScenarioHash = value } if value, ok := ac.mutation.Simulated(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: alert.FieldSimulated, - }) + _spec.SetField(alert.FieldSimulated, field.TypeBool, value) _node.Simulated = value } if value, ok := ac.mutation.UUID(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldUUID, - }) + _spec.SetField(alert.FieldUUID, field.TypeString, value) _node.UUID = value } + if value, ok := ac.mutation.Remediation(); ok { + _spec.SetField(alert.FieldRemediation, field.TypeBool, value) + _node.Remediation = value + } if nodes := ac.mutation.OwnerIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -740,10 +629,7 @@ func (ac *AlertCreate) createSpec() (*Alert, *sqlgraph.CreateSpec) { Columns: []string{alert.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -760,10 +646,7 @@ func (ac *AlertCreate) createSpec() (*Alert, *sqlgraph.CreateSpec) { Columns: []string{alert.DecisionsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -779,10 +662,7 @@ func (ac *AlertCreate) createSpec() (*Alert, *sqlgraph.CreateSpec) { Columns: []string{alert.EventsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -798,10 +678,7 @@ func (ac *AlertCreate) createSpec() (*Alert, *sqlgraph.CreateSpec) { Columns: []string{alert.MetasColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -815,11 +692,15 @@ func (ac *AlertCreate) createSpec() (*Alert, *sqlgraph.CreateSpec) { // AlertCreateBulk is the builder for creating many Alert entities in bulk. type AlertCreateBulk struct { config + err error builders []*AlertCreate } // Save creates the Alert entities in the database. func (acb *AlertCreateBulk) Save(ctx context.Context) ([]*Alert, error) { + if acb.err != nil { + return nil, acb.err + } specs := make([]*sqlgraph.CreateSpec, len(acb.builders)) nodes := make([]*Alert, len(acb.builders)) mutators := make([]Mutator, len(acb.builders)) @@ -836,8 +717,8 @@ func (acb *AlertCreateBulk) Save(ctx context.Context) ([]*Alert, error) { return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, acb.builders[i+1].mutation) } else { diff --git a/pkg/database/ent/alert_delete.go b/pkg/database/ent/alert_delete.go index 014bcc2e0c6..15b3a4c822a 100644 --- a/pkg/database/ent/alert_delete.go +++ b/pkg/database/ent/alert_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (ad *AlertDelete) Where(ps ...predicate.Alert) *AlertDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (ad *AlertDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(ad.hooks) == 0 { - affected, err = ad.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*AlertMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - ad.mutation = mutation - affected, err = ad.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(ad.hooks) - 1; i >= 0; i-- { - if ad.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = ad.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, ad.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, ad.sqlExec, ad.mutation, ad.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (ad *AlertDelete) ExecX(ctx context.Context) int { } func (ad *AlertDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: alert.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(alert.Table, sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt)) if ps := ad.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (ad *AlertDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + ad.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type AlertDeleteOne struct { ad *AlertDelete } +// Where appends a list predicates to the AlertDelete builder. +func (ado *AlertDeleteOne) Where(ps ...predicate.Alert) *AlertDeleteOne { + ado.ad.mutation.Where(ps...) + return ado +} + // Exec executes the deletion query. func (ado *AlertDeleteOne) Exec(ctx context.Context) error { n, err := ado.ad.Exec(ctx) @@ -111,5 +82,7 @@ func (ado *AlertDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (ado *AlertDeleteOne) ExecX(ctx context.Context) { - ado.ad.ExecX(ctx) + if err := ado.Exec(ctx); err != nil { + panic(err) + } } diff --git a/pkg/database/ent/alert_query.go b/pkg/database/ent/alert_query.go index 68789196d24..7eddb6ce024 100644 --- a/pkg/database/ent/alert_query.go +++ b/pkg/database/ent/alert_query.go @@ -22,11 +22,9 @@ import ( // AlertQuery is the builder for querying Alert entities. type AlertQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string + ctx *QueryContext + order []alert.OrderOption + inters []Interceptor predicates []predicate.Alert withOwner *MachineQuery withDecisions *DecisionQuery @@ -44,34 +42,34 @@ func (aq *AlertQuery) Where(ps ...predicate.Alert) *AlertQuery { return aq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (aq *AlertQuery) Limit(limit int) *AlertQuery { - aq.limit = &limit + aq.ctx.Limit = &limit return aq } -// Offset adds an offset step to the query. +// Offset to start from. func (aq *AlertQuery) Offset(offset int) *AlertQuery { - aq.offset = &offset + aq.ctx.Offset = &offset return aq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (aq *AlertQuery) Unique(unique bool) *AlertQuery { - aq.unique = &unique + aq.ctx.Unique = &unique return aq } -// Order adds an order step to the query. -func (aq *AlertQuery) Order(o ...OrderFunc) *AlertQuery { +// Order specifies how the records should be ordered. +func (aq *AlertQuery) Order(o ...alert.OrderOption) *AlertQuery { aq.order = append(aq.order, o...) return aq } // QueryOwner chains the current query on the "owner" edge. func (aq *AlertQuery) QueryOwner() *MachineQuery { - query := &MachineQuery{config: aq.config} + query := (&MachineClient{config: aq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := aq.prepareQuery(ctx); err != nil { return nil, err @@ -93,7 +91,7 @@ func (aq *AlertQuery) QueryOwner() *MachineQuery { // QueryDecisions chains the current query on the "decisions" edge. func (aq *AlertQuery) QueryDecisions() *DecisionQuery { - query := &DecisionQuery{config: aq.config} + query := (&DecisionClient{config: aq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := aq.prepareQuery(ctx); err != nil { return nil, err @@ -115,7 +113,7 @@ func (aq *AlertQuery) QueryDecisions() *DecisionQuery { // QueryEvents chains the current query on the "events" edge. func (aq *AlertQuery) QueryEvents() *EventQuery { - query := &EventQuery{config: aq.config} + query := (&EventClient{config: aq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := aq.prepareQuery(ctx); err != nil { return nil, err @@ -137,7 +135,7 @@ func (aq *AlertQuery) QueryEvents() *EventQuery { // QueryMetas chains the current query on the "metas" edge. func (aq *AlertQuery) QueryMetas() *MetaQuery { - query := &MetaQuery{config: aq.config} + query := (&MetaClient{config: aq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := aq.prepareQuery(ctx); err != nil { return nil, err @@ -160,7 +158,7 @@ func (aq *AlertQuery) QueryMetas() *MetaQuery { // First returns the first Alert entity from the query. // Returns a *NotFoundError when no Alert was found. func (aq *AlertQuery) First(ctx context.Context) (*Alert, error) { - nodes, err := aq.Limit(1).All(ctx) + nodes, err := aq.Limit(1).All(setContextOp(ctx, aq.ctx, "First")) if err != nil { return nil, err } @@ -183,7 +181,7 @@ func (aq *AlertQuery) FirstX(ctx context.Context) *Alert { // Returns a *NotFoundError when no Alert ID was found. func (aq *AlertQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = aq.Limit(1).IDs(ctx); err != nil { + if ids, err = aq.Limit(1).IDs(setContextOp(ctx, aq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -206,7 +204,7 @@ func (aq *AlertQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Alert entity is found. // Returns a *NotFoundError when no Alert entities are found. func (aq *AlertQuery) Only(ctx context.Context) (*Alert, error) { - nodes, err := aq.Limit(2).All(ctx) + nodes, err := aq.Limit(2).All(setContextOp(ctx, aq.ctx, "Only")) if err != nil { return nil, err } @@ -234,7 +232,7 @@ func (aq *AlertQuery) OnlyX(ctx context.Context) *Alert { // Returns a *NotFoundError when no entities are found. func (aq *AlertQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = aq.Limit(2).IDs(ctx); err != nil { + if ids, err = aq.Limit(2).IDs(setContextOp(ctx, aq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -259,10 +257,12 @@ func (aq *AlertQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Alerts. func (aq *AlertQuery) All(ctx context.Context) ([]*Alert, error) { + ctx = setContextOp(ctx, aq.ctx, "All") if err := aq.prepareQuery(ctx); err != nil { return nil, err } - return aq.sqlAll(ctx) + qr := querierAll[[]*Alert, *AlertQuery]() + return withInterceptors[[]*Alert](ctx, aq, qr, aq.inters) } // AllX is like All, but panics if an error occurs. @@ -275,9 +275,12 @@ func (aq *AlertQuery) AllX(ctx context.Context) []*Alert { } // IDs executes the query and returns a list of Alert IDs. -func (aq *AlertQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := aq.Select(alert.FieldID).Scan(ctx, &ids); err != nil { +func (aq *AlertQuery) IDs(ctx context.Context) (ids []int, err error) { + if aq.ctx.Unique == nil && aq.path != nil { + aq.Unique(true) + } + ctx = setContextOp(ctx, aq.ctx, "IDs") + if err = aq.Select(alert.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -294,10 +297,11 @@ func (aq *AlertQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (aq *AlertQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, aq.ctx, "Count") if err := aq.prepareQuery(ctx); err != nil { return 0, err } - return aq.sqlCount(ctx) + return withInterceptors[int](ctx, aq, querierCount[*AlertQuery](), aq.inters) } // CountX is like Count, but panics if an error occurs. @@ -311,10 +315,15 @@ func (aq *AlertQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (aq *AlertQuery) Exist(ctx context.Context) (bool, error) { - if err := aq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, aq.ctx, "Exist") + switch _, err := aq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return aq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -334,25 +343,24 @@ func (aq *AlertQuery) Clone() *AlertQuery { } return &AlertQuery{ config: aq.config, - limit: aq.limit, - offset: aq.offset, - order: append([]OrderFunc{}, aq.order...), + ctx: aq.ctx.Clone(), + order: append([]alert.OrderOption{}, aq.order...), + inters: append([]Interceptor{}, aq.inters...), predicates: append([]predicate.Alert{}, aq.predicates...), withOwner: aq.withOwner.Clone(), withDecisions: aq.withDecisions.Clone(), withEvents: aq.withEvents.Clone(), withMetas: aq.withMetas.Clone(), // clone intermediate query. - sql: aq.sql.Clone(), - path: aq.path, - unique: aq.unique, + sql: aq.sql.Clone(), + path: aq.path, } } // WithOwner tells the query-builder to eager-load the nodes that are connected to // the "owner" edge. The optional arguments are used to configure the query builder of the edge. func (aq *AlertQuery) WithOwner(opts ...func(*MachineQuery)) *AlertQuery { - query := &MachineQuery{config: aq.config} + query := (&MachineClient{config: aq.config}).Query() for _, opt := range opts { opt(query) } @@ -363,7 +371,7 @@ func (aq *AlertQuery) WithOwner(opts ...func(*MachineQuery)) *AlertQuery { // WithDecisions tells the query-builder to eager-load the nodes that are connected to // the "decisions" edge. The optional arguments are used to configure the query builder of the edge. func (aq *AlertQuery) WithDecisions(opts ...func(*DecisionQuery)) *AlertQuery { - query := &DecisionQuery{config: aq.config} + query := (&DecisionClient{config: aq.config}).Query() for _, opt := range opts { opt(query) } @@ -374,7 +382,7 @@ func (aq *AlertQuery) WithDecisions(opts ...func(*DecisionQuery)) *AlertQuery { // WithEvents tells the query-builder to eager-load the nodes that are connected to // the "events" edge. The optional arguments are used to configure the query builder of the edge. func (aq *AlertQuery) WithEvents(opts ...func(*EventQuery)) *AlertQuery { - query := &EventQuery{config: aq.config} + query := (&EventClient{config: aq.config}).Query() for _, opt := range opts { opt(query) } @@ -385,7 +393,7 @@ func (aq *AlertQuery) WithEvents(opts ...func(*EventQuery)) *AlertQuery { // WithMetas tells the query-builder to eager-load the nodes that are connected to // the "metas" edge. The optional arguments are used to configure the query builder of the edge. func (aq *AlertQuery) WithMetas(opts ...func(*MetaQuery)) *AlertQuery { - query := &MetaQuery{config: aq.config} + query := (&MetaClient{config: aq.config}).Query() for _, opt := range opts { opt(query) } @@ -408,16 +416,11 @@ func (aq *AlertQuery) WithMetas(opts ...func(*MetaQuery)) *AlertQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (aq *AlertQuery) GroupBy(field string, fields ...string) *AlertGroupBy { - grbuild := &AlertGroupBy{config: aq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := aq.prepareQuery(ctx); err != nil { - return nil, err - } - return aq.sqlQuery(ctx), nil - } + aq.ctx.Fields = append([]string{field}, fields...) + grbuild := &AlertGroupBy{build: aq} + grbuild.flds = &aq.ctx.Fields grbuild.label = alert.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -434,15 +437,30 @@ func (aq *AlertQuery) GroupBy(field string, fields ...string) *AlertGroupBy { // Select(alert.FieldCreatedAt). // Scan(ctx, &v) func (aq *AlertQuery) Select(fields ...string) *AlertSelect { - aq.fields = append(aq.fields, fields...) - selbuild := &AlertSelect{AlertQuery: aq} - selbuild.label = alert.Label - selbuild.flds, selbuild.scan = &aq.fields, selbuild.Scan - return selbuild + aq.ctx.Fields = append(aq.ctx.Fields, fields...) + sbuild := &AlertSelect{AlertQuery: aq} + sbuild.label = alert.Label + sbuild.flds, sbuild.scan = &aq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a AlertSelect configured with the given aggregations. +func (aq *AlertQuery) Aggregate(fns ...AggregateFunc) *AlertSelect { + return aq.Select().Aggregate(fns...) } func (aq *AlertQuery) prepareQuery(ctx context.Context) error { - for _, f := range aq.fields { + for _, inter := range aq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, aq); err != nil { + return err + } + } + } + for _, f := range aq.ctx.Fields { if !alert.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -536,6 +554,9 @@ func (aq *AlertQuery) loadOwner(ctx context.Context, query *MachineQuery, nodes } nodeids[fk] = append(nodeids[fk], nodes[i]) } + if len(ids) == 0 { + return nil + } query.Where(machine.IDIn(ids...)) neighbors, err := query.All(ctx) if err != nil { @@ -562,8 +583,11 @@ func (aq *AlertQuery) loadDecisions(ctx context.Context, query *DecisionQuery, n init(nodes[i]) } } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(decision.FieldAlertDecisions) + } query.Where(predicate.Decision(func(s *sql.Selector) { - s.Where(sql.InValues(alert.DecisionsColumn, fks...)) + s.Where(sql.InValues(s.C(alert.DecisionsColumn), fks...)) })) neighbors, err := query.All(ctx) if err != nil { @@ -573,7 +597,7 @@ func (aq *AlertQuery) loadDecisions(ctx context.Context, query *DecisionQuery, n fk := n.AlertDecisions node, ok := nodeids[fk] if !ok { - return fmt.Errorf(`unexpected foreign-key "alert_decisions" returned %v for node %v`, fk, n.ID) + return fmt.Errorf(`unexpected referenced foreign-key "alert_decisions" returned %v for node %v`, fk, n.ID) } assign(node, n) } @@ -589,8 +613,11 @@ func (aq *AlertQuery) loadEvents(ctx context.Context, query *EventQuery, nodes [ init(nodes[i]) } } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(event.FieldAlertEvents) + } query.Where(predicate.Event(func(s *sql.Selector) { - s.Where(sql.InValues(alert.EventsColumn, fks...)) + s.Where(sql.InValues(s.C(alert.EventsColumn), fks...)) })) neighbors, err := query.All(ctx) if err != nil { @@ -600,7 +627,7 @@ func (aq *AlertQuery) loadEvents(ctx context.Context, query *EventQuery, nodes [ fk := n.AlertEvents node, ok := nodeids[fk] if !ok { - return fmt.Errorf(`unexpected foreign-key "alert_events" returned %v for node %v`, fk, n.ID) + return fmt.Errorf(`unexpected referenced foreign-key "alert_events" returned %v for node %v`, fk, n.ID) } assign(node, n) } @@ -616,8 +643,11 @@ func (aq *AlertQuery) loadMetas(ctx context.Context, query *MetaQuery, nodes []* init(nodes[i]) } } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(meta.FieldAlertMetas) + } query.Where(predicate.Meta(func(s *sql.Selector) { - s.Where(sql.InValues(alert.MetasColumn, fks...)) + s.Where(sql.InValues(s.C(alert.MetasColumn), fks...)) })) neighbors, err := query.All(ctx) if err != nil { @@ -627,7 +657,7 @@ func (aq *AlertQuery) loadMetas(ctx context.Context, query *MetaQuery, nodes []* fk := n.AlertMetas node, ok := nodeids[fk] if !ok { - return fmt.Errorf(`unexpected foreign-key "alert_metas" returned %v for node %v`, fk, n.ID) + return fmt.Errorf(`unexpected referenced foreign-key "alert_metas" returned %v for node %v`, fk, n.ID) } assign(node, n) } @@ -636,41 +666,22 @@ func (aq *AlertQuery) loadMetas(ctx context.Context, query *MetaQuery, nodes []* func (aq *AlertQuery) sqlCount(ctx context.Context) (int, error) { _spec := aq.querySpec() - _spec.Node.Columns = aq.fields - if len(aq.fields) > 0 { - _spec.Unique = aq.unique != nil && *aq.unique + _spec.Node.Columns = aq.ctx.Fields + if len(aq.ctx.Fields) > 0 { + _spec.Unique = aq.ctx.Unique != nil && *aq.ctx.Unique } return sqlgraph.CountNodes(ctx, aq.driver, _spec) } -func (aq *AlertQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := aq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (aq *AlertQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: alert.Table, - Columns: alert.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, - }, - From: aq.sql, - Unique: true, - } - if unique := aq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(alert.Table, alert.Columns, sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt)) + _spec.From = aq.sql + if unique := aq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if aq.path != nil { + _spec.Unique = true } - if fields := aq.fields; len(fields) > 0 { + if fields := aq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, alert.FieldID) for i := range fields { @@ -686,10 +697,10 @@ func (aq *AlertQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := aq.limit; limit != nil { + if limit := aq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := aq.offset; offset != nil { + if offset := aq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := aq.order; len(ps) > 0 { @@ -705,7 +716,7 @@ func (aq *AlertQuery) querySpec() *sqlgraph.QuerySpec { func (aq *AlertQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(aq.driver.Dialect()) t1 := builder.Table(alert.Table) - columns := aq.fields + columns := aq.ctx.Fields if len(columns) == 0 { columns = alert.Columns } @@ -714,7 +725,7 @@ func (aq *AlertQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = aq.sql selector.Select(selector.Columns(columns...)...) } - if aq.unique != nil && *aq.unique { + if aq.ctx.Unique != nil && *aq.ctx.Unique { selector.Distinct() } for _, p := range aq.predicates { @@ -723,12 +734,12 @@ func (aq *AlertQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range aq.order { p(selector) } - if offset := aq.offset; offset != nil { + if offset := aq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := aq.limit; limit != nil { + if limit := aq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -736,13 +747,8 @@ func (aq *AlertQuery) sqlQuery(ctx context.Context) *sql.Selector { // AlertGroupBy is the group-by builder for Alert entities. type AlertGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *AlertQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -751,74 +757,77 @@ func (agb *AlertGroupBy) Aggregate(fns ...AggregateFunc) *AlertGroupBy { return agb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (agb *AlertGroupBy) Scan(ctx context.Context, v any) error { - query, err := agb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, agb.build.ctx, "GroupBy") + if err := agb.build.prepareQuery(ctx); err != nil { return err } - agb.sql = query - return agb.sqlScan(ctx, v) + return scanWithInterceptors[*AlertQuery, *AlertGroupBy](ctx, agb.build, agb, agb.build.inters, v) } -func (agb *AlertGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range agb.fields { - if !alert.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (agb *AlertGroupBy) sqlScan(ctx context.Context, root *AlertQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(agb.fns)) + for _, fn := range agb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*agb.flds)+len(agb.fns)) + for _, f := range *agb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := agb.sqlQuery() + selector.GroupBy(selector.Columns(*agb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := agb.driver.Query(ctx, query, args, rows); err != nil { + if err := agb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (agb *AlertGroupBy) sqlQuery() *sql.Selector { - selector := agb.sql.Select() - aggregation := make([]string, 0, len(agb.fns)) - for _, fn := range agb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(agb.fields)+len(agb.fns)) - for _, f := range agb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(agb.fields...)...) -} - // AlertSelect is the builder for selecting fields of Alert entities. type AlertSelect struct { *AlertQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (as *AlertSelect) Aggregate(fns ...AggregateFunc) *AlertSelect { + as.fns = append(as.fns, fns...) + return as } // Scan applies the selector query and scans the result into the given value. func (as *AlertSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, as.ctx, "Select") if err := as.prepareQuery(ctx); err != nil { return err } - as.sql = as.AlertQuery.sqlQuery(ctx) - return as.sqlScan(ctx, v) + return scanWithInterceptors[*AlertQuery, *AlertSelect](ctx, as.AlertQuery, as, as.inters, v) } -func (as *AlertSelect) sqlScan(ctx context.Context, v any) error { +func (as *AlertSelect) sqlScan(ctx context.Context, root *AlertQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(as.fns)) + for _, fn := range as.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*as.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := as.sql.Query() + query, args := selector.Query() if err := as.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/pkg/database/ent/alert_update.go b/pkg/database/ent/alert_update.go index aaa12ef20a3..5f0e01ac09f 100644 --- a/pkg/database/ent/alert_update.go +++ b/pkg/database/ent/alert_update.go @@ -32,458 +32,12 @@ func (au *AlertUpdate) Where(ps ...predicate.Alert) *AlertUpdate { return au } -// SetCreatedAt sets the "created_at" field. -func (au *AlertUpdate) SetCreatedAt(t time.Time) *AlertUpdate { - au.mutation.SetCreatedAt(t) - return au -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (au *AlertUpdate) ClearCreatedAt() *AlertUpdate { - au.mutation.ClearCreatedAt() - return au -} - // SetUpdatedAt sets the "updated_at" field. func (au *AlertUpdate) SetUpdatedAt(t time.Time) *AlertUpdate { au.mutation.SetUpdatedAt(t) return au } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (au *AlertUpdate) ClearUpdatedAt() *AlertUpdate { - au.mutation.ClearUpdatedAt() - return au -} - -// SetScenario sets the "scenario" field. -func (au *AlertUpdate) SetScenario(s string) *AlertUpdate { - au.mutation.SetScenario(s) - return au -} - -// SetBucketId sets the "bucketId" field. -func (au *AlertUpdate) SetBucketId(s string) *AlertUpdate { - au.mutation.SetBucketId(s) - return au -} - -// SetNillableBucketId sets the "bucketId" field if the given value is not nil. -func (au *AlertUpdate) SetNillableBucketId(s *string) *AlertUpdate { - if s != nil { - au.SetBucketId(*s) - } - return au -} - -// ClearBucketId clears the value of the "bucketId" field. -func (au *AlertUpdate) ClearBucketId() *AlertUpdate { - au.mutation.ClearBucketId() - return au -} - -// SetMessage sets the "message" field. -func (au *AlertUpdate) SetMessage(s string) *AlertUpdate { - au.mutation.SetMessage(s) - return au -} - -// SetNillableMessage sets the "message" field if the given value is not nil. -func (au *AlertUpdate) SetNillableMessage(s *string) *AlertUpdate { - if s != nil { - au.SetMessage(*s) - } - return au -} - -// ClearMessage clears the value of the "message" field. -func (au *AlertUpdate) ClearMessage() *AlertUpdate { - au.mutation.ClearMessage() - return au -} - -// SetEventsCount sets the "eventsCount" field. -func (au *AlertUpdate) SetEventsCount(i int32) *AlertUpdate { - au.mutation.ResetEventsCount() - au.mutation.SetEventsCount(i) - return au -} - -// SetNillableEventsCount sets the "eventsCount" field if the given value is not nil. -func (au *AlertUpdate) SetNillableEventsCount(i *int32) *AlertUpdate { - if i != nil { - au.SetEventsCount(*i) - } - return au -} - -// AddEventsCount adds i to the "eventsCount" field. -func (au *AlertUpdate) AddEventsCount(i int32) *AlertUpdate { - au.mutation.AddEventsCount(i) - return au -} - -// ClearEventsCount clears the value of the "eventsCount" field. -func (au *AlertUpdate) ClearEventsCount() *AlertUpdate { - au.mutation.ClearEventsCount() - return au -} - -// SetStartedAt sets the "startedAt" field. -func (au *AlertUpdate) SetStartedAt(t time.Time) *AlertUpdate { - au.mutation.SetStartedAt(t) - return au -} - -// SetNillableStartedAt sets the "startedAt" field if the given value is not nil. -func (au *AlertUpdate) SetNillableStartedAt(t *time.Time) *AlertUpdate { - if t != nil { - au.SetStartedAt(*t) - } - return au -} - -// ClearStartedAt clears the value of the "startedAt" field. -func (au *AlertUpdate) ClearStartedAt() *AlertUpdate { - au.mutation.ClearStartedAt() - return au -} - -// SetStoppedAt sets the "stoppedAt" field. -func (au *AlertUpdate) SetStoppedAt(t time.Time) *AlertUpdate { - au.mutation.SetStoppedAt(t) - return au -} - -// SetNillableStoppedAt sets the "stoppedAt" field if the given value is not nil. -func (au *AlertUpdate) SetNillableStoppedAt(t *time.Time) *AlertUpdate { - if t != nil { - au.SetStoppedAt(*t) - } - return au -} - -// ClearStoppedAt clears the value of the "stoppedAt" field. -func (au *AlertUpdate) ClearStoppedAt() *AlertUpdate { - au.mutation.ClearStoppedAt() - return au -} - -// SetSourceIp sets the "sourceIp" field. -func (au *AlertUpdate) SetSourceIp(s string) *AlertUpdate { - au.mutation.SetSourceIp(s) - return au -} - -// SetNillableSourceIp sets the "sourceIp" field if the given value is not nil. -func (au *AlertUpdate) SetNillableSourceIp(s *string) *AlertUpdate { - if s != nil { - au.SetSourceIp(*s) - } - return au -} - -// ClearSourceIp clears the value of the "sourceIp" field. -func (au *AlertUpdate) ClearSourceIp() *AlertUpdate { - au.mutation.ClearSourceIp() - return au -} - -// SetSourceRange sets the "sourceRange" field. -func (au *AlertUpdate) SetSourceRange(s string) *AlertUpdate { - au.mutation.SetSourceRange(s) - return au -} - -// SetNillableSourceRange sets the "sourceRange" field if the given value is not nil. -func (au *AlertUpdate) SetNillableSourceRange(s *string) *AlertUpdate { - if s != nil { - au.SetSourceRange(*s) - } - return au -} - -// ClearSourceRange clears the value of the "sourceRange" field. -func (au *AlertUpdate) ClearSourceRange() *AlertUpdate { - au.mutation.ClearSourceRange() - return au -} - -// SetSourceAsNumber sets the "sourceAsNumber" field. -func (au *AlertUpdate) SetSourceAsNumber(s string) *AlertUpdate { - au.mutation.SetSourceAsNumber(s) - return au -} - -// SetNillableSourceAsNumber sets the "sourceAsNumber" field if the given value is not nil. -func (au *AlertUpdate) SetNillableSourceAsNumber(s *string) *AlertUpdate { - if s != nil { - au.SetSourceAsNumber(*s) - } - return au -} - -// ClearSourceAsNumber clears the value of the "sourceAsNumber" field. -func (au *AlertUpdate) ClearSourceAsNumber() *AlertUpdate { - au.mutation.ClearSourceAsNumber() - return au -} - -// SetSourceAsName sets the "sourceAsName" field. -func (au *AlertUpdate) SetSourceAsName(s string) *AlertUpdate { - au.mutation.SetSourceAsName(s) - return au -} - -// SetNillableSourceAsName sets the "sourceAsName" field if the given value is not nil. -func (au *AlertUpdate) SetNillableSourceAsName(s *string) *AlertUpdate { - if s != nil { - au.SetSourceAsName(*s) - } - return au -} - -// ClearSourceAsName clears the value of the "sourceAsName" field. -func (au *AlertUpdate) ClearSourceAsName() *AlertUpdate { - au.mutation.ClearSourceAsName() - return au -} - -// SetSourceCountry sets the "sourceCountry" field. -func (au *AlertUpdate) SetSourceCountry(s string) *AlertUpdate { - au.mutation.SetSourceCountry(s) - return au -} - -// SetNillableSourceCountry sets the "sourceCountry" field if the given value is not nil. -func (au *AlertUpdate) SetNillableSourceCountry(s *string) *AlertUpdate { - if s != nil { - au.SetSourceCountry(*s) - } - return au -} - -// ClearSourceCountry clears the value of the "sourceCountry" field. -func (au *AlertUpdate) ClearSourceCountry() *AlertUpdate { - au.mutation.ClearSourceCountry() - return au -} - -// SetSourceLatitude sets the "sourceLatitude" field. -func (au *AlertUpdate) SetSourceLatitude(f float32) *AlertUpdate { - au.mutation.ResetSourceLatitude() - au.mutation.SetSourceLatitude(f) - return au -} - -// SetNillableSourceLatitude sets the "sourceLatitude" field if the given value is not nil. -func (au *AlertUpdate) SetNillableSourceLatitude(f *float32) *AlertUpdate { - if f != nil { - au.SetSourceLatitude(*f) - } - return au -} - -// AddSourceLatitude adds f to the "sourceLatitude" field. -func (au *AlertUpdate) AddSourceLatitude(f float32) *AlertUpdate { - au.mutation.AddSourceLatitude(f) - return au -} - -// ClearSourceLatitude clears the value of the "sourceLatitude" field. -func (au *AlertUpdate) ClearSourceLatitude() *AlertUpdate { - au.mutation.ClearSourceLatitude() - return au -} - -// SetSourceLongitude sets the "sourceLongitude" field. -func (au *AlertUpdate) SetSourceLongitude(f float32) *AlertUpdate { - au.mutation.ResetSourceLongitude() - au.mutation.SetSourceLongitude(f) - return au -} - -// SetNillableSourceLongitude sets the "sourceLongitude" field if the given value is not nil. -func (au *AlertUpdate) SetNillableSourceLongitude(f *float32) *AlertUpdate { - if f != nil { - au.SetSourceLongitude(*f) - } - return au -} - -// AddSourceLongitude adds f to the "sourceLongitude" field. -func (au *AlertUpdate) AddSourceLongitude(f float32) *AlertUpdate { - au.mutation.AddSourceLongitude(f) - return au -} - -// ClearSourceLongitude clears the value of the "sourceLongitude" field. -func (au *AlertUpdate) ClearSourceLongitude() *AlertUpdate { - au.mutation.ClearSourceLongitude() - return au -} - -// SetSourceScope sets the "sourceScope" field. -func (au *AlertUpdate) SetSourceScope(s string) *AlertUpdate { - au.mutation.SetSourceScope(s) - return au -} - -// SetNillableSourceScope sets the "sourceScope" field if the given value is not nil. -func (au *AlertUpdate) SetNillableSourceScope(s *string) *AlertUpdate { - if s != nil { - au.SetSourceScope(*s) - } - return au -} - -// ClearSourceScope clears the value of the "sourceScope" field. -func (au *AlertUpdate) ClearSourceScope() *AlertUpdate { - au.mutation.ClearSourceScope() - return au -} - -// SetSourceValue sets the "sourceValue" field. -func (au *AlertUpdate) SetSourceValue(s string) *AlertUpdate { - au.mutation.SetSourceValue(s) - return au -} - -// SetNillableSourceValue sets the "sourceValue" field if the given value is not nil. -func (au *AlertUpdate) SetNillableSourceValue(s *string) *AlertUpdate { - if s != nil { - au.SetSourceValue(*s) - } - return au -} - -// ClearSourceValue clears the value of the "sourceValue" field. -func (au *AlertUpdate) ClearSourceValue() *AlertUpdate { - au.mutation.ClearSourceValue() - return au -} - -// SetCapacity sets the "capacity" field. -func (au *AlertUpdate) SetCapacity(i int32) *AlertUpdate { - au.mutation.ResetCapacity() - au.mutation.SetCapacity(i) - return au -} - -// SetNillableCapacity sets the "capacity" field if the given value is not nil. -func (au *AlertUpdate) SetNillableCapacity(i *int32) *AlertUpdate { - if i != nil { - au.SetCapacity(*i) - } - return au -} - -// AddCapacity adds i to the "capacity" field. -func (au *AlertUpdate) AddCapacity(i int32) *AlertUpdate { - au.mutation.AddCapacity(i) - return au -} - -// ClearCapacity clears the value of the "capacity" field. -func (au *AlertUpdate) ClearCapacity() *AlertUpdate { - au.mutation.ClearCapacity() - return au -} - -// SetLeakSpeed sets the "leakSpeed" field. -func (au *AlertUpdate) SetLeakSpeed(s string) *AlertUpdate { - au.mutation.SetLeakSpeed(s) - return au -} - -// SetNillableLeakSpeed sets the "leakSpeed" field if the given value is not nil. -func (au *AlertUpdate) SetNillableLeakSpeed(s *string) *AlertUpdate { - if s != nil { - au.SetLeakSpeed(*s) - } - return au -} - -// ClearLeakSpeed clears the value of the "leakSpeed" field. -func (au *AlertUpdate) ClearLeakSpeed() *AlertUpdate { - au.mutation.ClearLeakSpeed() - return au -} - -// SetScenarioVersion sets the "scenarioVersion" field. -func (au *AlertUpdate) SetScenarioVersion(s string) *AlertUpdate { - au.mutation.SetScenarioVersion(s) - return au -} - -// SetNillableScenarioVersion sets the "scenarioVersion" field if the given value is not nil. -func (au *AlertUpdate) SetNillableScenarioVersion(s *string) *AlertUpdate { - if s != nil { - au.SetScenarioVersion(*s) - } - return au -} - -// ClearScenarioVersion clears the value of the "scenarioVersion" field. -func (au *AlertUpdate) ClearScenarioVersion() *AlertUpdate { - au.mutation.ClearScenarioVersion() - return au -} - -// SetScenarioHash sets the "scenarioHash" field. -func (au *AlertUpdate) SetScenarioHash(s string) *AlertUpdate { - au.mutation.SetScenarioHash(s) - return au -} - -// SetNillableScenarioHash sets the "scenarioHash" field if the given value is not nil. -func (au *AlertUpdate) SetNillableScenarioHash(s *string) *AlertUpdate { - if s != nil { - au.SetScenarioHash(*s) - } - return au -} - -// ClearScenarioHash clears the value of the "scenarioHash" field. -func (au *AlertUpdate) ClearScenarioHash() *AlertUpdate { - au.mutation.ClearScenarioHash() - return au -} - -// SetSimulated sets the "simulated" field. -func (au *AlertUpdate) SetSimulated(b bool) *AlertUpdate { - au.mutation.SetSimulated(b) - return au -} - -// SetNillableSimulated sets the "simulated" field if the given value is not nil. -func (au *AlertUpdate) SetNillableSimulated(b *bool) *AlertUpdate { - if b != nil { - au.SetSimulated(*b) - } - return au -} - -// SetUUID sets the "uuid" field. -func (au *AlertUpdate) SetUUID(s string) *AlertUpdate { - au.mutation.SetUUID(s) - return au -} - -// SetNillableUUID sets the "uuid" field if the given value is not nil. -func (au *AlertUpdate) SetNillableUUID(s *string) *AlertUpdate { - if s != nil { - au.SetUUID(*s) - } - return au -} - -// ClearUUID clears the value of the "uuid" field. -func (au *AlertUpdate) ClearUUID() *AlertUpdate { - au.mutation.ClearUUID() - return au -} - // SetOwnerID sets the "owner" edge to the Machine entity by ID. func (au *AlertUpdate) SetOwnerID(id int) *AlertUpdate { au.mutation.SetOwnerID(id) @@ -624,35 +178,8 @@ func (au *AlertUpdate) RemoveMetas(m ...*Meta) *AlertUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (au *AlertUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) au.defaults() - if len(au.hooks) == 0 { - affected, err = au.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*AlertMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - au.mutation = mutation - affected, err = au.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(au.hooks) - 1; i >= 0; i-- { - if au.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = au.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, au.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, au.sqlSave, au.mutation, au.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -679,27 +206,14 @@ func (au *AlertUpdate) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (au *AlertUpdate) defaults() { - if _, ok := au.mutation.CreatedAt(); !ok && !au.mutation.CreatedAtCleared() { - v := alert.UpdateDefaultCreatedAt() - au.mutation.SetCreatedAt(v) - } - if _, ok := au.mutation.UpdatedAt(); !ok && !au.mutation.UpdatedAtCleared() { + if _, ok := au.mutation.UpdatedAt(); !ok { v := alert.UpdateDefaultUpdatedAt() au.mutation.SetUpdatedAt(v) } } func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: alert.Table, - Columns: alert.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(alert.Table, alert.Columns, sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt)) if ps := au.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -707,320 +221,68 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { } } } - if value, ok := au.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldCreatedAt, - }) - } - if au.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: alert.FieldCreatedAt, - }) - } if value, ok := au.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldUpdatedAt, - }) - } - if au.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: alert.FieldUpdatedAt, - }) - } - if value, ok := au.mutation.Scenario(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldScenario, - }) - } - if value, ok := au.mutation.BucketId(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldBucketId, - }) + _spec.SetField(alert.FieldUpdatedAt, field.TypeTime, value) } if au.mutation.BucketIdCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldBucketId, - }) - } - if value, ok := au.mutation.Message(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldMessage, - }) + _spec.ClearField(alert.FieldBucketId, field.TypeString) } if au.mutation.MessageCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldMessage, - }) - } - if value, ok := au.mutation.EventsCount(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldEventsCount, - }) - } - if value, ok := au.mutation.AddedEventsCount(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldEventsCount, - }) + _spec.ClearField(alert.FieldMessage, field.TypeString) } if au.mutation.EventsCountCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Column: alert.FieldEventsCount, - }) - } - if value, ok := au.mutation.StartedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldStartedAt, - }) + _spec.ClearField(alert.FieldEventsCount, field.TypeInt32) } if au.mutation.StartedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: alert.FieldStartedAt, - }) - } - if value, ok := au.mutation.StoppedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldStoppedAt, - }) + _spec.ClearField(alert.FieldStartedAt, field.TypeTime) } if au.mutation.StoppedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: alert.FieldStoppedAt, - }) - } - if value, ok := au.mutation.SourceIp(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceIp, - }) + _spec.ClearField(alert.FieldStoppedAt, field.TypeTime) } if au.mutation.SourceIpCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceIp, - }) - } - if value, ok := au.mutation.SourceRange(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceRange, - }) + _spec.ClearField(alert.FieldSourceIp, field.TypeString) } if au.mutation.SourceRangeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceRange, - }) - } - if value, ok := au.mutation.SourceAsNumber(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceAsNumber, - }) + _spec.ClearField(alert.FieldSourceRange, field.TypeString) } if au.mutation.SourceAsNumberCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceAsNumber, - }) - } - if value, ok := au.mutation.SourceAsName(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceAsName, - }) + _spec.ClearField(alert.FieldSourceAsNumber, field.TypeString) } if au.mutation.SourceAsNameCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceAsName, - }) - } - if value, ok := au.mutation.SourceCountry(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceCountry, - }) + _spec.ClearField(alert.FieldSourceAsName, field.TypeString) } if au.mutation.SourceCountryCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceCountry, - }) - } - if value, ok := au.mutation.SourceLatitude(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLatitude, - }) - } - if value, ok := au.mutation.AddedSourceLatitude(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLatitude, - }) + _spec.ClearField(alert.FieldSourceCountry, field.TypeString) } if au.mutation.SourceLatitudeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Column: alert.FieldSourceLatitude, - }) - } - if value, ok := au.mutation.SourceLongitude(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLongitude, - }) - } - if value, ok := au.mutation.AddedSourceLongitude(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLongitude, - }) + _spec.ClearField(alert.FieldSourceLatitude, field.TypeFloat32) } if au.mutation.SourceLongitudeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Column: alert.FieldSourceLongitude, - }) - } - if value, ok := au.mutation.SourceScope(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceScope, - }) + _spec.ClearField(alert.FieldSourceLongitude, field.TypeFloat32) } if au.mutation.SourceScopeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceScope, - }) - } - if value, ok := au.mutation.SourceValue(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceValue, - }) + _spec.ClearField(alert.FieldSourceScope, field.TypeString) } if au.mutation.SourceValueCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceValue, - }) - } - if value, ok := au.mutation.Capacity(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldCapacity, - }) - } - if value, ok := au.mutation.AddedCapacity(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldCapacity, - }) + _spec.ClearField(alert.FieldSourceValue, field.TypeString) } if au.mutation.CapacityCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Column: alert.FieldCapacity, - }) - } - if value, ok := au.mutation.LeakSpeed(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldLeakSpeed, - }) + _spec.ClearField(alert.FieldCapacity, field.TypeInt32) } if au.mutation.LeakSpeedCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldLeakSpeed, - }) - } - if value, ok := au.mutation.ScenarioVersion(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldScenarioVersion, - }) + _spec.ClearField(alert.FieldLeakSpeed, field.TypeString) } if au.mutation.ScenarioVersionCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldScenarioVersion, - }) - } - if value, ok := au.mutation.ScenarioHash(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldScenarioHash, - }) + _spec.ClearField(alert.FieldScenarioVersion, field.TypeString) } if au.mutation.ScenarioHashCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldScenarioHash, - }) - } - if value, ok := au.mutation.Simulated(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: alert.FieldSimulated, - }) - } - if value, ok := au.mutation.UUID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldUUID, - }) + _spec.ClearField(alert.FieldScenarioHash, field.TypeString) } if au.mutation.UUIDCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldUUID, - }) + _spec.ClearField(alert.FieldUUID, field.TypeString) + } + if au.mutation.RemediationCleared() { + _spec.ClearField(alert.FieldRemediation, field.TypeBool) } if au.mutation.OwnerCleared() { edge := &sqlgraph.EdgeSpec{ @@ -1030,10 +292,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -1046,10 +305,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1065,10 +321,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.DecisionsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -1081,10 +334,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.DecisionsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1100,10 +350,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.DecisionsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1119,10 +366,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.EventsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -1135,10 +379,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.EventsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1154,10 +395,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.EventsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1173,10 +411,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.MetasColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -1189,10 +424,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.MetasColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1208,10 +440,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.MetasColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1227,6 +456,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + au.mutation.done = true return n, nil } @@ -1238,458 +468,12 @@ type AlertUpdateOne struct { mutation *AlertMutation } -// SetCreatedAt sets the "created_at" field. -func (auo *AlertUpdateOne) SetCreatedAt(t time.Time) *AlertUpdateOne { - auo.mutation.SetCreatedAt(t) - return auo -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (auo *AlertUpdateOne) ClearCreatedAt() *AlertUpdateOne { - auo.mutation.ClearCreatedAt() - return auo -} - // SetUpdatedAt sets the "updated_at" field. func (auo *AlertUpdateOne) SetUpdatedAt(t time.Time) *AlertUpdateOne { auo.mutation.SetUpdatedAt(t) return auo } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (auo *AlertUpdateOne) ClearUpdatedAt() *AlertUpdateOne { - auo.mutation.ClearUpdatedAt() - return auo -} - -// SetScenario sets the "scenario" field. -func (auo *AlertUpdateOne) SetScenario(s string) *AlertUpdateOne { - auo.mutation.SetScenario(s) - return auo -} - -// SetBucketId sets the "bucketId" field. -func (auo *AlertUpdateOne) SetBucketId(s string) *AlertUpdateOne { - auo.mutation.SetBucketId(s) - return auo -} - -// SetNillableBucketId sets the "bucketId" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableBucketId(s *string) *AlertUpdateOne { - if s != nil { - auo.SetBucketId(*s) - } - return auo -} - -// ClearBucketId clears the value of the "bucketId" field. -func (auo *AlertUpdateOne) ClearBucketId() *AlertUpdateOne { - auo.mutation.ClearBucketId() - return auo -} - -// SetMessage sets the "message" field. -func (auo *AlertUpdateOne) SetMessage(s string) *AlertUpdateOne { - auo.mutation.SetMessage(s) - return auo -} - -// SetNillableMessage sets the "message" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableMessage(s *string) *AlertUpdateOne { - if s != nil { - auo.SetMessage(*s) - } - return auo -} - -// ClearMessage clears the value of the "message" field. -func (auo *AlertUpdateOne) ClearMessage() *AlertUpdateOne { - auo.mutation.ClearMessage() - return auo -} - -// SetEventsCount sets the "eventsCount" field. -func (auo *AlertUpdateOne) SetEventsCount(i int32) *AlertUpdateOne { - auo.mutation.ResetEventsCount() - auo.mutation.SetEventsCount(i) - return auo -} - -// SetNillableEventsCount sets the "eventsCount" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableEventsCount(i *int32) *AlertUpdateOne { - if i != nil { - auo.SetEventsCount(*i) - } - return auo -} - -// AddEventsCount adds i to the "eventsCount" field. -func (auo *AlertUpdateOne) AddEventsCount(i int32) *AlertUpdateOne { - auo.mutation.AddEventsCount(i) - return auo -} - -// ClearEventsCount clears the value of the "eventsCount" field. -func (auo *AlertUpdateOne) ClearEventsCount() *AlertUpdateOne { - auo.mutation.ClearEventsCount() - return auo -} - -// SetStartedAt sets the "startedAt" field. -func (auo *AlertUpdateOne) SetStartedAt(t time.Time) *AlertUpdateOne { - auo.mutation.SetStartedAt(t) - return auo -} - -// SetNillableStartedAt sets the "startedAt" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableStartedAt(t *time.Time) *AlertUpdateOne { - if t != nil { - auo.SetStartedAt(*t) - } - return auo -} - -// ClearStartedAt clears the value of the "startedAt" field. -func (auo *AlertUpdateOne) ClearStartedAt() *AlertUpdateOne { - auo.mutation.ClearStartedAt() - return auo -} - -// SetStoppedAt sets the "stoppedAt" field. -func (auo *AlertUpdateOne) SetStoppedAt(t time.Time) *AlertUpdateOne { - auo.mutation.SetStoppedAt(t) - return auo -} - -// SetNillableStoppedAt sets the "stoppedAt" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableStoppedAt(t *time.Time) *AlertUpdateOne { - if t != nil { - auo.SetStoppedAt(*t) - } - return auo -} - -// ClearStoppedAt clears the value of the "stoppedAt" field. -func (auo *AlertUpdateOne) ClearStoppedAt() *AlertUpdateOne { - auo.mutation.ClearStoppedAt() - return auo -} - -// SetSourceIp sets the "sourceIp" field. -func (auo *AlertUpdateOne) SetSourceIp(s string) *AlertUpdateOne { - auo.mutation.SetSourceIp(s) - return auo -} - -// SetNillableSourceIp sets the "sourceIp" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableSourceIp(s *string) *AlertUpdateOne { - if s != nil { - auo.SetSourceIp(*s) - } - return auo -} - -// ClearSourceIp clears the value of the "sourceIp" field. -func (auo *AlertUpdateOne) ClearSourceIp() *AlertUpdateOne { - auo.mutation.ClearSourceIp() - return auo -} - -// SetSourceRange sets the "sourceRange" field. -func (auo *AlertUpdateOne) SetSourceRange(s string) *AlertUpdateOne { - auo.mutation.SetSourceRange(s) - return auo -} - -// SetNillableSourceRange sets the "sourceRange" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableSourceRange(s *string) *AlertUpdateOne { - if s != nil { - auo.SetSourceRange(*s) - } - return auo -} - -// ClearSourceRange clears the value of the "sourceRange" field. -func (auo *AlertUpdateOne) ClearSourceRange() *AlertUpdateOne { - auo.mutation.ClearSourceRange() - return auo -} - -// SetSourceAsNumber sets the "sourceAsNumber" field. -func (auo *AlertUpdateOne) SetSourceAsNumber(s string) *AlertUpdateOne { - auo.mutation.SetSourceAsNumber(s) - return auo -} - -// SetNillableSourceAsNumber sets the "sourceAsNumber" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableSourceAsNumber(s *string) *AlertUpdateOne { - if s != nil { - auo.SetSourceAsNumber(*s) - } - return auo -} - -// ClearSourceAsNumber clears the value of the "sourceAsNumber" field. -func (auo *AlertUpdateOne) ClearSourceAsNumber() *AlertUpdateOne { - auo.mutation.ClearSourceAsNumber() - return auo -} - -// SetSourceAsName sets the "sourceAsName" field. -func (auo *AlertUpdateOne) SetSourceAsName(s string) *AlertUpdateOne { - auo.mutation.SetSourceAsName(s) - return auo -} - -// SetNillableSourceAsName sets the "sourceAsName" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableSourceAsName(s *string) *AlertUpdateOne { - if s != nil { - auo.SetSourceAsName(*s) - } - return auo -} - -// ClearSourceAsName clears the value of the "sourceAsName" field. -func (auo *AlertUpdateOne) ClearSourceAsName() *AlertUpdateOne { - auo.mutation.ClearSourceAsName() - return auo -} - -// SetSourceCountry sets the "sourceCountry" field. -func (auo *AlertUpdateOne) SetSourceCountry(s string) *AlertUpdateOne { - auo.mutation.SetSourceCountry(s) - return auo -} - -// SetNillableSourceCountry sets the "sourceCountry" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableSourceCountry(s *string) *AlertUpdateOne { - if s != nil { - auo.SetSourceCountry(*s) - } - return auo -} - -// ClearSourceCountry clears the value of the "sourceCountry" field. -func (auo *AlertUpdateOne) ClearSourceCountry() *AlertUpdateOne { - auo.mutation.ClearSourceCountry() - return auo -} - -// SetSourceLatitude sets the "sourceLatitude" field. -func (auo *AlertUpdateOne) SetSourceLatitude(f float32) *AlertUpdateOne { - auo.mutation.ResetSourceLatitude() - auo.mutation.SetSourceLatitude(f) - return auo -} - -// SetNillableSourceLatitude sets the "sourceLatitude" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableSourceLatitude(f *float32) *AlertUpdateOne { - if f != nil { - auo.SetSourceLatitude(*f) - } - return auo -} - -// AddSourceLatitude adds f to the "sourceLatitude" field. -func (auo *AlertUpdateOne) AddSourceLatitude(f float32) *AlertUpdateOne { - auo.mutation.AddSourceLatitude(f) - return auo -} - -// ClearSourceLatitude clears the value of the "sourceLatitude" field. -func (auo *AlertUpdateOne) ClearSourceLatitude() *AlertUpdateOne { - auo.mutation.ClearSourceLatitude() - return auo -} - -// SetSourceLongitude sets the "sourceLongitude" field. -func (auo *AlertUpdateOne) SetSourceLongitude(f float32) *AlertUpdateOne { - auo.mutation.ResetSourceLongitude() - auo.mutation.SetSourceLongitude(f) - return auo -} - -// SetNillableSourceLongitude sets the "sourceLongitude" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableSourceLongitude(f *float32) *AlertUpdateOne { - if f != nil { - auo.SetSourceLongitude(*f) - } - return auo -} - -// AddSourceLongitude adds f to the "sourceLongitude" field. -func (auo *AlertUpdateOne) AddSourceLongitude(f float32) *AlertUpdateOne { - auo.mutation.AddSourceLongitude(f) - return auo -} - -// ClearSourceLongitude clears the value of the "sourceLongitude" field. -func (auo *AlertUpdateOne) ClearSourceLongitude() *AlertUpdateOne { - auo.mutation.ClearSourceLongitude() - return auo -} - -// SetSourceScope sets the "sourceScope" field. -func (auo *AlertUpdateOne) SetSourceScope(s string) *AlertUpdateOne { - auo.mutation.SetSourceScope(s) - return auo -} - -// SetNillableSourceScope sets the "sourceScope" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableSourceScope(s *string) *AlertUpdateOne { - if s != nil { - auo.SetSourceScope(*s) - } - return auo -} - -// ClearSourceScope clears the value of the "sourceScope" field. -func (auo *AlertUpdateOne) ClearSourceScope() *AlertUpdateOne { - auo.mutation.ClearSourceScope() - return auo -} - -// SetSourceValue sets the "sourceValue" field. -func (auo *AlertUpdateOne) SetSourceValue(s string) *AlertUpdateOne { - auo.mutation.SetSourceValue(s) - return auo -} - -// SetNillableSourceValue sets the "sourceValue" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableSourceValue(s *string) *AlertUpdateOne { - if s != nil { - auo.SetSourceValue(*s) - } - return auo -} - -// ClearSourceValue clears the value of the "sourceValue" field. -func (auo *AlertUpdateOne) ClearSourceValue() *AlertUpdateOne { - auo.mutation.ClearSourceValue() - return auo -} - -// SetCapacity sets the "capacity" field. -func (auo *AlertUpdateOne) SetCapacity(i int32) *AlertUpdateOne { - auo.mutation.ResetCapacity() - auo.mutation.SetCapacity(i) - return auo -} - -// SetNillableCapacity sets the "capacity" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableCapacity(i *int32) *AlertUpdateOne { - if i != nil { - auo.SetCapacity(*i) - } - return auo -} - -// AddCapacity adds i to the "capacity" field. -func (auo *AlertUpdateOne) AddCapacity(i int32) *AlertUpdateOne { - auo.mutation.AddCapacity(i) - return auo -} - -// ClearCapacity clears the value of the "capacity" field. -func (auo *AlertUpdateOne) ClearCapacity() *AlertUpdateOne { - auo.mutation.ClearCapacity() - return auo -} - -// SetLeakSpeed sets the "leakSpeed" field. -func (auo *AlertUpdateOne) SetLeakSpeed(s string) *AlertUpdateOne { - auo.mutation.SetLeakSpeed(s) - return auo -} - -// SetNillableLeakSpeed sets the "leakSpeed" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableLeakSpeed(s *string) *AlertUpdateOne { - if s != nil { - auo.SetLeakSpeed(*s) - } - return auo -} - -// ClearLeakSpeed clears the value of the "leakSpeed" field. -func (auo *AlertUpdateOne) ClearLeakSpeed() *AlertUpdateOne { - auo.mutation.ClearLeakSpeed() - return auo -} - -// SetScenarioVersion sets the "scenarioVersion" field. -func (auo *AlertUpdateOne) SetScenarioVersion(s string) *AlertUpdateOne { - auo.mutation.SetScenarioVersion(s) - return auo -} - -// SetNillableScenarioVersion sets the "scenarioVersion" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableScenarioVersion(s *string) *AlertUpdateOne { - if s != nil { - auo.SetScenarioVersion(*s) - } - return auo -} - -// ClearScenarioVersion clears the value of the "scenarioVersion" field. -func (auo *AlertUpdateOne) ClearScenarioVersion() *AlertUpdateOne { - auo.mutation.ClearScenarioVersion() - return auo -} - -// SetScenarioHash sets the "scenarioHash" field. -func (auo *AlertUpdateOne) SetScenarioHash(s string) *AlertUpdateOne { - auo.mutation.SetScenarioHash(s) - return auo -} - -// SetNillableScenarioHash sets the "scenarioHash" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableScenarioHash(s *string) *AlertUpdateOne { - if s != nil { - auo.SetScenarioHash(*s) - } - return auo -} - -// ClearScenarioHash clears the value of the "scenarioHash" field. -func (auo *AlertUpdateOne) ClearScenarioHash() *AlertUpdateOne { - auo.mutation.ClearScenarioHash() - return auo -} - -// SetSimulated sets the "simulated" field. -func (auo *AlertUpdateOne) SetSimulated(b bool) *AlertUpdateOne { - auo.mutation.SetSimulated(b) - return auo -} - -// SetNillableSimulated sets the "simulated" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableSimulated(b *bool) *AlertUpdateOne { - if b != nil { - auo.SetSimulated(*b) - } - return auo -} - -// SetUUID sets the "uuid" field. -func (auo *AlertUpdateOne) SetUUID(s string) *AlertUpdateOne { - auo.mutation.SetUUID(s) - return auo -} - -// SetNillableUUID sets the "uuid" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableUUID(s *string) *AlertUpdateOne { - if s != nil { - auo.SetUUID(*s) - } - return auo -} - -// ClearUUID clears the value of the "uuid" field. -func (auo *AlertUpdateOne) ClearUUID() *AlertUpdateOne { - auo.mutation.ClearUUID() - return auo -} - // SetOwnerID sets the "owner" edge to the Machine entity by ID. func (auo *AlertUpdateOne) SetOwnerID(id int) *AlertUpdateOne { auo.mutation.SetOwnerID(id) @@ -1828,6 +612,12 @@ func (auo *AlertUpdateOne) RemoveMetas(m ...*Meta) *AlertUpdateOne { return auo.RemoveMetaIDs(ids...) } +// Where appends a list predicates to the AlertUpdate builder. +func (auo *AlertUpdateOne) Where(ps ...predicate.Alert) *AlertUpdateOne { + auo.mutation.Where(ps...) + return auo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (auo *AlertUpdateOne) Select(field string, fields ...string) *AlertUpdateOne { @@ -1837,41 +627,8 @@ func (auo *AlertUpdateOne) Select(field string, fields ...string) *AlertUpdateOn // Save executes the query and returns the updated Alert entity. func (auo *AlertUpdateOne) Save(ctx context.Context) (*Alert, error) { - var ( - err error - node *Alert - ) auo.defaults() - if len(auo.hooks) == 0 { - node, err = auo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*AlertMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - auo.mutation = mutation - node, err = auo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(auo.hooks) - 1; i >= 0; i-- { - if auo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = auo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, auo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Alert) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from AlertMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, auo.sqlSave, auo.mutation, auo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -1898,27 +655,14 @@ func (auo *AlertUpdateOne) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (auo *AlertUpdateOne) defaults() { - if _, ok := auo.mutation.CreatedAt(); !ok && !auo.mutation.CreatedAtCleared() { - v := alert.UpdateDefaultCreatedAt() - auo.mutation.SetCreatedAt(v) - } - if _, ok := auo.mutation.UpdatedAt(); !ok && !auo.mutation.UpdatedAtCleared() { + if _, ok := auo.mutation.UpdatedAt(); !ok { v := alert.UpdateDefaultUpdatedAt() auo.mutation.SetUpdatedAt(v) } } func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: alert.Table, - Columns: alert.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(alert.Table, alert.Columns, sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt)) id, ok := auo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Alert.id" for update`)} @@ -1943,320 +687,68 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error } } } - if value, ok := auo.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldCreatedAt, - }) - } - if auo.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: alert.FieldCreatedAt, - }) - } if value, ok := auo.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldUpdatedAt, - }) - } - if auo.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: alert.FieldUpdatedAt, - }) - } - if value, ok := auo.mutation.Scenario(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldScenario, - }) - } - if value, ok := auo.mutation.BucketId(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldBucketId, - }) + _spec.SetField(alert.FieldUpdatedAt, field.TypeTime, value) } if auo.mutation.BucketIdCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldBucketId, - }) - } - if value, ok := auo.mutation.Message(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldMessage, - }) + _spec.ClearField(alert.FieldBucketId, field.TypeString) } if auo.mutation.MessageCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldMessage, - }) - } - if value, ok := auo.mutation.EventsCount(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldEventsCount, - }) - } - if value, ok := auo.mutation.AddedEventsCount(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldEventsCount, - }) + _spec.ClearField(alert.FieldMessage, field.TypeString) } if auo.mutation.EventsCountCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Column: alert.FieldEventsCount, - }) - } - if value, ok := auo.mutation.StartedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldStartedAt, - }) + _spec.ClearField(alert.FieldEventsCount, field.TypeInt32) } if auo.mutation.StartedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: alert.FieldStartedAt, - }) - } - if value, ok := auo.mutation.StoppedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldStoppedAt, - }) + _spec.ClearField(alert.FieldStartedAt, field.TypeTime) } if auo.mutation.StoppedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: alert.FieldStoppedAt, - }) - } - if value, ok := auo.mutation.SourceIp(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceIp, - }) + _spec.ClearField(alert.FieldStoppedAt, field.TypeTime) } if auo.mutation.SourceIpCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceIp, - }) - } - if value, ok := auo.mutation.SourceRange(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceRange, - }) + _spec.ClearField(alert.FieldSourceIp, field.TypeString) } if auo.mutation.SourceRangeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceRange, - }) - } - if value, ok := auo.mutation.SourceAsNumber(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceAsNumber, - }) + _spec.ClearField(alert.FieldSourceRange, field.TypeString) } if auo.mutation.SourceAsNumberCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceAsNumber, - }) - } - if value, ok := auo.mutation.SourceAsName(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceAsName, - }) + _spec.ClearField(alert.FieldSourceAsNumber, field.TypeString) } if auo.mutation.SourceAsNameCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceAsName, - }) - } - if value, ok := auo.mutation.SourceCountry(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceCountry, - }) + _spec.ClearField(alert.FieldSourceAsName, field.TypeString) } if auo.mutation.SourceCountryCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceCountry, - }) - } - if value, ok := auo.mutation.SourceLatitude(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLatitude, - }) - } - if value, ok := auo.mutation.AddedSourceLatitude(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLatitude, - }) + _spec.ClearField(alert.FieldSourceCountry, field.TypeString) } if auo.mutation.SourceLatitudeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Column: alert.FieldSourceLatitude, - }) - } - if value, ok := auo.mutation.SourceLongitude(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLongitude, - }) - } - if value, ok := auo.mutation.AddedSourceLongitude(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLongitude, - }) + _spec.ClearField(alert.FieldSourceLatitude, field.TypeFloat32) } if auo.mutation.SourceLongitudeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Column: alert.FieldSourceLongitude, - }) - } - if value, ok := auo.mutation.SourceScope(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceScope, - }) + _spec.ClearField(alert.FieldSourceLongitude, field.TypeFloat32) } if auo.mutation.SourceScopeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceScope, - }) - } - if value, ok := auo.mutation.SourceValue(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceValue, - }) + _spec.ClearField(alert.FieldSourceScope, field.TypeString) } if auo.mutation.SourceValueCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceValue, - }) - } - if value, ok := auo.mutation.Capacity(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldCapacity, - }) - } - if value, ok := auo.mutation.AddedCapacity(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldCapacity, - }) + _spec.ClearField(alert.FieldSourceValue, field.TypeString) } if auo.mutation.CapacityCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Column: alert.FieldCapacity, - }) - } - if value, ok := auo.mutation.LeakSpeed(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldLeakSpeed, - }) + _spec.ClearField(alert.FieldCapacity, field.TypeInt32) } if auo.mutation.LeakSpeedCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldLeakSpeed, - }) - } - if value, ok := auo.mutation.ScenarioVersion(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldScenarioVersion, - }) + _spec.ClearField(alert.FieldLeakSpeed, field.TypeString) } if auo.mutation.ScenarioVersionCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldScenarioVersion, - }) - } - if value, ok := auo.mutation.ScenarioHash(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldScenarioHash, - }) + _spec.ClearField(alert.FieldScenarioVersion, field.TypeString) } if auo.mutation.ScenarioHashCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldScenarioHash, - }) - } - if value, ok := auo.mutation.Simulated(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: alert.FieldSimulated, - }) - } - if value, ok := auo.mutation.UUID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldUUID, - }) + _spec.ClearField(alert.FieldScenarioHash, field.TypeString) } if auo.mutation.UUIDCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldUUID, - }) + _spec.ClearField(alert.FieldUUID, field.TypeString) + } + if auo.mutation.RemediationCleared() { + _spec.ClearField(alert.FieldRemediation, field.TypeBool) } if auo.mutation.OwnerCleared() { edge := &sqlgraph.EdgeSpec{ @@ -2266,10 +758,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -2282,10 +771,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -2301,10 +787,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.DecisionsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -2317,10 +800,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.DecisionsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -2336,10 +816,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.DecisionsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -2355,10 +832,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.EventsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -2371,10 +845,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.EventsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -2390,10 +861,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.EventsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -2409,10 +877,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.MetasColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -2425,10 +890,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.MetasColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -2444,10 +906,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.MetasColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -2466,5 +925,6 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error } return nil, err } + auo.mutation.done = true return _node, nil } diff --git a/pkg/database/ent/bouncer.go b/pkg/database/ent/bouncer.go index 068fc6c6713..3b4d619e384 100644 --- a/pkg/database/ent/bouncer.go +++ b/pkg/database/ent/bouncer.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer" ) @@ -17,13 +18,13 @@ type Bouncer struct { // ID of the ent. ID int `json:"id,omitempty"` // CreatedAt holds the value of the "created_at" field. - CreatedAt *time.Time `json:"created_at"` + CreatedAt time.Time `json:"created_at"` // UpdatedAt holds the value of the "updated_at" field. - UpdatedAt *time.Time `json:"updated_at"` + UpdatedAt time.Time `json:"updated_at"` // Name holds the value of the "name" field. Name string `json:"name"` // APIKey holds the value of the "api_key" field. - APIKey string `json:"api_key"` + APIKey string `json:"-"` // Revoked holds the value of the "revoked" field. Revoked bool `json:"revoked"` // IPAddress holds the value of the "ip_address" field. @@ -32,12 +33,17 @@ type Bouncer struct { Type string `json:"type"` // Version holds the value of the "version" field. Version string `json:"version"` - // Until holds the value of the "until" field. - Until time.Time `json:"until"` // LastPull holds the value of the "last_pull" field. - LastPull time.Time `json:"last_pull"` + LastPull *time.Time `json:"last_pull"` // AuthType holds the value of the "auth_type" field. AuthType string `json:"auth_type"` + // Osname holds the value of the "osname" field. + Osname string `json:"osname,omitempty"` + // Osversion holds the value of the "osversion" field. + Osversion string `json:"osversion,omitempty"` + // Featureflags holds the value of the "featureflags" field. + Featureflags string `json:"featureflags,omitempty"` + selectValues sql.SelectValues } // scanValues returns the types for scanning values from sql.Rows. @@ -49,12 +55,12 @@ func (*Bouncer) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullBool) case bouncer.FieldID: values[i] = new(sql.NullInt64) - case bouncer.FieldName, bouncer.FieldAPIKey, bouncer.FieldIPAddress, bouncer.FieldType, bouncer.FieldVersion, bouncer.FieldAuthType: + case bouncer.FieldName, bouncer.FieldAPIKey, bouncer.FieldIPAddress, bouncer.FieldType, bouncer.FieldVersion, bouncer.FieldAuthType, bouncer.FieldOsname, bouncer.FieldOsversion, bouncer.FieldFeatureflags: values[i] = new(sql.NullString) - case bouncer.FieldCreatedAt, bouncer.FieldUpdatedAt, bouncer.FieldUntil, bouncer.FieldLastPull: + case bouncer.FieldCreatedAt, bouncer.FieldUpdatedAt, bouncer.FieldLastPull: values[i] = new(sql.NullTime) default: - return nil, fmt.Errorf("unexpected column %q for type Bouncer", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -78,15 +84,13 @@ func (b *Bouncer) assignValues(columns []string, values []any) error { if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) } else if value.Valid { - b.CreatedAt = new(time.Time) - *b.CreatedAt = value.Time + b.CreatedAt = value.Time } case bouncer.FieldUpdatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field updated_at", values[i]) } else if value.Valid { - b.UpdatedAt = new(time.Time) - *b.UpdatedAt = value.Time + b.UpdatedAt = value.Time } case bouncer.FieldName: if value, ok := values[i].(*sql.NullString); !ok { @@ -124,17 +128,12 @@ func (b *Bouncer) assignValues(columns []string, values []any) error { } else if value.Valid { b.Version = value.String } - case bouncer.FieldUntil: - if value, ok := values[i].(*sql.NullTime); !ok { - return fmt.Errorf("unexpected type %T for field until", values[i]) - } else if value.Valid { - b.Until = value.Time - } case bouncer.FieldLastPull: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field last_pull", values[i]) } else if value.Valid { - b.LastPull = value.Time + b.LastPull = new(time.Time) + *b.LastPull = value.Time } case bouncer.FieldAuthType: if value, ok := values[i].(*sql.NullString); !ok { @@ -142,16 +141,42 @@ func (b *Bouncer) assignValues(columns []string, values []any) error { } else if value.Valid { b.AuthType = value.String } + case bouncer.FieldOsname: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field osname", values[i]) + } else if value.Valid { + b.Osname = value.String + } + case bouncer.FieldOsversion: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field osversion", values[i]) + } else if value.Valid { + b.Osversion = value.String + } + case bouncer.FieldFeatureflags: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field featureflags", values[i]) + } else if value.Valid { + b.Featureflags = value.String + } + default: + b.selectValues.Set(columns[i], values[i]) } } return nil } +// Value returns the ent.Value that was dynamically selected and assigned to the Bouncer. +// This includes values selected through modifiers, order, etc. +func (b *Bouncer) Value(name string) (ent.Value, error) { + return b.selectValues.Get(name) +} + // Update returns a builder for updating this Bouncer. // Note that you need to call Bouncer.Unwrap() before calling this method if this Bouncer // was returned from a transaction, and the transaction was committed or rolled back. func (b *Bouncer) Update() *BouncerUpdateOne { - return (&BouncerClient{config: b.config}).UpdateOne(b) + return NewBouncerClient(b.config).UpdateOne(b) } // Unwrap unwraps the Bouncer entity that was returned from a transaction after it was closed, @@ -170,21 +195,16 @@ func (b *Bouncer) String() string { var builder strings.Builder builder.WriteString("Bouncer(") builder.WriteString(fmt.Sprintf("id=%v, ", b.ID)) - if v := b.CreatedAt; v != nil { - builder.WriteString("created_at=") - builder.WriteString(v.Format(time.ANSIC)) - } + builder.WriteString("created_at=") + builder.WriteString(b.CreatedAt.Format(time.ANSIC)) builder.WriteString(", ") - if v := b.UpdatedAt; v != nil { - builder.WriteString("updated_at=") - builder.WriteString(v.Format(time.ANSIC)) - } + builder.WriteString("updated_at=") + builder.WriteString(b.UpdatedAt.Format(time.ANSIC)) builder.WriteString(", ") builder.WriteString("name=") builder.WriteString(b.Name) builder.WriteString(", ") - builder.WriteString("api_key=") - builder.WriteString(b.APIKey) + builder.WriteString("api_key=") builder.WriteString(", ") builder.WriteString("revoked=") builder.WriteString(fmt.Sprintf("%v", b.Revoked)) @@ -198,23 +218,25 @@ func (b *Bouncer) String() string { builder.WriteString("version=") builder.WriteString(b.Version) builder.WriteString(", ") - builder.WriteString("until=") - builder.WriteString(b.Until.Format(time.ANSIC)) - builder.WriteString(", ") - builder.WriteString("last_pull=") - builder.WriteString(b.LastPull.Format(time.ANSIC)) + if v := b.LastPull; v != nil { + builder.WriteString("last_pull=") + builder.WriteString(v.Format(time.ANSIC)) + } builder.WriteString(", ") builder.WriteString("auth_type=") builder.WriteString(b.AuthType) + builder.WriteString(", ") + builder.WriteString("osname=") + builder.WriteString(b.Osname) + builder.WriteString(", ") + builder.WriteString("osversion=") + builder.WriteString(b.Osversion) + builder.WriteString(", ") + builder.WriteString("featureflags=") + builder.WriteString(b.Featureflags) builder.WriteByte(')') return builder.String() } // Bouncers is a parsable slice of Bouncer. type Bouncers []*Bouncer - -func (b Bouncers) config(cfg config) { - for _i := range b { - b[_i].config = cfg - } -} diff --git a/pkg/database/ent/bouncer/bouncer.go b/pkg/database/ent/bouncer/bouncer.go index b688594ece4..a6f62aeadd5 100644 --- a/pkg/database/ent/bouncer/bouncer.go +++ b/pkg/database/ent/bouncer/bouncer.go @@ -4,6 +4,8 @@ package bouncer import ( "time" + + "entgo.io/ent/dialect/sql" ) const ( @@ -27,12 +29,16 @@ const ( FieldType = "type" // FieldVersion holds the string denoting the version field in the database. FieldVersion = "version" - // FieldUntil holds the string denoting the until field in the database. - FieldUntil = "until" // FieldLastPull holds the string denoting the last_pull field in the database. FieldLastPull = "last_pull" // FieldAuthType holds the string denoting the auth_type field in the database. FieldAuthType = "auth_type" + // FieldOsname holds the string denoting the osname field in the database. + FieldOsname = "osname" + // FieldOsversion holds the string denoting the osversion field in the database. + FieldOsversion = "osversion" + // FieldFeatureflags holds the string denoting the featureflags field in the database. + FieldFeatureflags = "featureflags" // Table holds the table name of the bouncer in the database. Table = "bouncers" ) @@ -48,9 +54,11 @@ var Columns = []string{ FieldIPAddress, FieldType, FieldVersion, - FieldUntil, FieldLastPull, FieldAuthType, + FieldOsname, + FieldOsversion, + FieldFeatureflags, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -66,18 +74,85 @@ func ValidColumn(column string) bool { var ( // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time - // UpdateDefaultCreatedAt holds the default value on update for the "created_at" field. - UpdateDefaultCreatedAt func() time.Time // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. DefaultUpdatedAt func() time.Time // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. UpdateDefaultUpdatedAt func() time.Time // DefaultIPAddress holds the default value on creation for the "ip_address" field. DefaultIPAddress string - // DefaultUntil holds the default value on creation for the "until" field. - DefaultUntil func() time.Time - // DefaultLastPull holds the default value on creation for the "last_pull" field. - DefaultLastPull func() time.Time // DefaultAuthType holds the default value on creation for the "auth_type" field. DefaultAuthType string ) + +// OrderOption defines the ordering options for the Bouncer queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByName orders the results by the name field. +func ByName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldName, opts...).ToFunc() +} + +// ByAPIKey orders the results by the api_key field. +func ByAPIKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAPIKey, opts...).ToFunc() +} + +// ByRevoked orders the results by the revoked field. +func ByRevoked(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRevoked, opts...).ToFunc() +} + +// ByIPAddress orders the results by the ip_address field. +func ByIPAddress(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIPAddress, opts...).ToFunc() +} + +// ByType orders the results by the type field. +func ByType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldType, opts...).ToFunc() +} + +// ByVersion orders the results by the version field. +func ByVersion(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldVersion, opts...).ToFunc() +} + +// ByLastPull orders the results by the last_pull field. +func ByLastPull(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLastPull, opts...).ToFunc() +} + +// ByAuthType orders the results by the auth_type field. +func ByAuthType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAuthType, opts...).ToFunc() +} + +// ByOsname orders the results by the osname field. +func ByOsname(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldOsname, opts...).ToFunc() +} + +// ByOsversion orders the results by the osversion field. +func ByOsversion(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldOsversion, opts...).ToFunc() +} + +// ByFeatureflags orders the results by the featureflags field. +func ByFeatureflags(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldFeatureflags, opts...).ToFunc() +} diff --git a/pkg/database/ent/bouncer/where.go b/pkg/database/ent/bouncer/where.go index 03a543f6d4f..e02199bc0a9 100644 --- a/pkg/database/ent/bouncer/where.go +++ b/pkg/database/ent/bouncer/where.go @@ -11,1128 +11,910 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.Bouncer(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.Bouncer(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.Bouncer(sql.FieldLTE(FieldID, id)) } // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldCreatedAt, v)) } // UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. func UpdatedAt(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldUpdatedAt, v)) } // Name applies equality check predicate on the "name" field. It's identical to NameEQ. func Name(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldName, v)) } // APIKey applies equality check predicate on the "api_key" field. It's identical to APIKeyEQ. func APIKey(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldAPIKey, v)) } // Revoked applies equality check predicate on the "revoked" field. It's identical to RevokedEQ. func Revoked(v bool) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldRevoked), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldRevoked, v)) } // IPAddress applies equality check predicate on the "ip_address" field. It's identical to IPAddressEQ. func IPAddress(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldIPAddress, v)) } // Type applies equality check predicate on the "type" field. It's identical to TypeEQ. func Type(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldType, v)) } // Version applies equality check predicate on the "version" field. It's identical to VersionEQ. func Version(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldVersion), v)) - }) -} - -// Until applies equality check predicate on the "until" field. It's identical to UntilEQ. -func Until(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUntil), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldVersion, v)) } // LastPull applies equality check predicate on the "last_pull" field. It's identical to LastPullEQ. func LastPull(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldLastPull), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldLastPull, v)) } // AuthType applies equality check predicate on the "auth_type" field. It's identical to AuthTypeEQ. func AuthType(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldAuthType, v)) +} + +// Osname applies equality check predicate on the "osname" field. It's identical to OsnameEQ. +func Osname(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldEQ(FieldOsname, v)) +} + +// Osversion applies equality check predicate on the "osversion" field. It's identical to OsversionEQ. +func Osversion(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldEQ(FieldOsversion, v)) +} + +// Featureflags applies equality check predicate on the "featureflags" field. It's identical to FeatureflagsEQ. +func Featureflags(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldEQ(FieldFeatureflags, v)) } // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldCreatedAt, v)) } // CreatedAtNEQ applies the NEQ predicate on the "created_at" field. func CreatedAtNEQ(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldCreatedAt, v)) } // CreatedAtIn applies the In predicate on the "created_at" field. func CreatedAtIn(vs ...time.Time) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldCreatedAt), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldCreatedAt, vs...)) } // CreatedAtNotIn applies the NotIn predicate on the "created_at" field. func CreatedAtNotIn(vs ...time.Time) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldCreatedAt), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldCreatedAt, vs...)) } // CreatedAtGT applies the GT predicate on the "created_at" field. func CreatedAtGT(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldCreatedAt), v)) - }) + return predicate.Bouncer(sql.FieldGT(FieldCreatedAt, v)) } // CreatedAtGTE applies the GTE predicate on the "created_at" field. func CreatedAtGTE(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldCreatedAt), v)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldCreatedAt, v)) } // CreatedAtLT applies the LT predicate on the "created_at" field. func CreatedAtLT(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldCreatedAt), v)) - }) + return predicate.Bouncer(sql.FieldLT(FieldCreatedAt, v)) } // CreatedAtLTE applies the LTE predicate on the "created_at" field. func CreatedAtLTE(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldCreatedAt), v)) - }) -} - -// CreatedAtIsNil applies the IsNil predicate on the "created_at" field. -func CreatedAtIsNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldCreatedAt))) - }) -} - -// CreatedAtNotNil applies the NotNil predicate on the "created_at" field. -func CreatedAtNotNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldCreatedAt))) - }) + return predicate.Bouncer(sql.FieldLTE(FieldCreatedAt, v)) } // UpdatedAtEQ applies the EQ predicate on the "updated_at" field. func UpdatedAtEQ(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldUpdatedAt, v)) } // UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. func UpdatedAtNEQ(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldUpdatedAt, v)) } // UpdatedAtIn applies the In predicate on the "updated_at" field. func UpdatedAtIn(vs ...time.Time) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldUpdatedAt, vs...)) } // UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. func UpdatedAtNotIn(vs ...time.Time) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldUpdatedAt, vs...)) } // UpdatedAtGT applies the GT predicate on the "updated_at" field. func UpdatedAtGT(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Bouncer(sql.FieldGT(FieldUpdatedAt, v)) } // UpdatedAtGTE applies the GTE predicate on the "updated_at" field. func UpdatedAtGTE(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUpdatedAt), v)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldUpdatedAt, v)) } // UpdatedAtLT applies the LT predicate on the "updated_at" field. func UpdatedAtLT(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Bouncer(sql.FieldLT(FieldUpdatedAt, v)) } // UpdatedAtLTE applies the LTE predicate on the "updated_at" field. func UpdatedAtLTE(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUpdatedAt), v)) - }) -} - -// UpdatedAtIsNil applies the IsNil predicate on the "updated_at" field. -func UpdatedAtIsNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUpdatedAt))) - }) -} - -// UpdatedAtNotNil applies the NotNil predicate on the "updated_at" field. -func UpdatedAtNotNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUpdatedAt))) - }) + return predicate.Bouncer(sql.FieldLTE(FieldUpdatedAt, v)) } // NameEQ applies the EQ predicate on the "name" field. func NameEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldName, v)) } // NameNEQ applies the NEQ predicate on the "name" field. func NameNEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldName, v)) } // NameIn applies the In predicate on the "name" field. func NameIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldName), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldName, vs...)) } // NameNotIn applies the NotIn predicate on the "name" field. func NameNotIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldName), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldName, vs...)) } // NameGT applies the GT predicate on the "name" field. func NameGT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldGT(FieldName, v)) } // NameGTE applies the GTE predicate on the "name" field. func NameGTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldName, v)) } // NameLT applies the LT predicate on the "name" field. func NameLT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldLT(FieldName, v)) } // NameLTE applies the LTE predicate on the "name" field. func NameLTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldLTE(FieldName, v)) } // NameContains applies the Contains predicate on the "name" field. func NameContains(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldContains(FieldName, v)) } // NameHasPrefix applies the HasPrefix predicate on the "name" field. func NameHasPrefix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldHasPrefix(FieldName, v)) } // NameHasSuffix applies the HasSuffix predicate on the "name" field. func NameHasSuffix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldHasSuffix(FieldName, v)) } // NameEqualFold applies the EqualFold predicate on the "name" field. func NameEqualFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldEqualFold(FieldName, v)) } // NameContainsFold applies the ContainsFold predicate on the "name" field. func NameContainsFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldContainsFold(FieldName, v)) } // APIKeyEQ applies the EQ predicate on the "api_key" field. func APIKeyEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldAPIKey, v)) } // APIKeyNEQ applies the NEQ predicate on the "api_key" field. func APIKeyNEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldAPIKey, v)) } // APIKeyIn applies the In predicate on the "api_key" field. func APIKeyIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldAPIKey), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldAPIKey, vs...)) } // APIKeyNotIn applies the NotIn predicate on the "api_key" field. func APIKeyNotIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldAPIKey), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldAPIKey, vs...)) } // APIKeyGT applies the GT predicate on the "api_key" field. func APIKeyGT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldGT(FieldAPIKey, v)) } // APIKeyGTE applies the GTE predicate on the "api_key" field. func APIKeyGTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldAPIKey, v)) } // APIKeyLT applies the LT predicate on the "api_key" field. func APIKeyLT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldLT(FieldAPIKey, v)) } // APIKeyLTE applies the LTE predicate on the "api_key" field. func APIKeyLTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldLTE(FieldAPIKey, v)) } // APIKeyContains applies the Contains predicate on the "api_key" field. func APIKeyContains(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldContains(FieldAPIKey, v)) } // APIKeyHasPrefix applies the HasPrefix predicate on the "api_key" field. func APIKeyHasPrefix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldHasPrefix(FieldAPIKey, v)) } // APIKeyHasSuffix applies the HasSuffix predicate on the "api_key" field. func APIKeyHasSuffix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldHasSuffix(FieldAPIKey, v)) } // APIKeyEqualFold applies the EqualFold predicate on the "api_key" field. func APIKeyEqualFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldEqualFold(FieldAPIKey, v)) } // APIKeyContainsFold applies the ContainsFold predicate on the "api_key" field. func APIKeyContainsFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldContainsFold(FieldAPIKey, v)) } // RevokedEQ applies the EQ predicate on the "revoked" field. func RevokedEQ(v bool) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldRevoked), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldRevoked, v)) } // RevokedNEQ applies the NEQ predicate on the "revoked" field. func RevokedNEQ(v bool) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldRevoked), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldRevoked, v)) } // IPAddressEQ applies the EQ predicate on the "ip_address" field. func IPAddressEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldIPAddress, v)) } // IPAddressNEQ applies the NEQ predicate on the "ip_address" field. func IPAddressNEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldIPAddress, v)) } // IPAddressIn applies the In predicate on the "ip_address" field. func IPAddressIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldIPAddress), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldIPAddress, vs...)) } // IPAddressNotIn applies the NotIn predicate on the "ip_address" field. func IPAddressNotIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldIPAddress), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldIPAddress, vs...)) } // IPAddressGT applies the GT predicate on the "ip_address" field. func IPAddressGT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldGT(FieldIPAddress, v)) } // IPAddressGTE applies the GTE predicate on the "ip_address" field. func IPAddressGTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldIPAddress, v)) } // IPAddressLT applies the LT predicate on the "ip_address" field. func IPAddressLT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldLT(FieldIPAddress, v)) } // IPAddressLTE applies the LTE predicate on the "ip_address" field. func IPAddressLTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldLTE(FieldIPAddress, v)) } // IPAddressContains applies the Contains predicate on the "ip_address" field. func IPAddressContains(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldContains(FieldIPAddress, v)) } // IPAddressHasPrefix applies the HasPrefix predicate on the "ip_address" field. func IPAddressHasPrefix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldHasPrefix(FieldIPAddress, v)) } // IPAddressHasSuffix applies the HasSuffix predicate on the "ip_address" field. func IPAddressHasSuffix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldHasSuffix(FieldIPAddress, v)) } // IPAddressIsNil applies the IsNil predicate on the "ip_address" field. func IPAddressIsNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldIPAddress))) - }) + return predicate.Bouncer(sql.FieldIsNull(FieldIPAddress)) } // IPAddressNotNil applies the NotNil predicate on the "ip_address" field. func IPAddressNotNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldIPAddress))) - }) + return predicate.Bouncer(sql.FieldNotNull(FieldIPAddress)) } // IPAddressEqualFold applies the EqualFold predicate on the "ip_address" field. func IPAddressEqualFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldEqualFold(FieldIPAddress, v)) } // IPAddressContainsFold applies the ContainsFold predicate on the "ip_address" field. func IPAddressContainsFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldContainsFold(FieldIPAddress, v)) } // TypeEQ applies the EQ predicate on the "type" field. func TypeEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldType, v)) } // TypeNEQ applies the NEQ predicate on the "type" field. func TypeNEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldType, v)) } // TypeIn applies the In predicate on the "type" field. func TypeIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldType), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldType, vs...)) } // TypeNotIn applies the NotIn predicate on the "type" field. func TypeNotIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldType), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldType, vs...)) } // TypeGT applies the GT predicate on the "type" field. func TypeGT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldGT(FieldType, v)) } // TypeGTE applies the GTE predicate on the "type" field. func TypeGTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldType, v)) } // TypeLT applies the LT predicate on the "type" field. func TypeLT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldLT(FieldType, v)) } // TypeLTE applies the LTE predicate on the "type" field. func TypeLTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldLTE(FieldType, v)) } // TypeContains applies the Contains predicate on the "type" field. func TypeContains(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldContains(FieldType, v)) } // TypeHasPrefix applies the HasPrefix predicate on the "type" field. func TypeHasPrefix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldHasPrefix(FieldType, v)) } // TypeHasSuffix applies the HasSuffix predicate on the "type" field. func TypeHasSuffix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldHasSuffix(FieldType, v)) } // TypeIsNil applies the IsNil predicate on the "type" field. func TypeIsNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldType))) - }) + return predicate.Bouncer(sql.FieldIsNull(FieldType)) } // TypeNotNil applies the NotNil predicate on the "type" field. func TypeNotNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldType))) - }) + return predicate.Bouncer(sql.FieldNotNull(FieldType)) } // TypeEqualFold applies the EqualFold predicate on the "type" field. func TypeEqualFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldEqualFold(FieldType, v)) } // TypeContainsFold applies the ContainsFold predicate on the "type" field. func TypeContainsFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldContainsFold(FieldType, v)) } // VersionEQ applies the EQ predicate on the "version" field. func VersionEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldVersion, v)) } // VersionNEQ applies the NEQ predicate on the "version" field. func VersionNEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldVersion, v)) } // VersionIn applies the In predicate on the "version" field. func VersionIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldVersion), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldVersion, vs...)) } // VersionNotIn applies the NotIn predicate on the "version" field. func VersionNotIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldVersion), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldVersion, vs...)) } // VersionGT applies the GT predicate on the "version" field. func VersionGT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldGT(FieldVersion, v)) } // VersionGTE applies the GTE predicate on the "version" field. func VersionGTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldVersion, v)) } // VersionLT applies the LT predicate on the "version" field. func VersionLT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldLT(FieldVersion, v)) } // VersionLTE applies the LTE predicate on the "version" field. func VersionLTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldLTE(FieldVersion, v)) } // VersionContains applies the Contains predicate on the "version" field. func VersionContains(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldContains(FieldVersion, v)) } // VersionHasPrefix applies the HasPrefix predicate on the "version" field. func VersionHasPrefix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldHasPrefix(FieldVersion, v)) } // VersionHasSuffix applies the HasSuffix predicate on the "version" field. func VersionHasSuffix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldHasSuffix(FieldVersion, v)) } // VersionIsNil applies the IsNil predicate on the "version" field. func VersionIsNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldVersion))) - }) + return predicate.Bouncer(sql.FieldIsNull(FieldVersion)) } // VersionNotNil applies the NotNil predicate on the "version" field. func VersionNotNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldVersion))) - }) + return predicate.Bouncer(sql.FieldNotNull(FieldVersion)) } // VersionEqualFold applies the EqualFold predicate on the "version" field. func VersionEqualFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldEqualFold(FieldVersion, v)) } // VersionContainsFold applies the ContainsFold predicate on the "version" field. func VersionContainsFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldVersion), v)) - }) -} - -// UntilEQ applies the EQ predicate on the "until" field. -func UntilEQ(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUntil), v)) - }) -} - -// UntilNEQ applies the NEQ predicate on the "until" field. -func UntilNEQ(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUntil), v)) - }) -} - -// UntilIn applies the In predicate on the "until" field. -func UntilIn(vs ...time.Time) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUntil), v...)) - }) -} - -// UntilNotIn applies the NotIn predicate on the "until" field. -func UntilNotIn(vs ...time.Time) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUntil), v...)) - }) -} - -// UntilGT applies the GT predicate on the "until" field. -func UntilGT(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUntil), v)) - }) -} - -// UntilGTE applies the GTE predicate on the "until" field. -func UntilGTE(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUntil), v)) - }) -} - -// UntilLT applies the LT predicate on the "until" field. -func UntilLT(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUntil), v)) - }) -} - -// UntilLTE applies the LTE predicate on the "until" field. -func UntilLTE(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUntil), v)) - }) -} - -// UntilIsNil applies the IsNil predicate on the "until" field. -func UntilIsNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUntil))) - }) -} - -// UntilNotNil applies the NotNil predicate on the "until" field. -func UntilNotNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUntil))) - }) + return predicate.Bouncer(sql.FieldContainsFold(FieldVersion, v)) } // LastPullEQ applies the EQ predicate on the "last_pull" field. func LastPullEQ(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldLastPull), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldLastPull, v)) } // LastPullNEQ applies the NEQ predicate on the "last_pull" field. func LastPullNEQ(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldLastPull), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldLastPull, v)) } // LastPullIn applies the In predicate on the "last_pull" field. func LastPullIn(vs ...time.Time) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldLastPull), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldLastPull, vs...)) } // LastPullNotIn applies the NotIn predicate on the "last_pull" field. func LastPullNotIn(vs ...time.Time) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldLastPull), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldLastPull, vs...)) } // LastPullGT applies the GT predicate on the "last_pull" field. func LastPullGT(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldLastPull), v)) - }) + return predicate.Bouncer(sql.FieldGT(FieldLastPull, v)) } // LastPullGTE applies the GTE predicate on the "last_pull" field. func LastPullGTE(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldLastPull), v)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldLastPull, v)) } // LastPullLT applies the LT predicate on the "last_pull" field. func LastPullLT(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldLastPull), v)) - }) + return predicate.Bouncer(sql.FieldLT(FieldLastPull, v)) } // LastPullLTE applies the LTE predicate on the "last_pull" field. func LastPullLTE(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldLastPull), v)) - }) + return predicate.Bouncer(sql.FieldLTE(FieldLastPull, v)) +} + +// LastPullIsNil applies the IsNil predicate on the "last_pull" field. +func LastPullIsNil() predicate.Bouncer { + return predicate.Bouncer(sql.FieldIsNull(FieldLastPull)) +} + +// LastPullNotNil applies the NotNil predicate on the "last_pull" field. +func LastPullNotNil() predicate.Bouncer { + return predicate.Bouncer(sql.FieldNotNull(FieldLastPull)) } // AuthTypeEQ applies the EQ predicate on the "auth_type" field. func AuthTypeEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldAuthType, v)) } // AuthTypeNEQ applies the NEQ predicate on the "auth_type" field. func AuthTypeNEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldAuthType, v)) } // AuthTypeIn applies the In predicate on the "auth_type" field. func AuthTypeIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldAuthType), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldAuthType, vs...)) } // AuthTypeNotIn applies the NotIn predicate on the "auth_type" field. func AuthTypeNotIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldAuthType), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldAuthType, vs...)) } // AuthTypeGT applies the GT predicate on the "auth_type" field. func AuthTypeGT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldGT(FieldAuthType, v)) } // AuthTypeGTE applies the GTE predicate on the "auth_type" field. func AuthTypeGTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldAuthType, v)) } // AuthTypeLT applies the LT predicate on the "auth_type" field. func AuthTypeLT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldLT(FieldAuthType, v)) } // AuthTypeLTE applies the LTE predicate on the "auth_type" field. func AuthTypeLTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldLTE(FieldAuthType, v)) } // AuthTypeContains applies the Contains predicate on the "auth_type" field. func AuthTypeContains(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldContains(FieldAuthType, v)) } // AuthTypeHasPrefix applies the HasPrefix predicate on the "auth_type" field. func AuthTypeHasPrefix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldHasPrefix(FieldAuthType, v)) } // AuthTypeHasSuffix applies the HasSuffix predicate on the "auth_type" field. func AuthTypeHasSuffix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldHasSuffix(FieldAuthType, v)) } // AuthTypeEqualFold applies the EqualFold predicate on the "auth_type" field. func AuthTypeEqualFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldEqualFold(FieldAuthType, v)) } // AuthTypeContainsFold applies the ContainsFold predicate on the "auth_type" field. func AuthTypeContainsFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldContainsFold(FieldAuthType, v)) +} + +// OsnameEQ applies the EQ predicate on the "osname" field. +func OsnameEQ(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldEQ(FieldOsname, v)) +} + +// OsnameNEQ applies the NEQ predicate on the "osname" field. +func OsnameNEQ(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldNEQ(FieldOsname, v)) +} + +// OsnameIn applies the In predicate on the "osname" field. +func OsnameIn(vs ...string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldIn(FieldOsname, vs...)) +} + +// OsnameNotIn applies the NotIn predicate on the "osname" field. +func OsnameNotIn(vs ...string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldNotIn(FieldOsname, vs...)) +} + +// OsnameGT applies the GT predicate on the "osname" field. +func OsnameGT(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldGT(FieldOsname, v)) +} + +// OsnameGTE applies the GTE predicate on the "osname" field. +func OsnameGTE(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldGTE(FieldOsname, v)) +} + +// OsnameLT applies the LT predicate on the "osname" field. +func OsnameLT(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldLT(FieldOsname, v)) +} + +// OsnameLTE applies the LTE predicate on the "osname" field. +func OsnameLTE(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldLTE(FieldOsname, v)) +} + +// OsnameContains applies the Contains predicate on the "osname" field. +func OsnameContains(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldContains(FieldOsname, v)) +} + +// OsnameHasPrefix applies the HasPrefix predicate on the "osname" field. +func OsnameHasPrefix(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldHasPrefix(FieldOsname, v)) +} + +// OsnameHasSuffix applies the HasSuffix predicate on the "osname" field. +func OsnameHasSuffix(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldHasSuffix(FieldOsname, v)) +} + +// OsnameIsNil applies the IsNil predicate on the "osname" field. +func OsnameIsNil() predicate.Bouncer { + return predicate.Bouncer(sql.FieldIsNull(FieldOsname)) +} + +// OsnameNotNil applies the NotNil predicate on the "osname" field. +func OsnameNotNil() predicate.Bouncer { + return predicate.Bouncer(sql.FieldNotNull(FieldOsname)) +} + +// OsnameEqualFold applies the EqualFold predicate on the "osname" field. +func OsnameEqualFold(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldEqualFold(FieldOsname, v)) +} + +// OsnameContainsFold applies the ContainsFold predicate on the "osname" field. +func OsnameContainsFold(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldContainsFold(FieldOsname, v)) +} + +// OsversionEQ applies the EQ predicate on the "osversion" field. +func OsversionEQ(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldEQ(FieldOsversion, v)) +} + +// OsversionNEQ applies the NEQ predicate on the "osversion" field. +func OsversionNEQ(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldNEQ(FieldOsversion, v)) +} + +// OsversionIn applies the In predicate on the "osversion" field. +func OsversionIn(vs ...string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldIn(FieldOsversion, vs...)) +} + +// OsversionNotIn applies the NotIn predicate on the "osversion" field. +func OsversionNotIn(vs ...string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldNotIn(FieldOsversion, vs...)) +} + +// OsversionGT applies the GT predicate on the "osversion" field. +func OsversionGT(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldGT(FieldOsversion, v)) +} + +// OsversionGTE applies the GTE predicate on the "osversion" field. +func OsversionGTE(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldGTE(FieldOsversion, v)) +} + +// OsversionLT applies the LT predicate on the "osversion" field. +func OsversionLT(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldLT(FieldOsversion, v)) +} + +// OsversionLTE applies the LTE predicate on the "osversion" field. +func OsversionLTE(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldLTE(FieldOsversion, v)) +} + +// OsversionContains applies the Contains predicate on the "osversion" field. +func OsversionContains(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldContains(FieldOsversion, v)) +} + +// OsversionHasPrefix applies the HasPrefix predicate on the "osversion" field. +func OsversionHasPrefix(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldHasPrefix(FieldOsversion, v)) +} + +// OsversionHasSuffix applies the HasSuffix predicate on the "osversion" field. +func OsversionHasSuffix(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldHasSuffix(FieldOsversion, v)) +} + +// OsversionIsNil applies the IsNil predicate on the "osversion" field. +func OsversionIsNil() predicate.Bouncer { + return predicate.Bouncer(sql.FieldIsNull(FieldOsversion)) +} + +// OsversionNotNil applies the NotNil predicate on the "osversion" field. +func OsversionNotNil() predicate.Bouncer { + return predicate.Bouncer(sql.FieldNotNull(FieldOsversion)) +} + +// OsversionEqualFold applies the EqualFold predicate on the "osversion" field. +func OsversionEqualFold(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldEqualFold(FieldOsversion, v)) +} + +// OsversionContainsFold applies the ContainsFold predicate on the "osversion" field. +func OsversionContainsFold(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldContainsFold(FieldOsversion, v)) +} + +// FeatureflagsEQ applies the EQ predicate on the "featureflags" field. +func FeatureflagsEQ(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldEQ(FieldFeatureflags, v)) +} + +// FeatureflagsNEQ applies the NEQ predicate on the "featureflags" field. +func FeatureflagsNEQ(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldNEQ(FieldFeatureflags, v)) +} + +// FeatureflagsIn applies the In predicate on the "featureflags" field. +func FeatureflagsIn(vs ...string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldIn(FieldFeatureflags, vs...)) +} + +// FeatureflagsNotIn applies the NotIn predicate on the "featureflags" field. +func FeatureflagsNotIn(vs ...string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldNotIn(FieldFeatureflags, vs...)) +} + +// FeatureflagsGT applies the GT predicate on the "featureflags" field. +func FeatureflagsGT(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldGT(FieldFeatureflags, v)) +} + +// FeatureflagsGTE applies the GTE predicate on the "featureflags" field. +func FeatureflagsGTE(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldGTE(FieldFeatureflags, v)) +} + +// FeatureflagsLT applies the LT predicate on the "featureflags" field. +func FeatureflagsLT(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldLT(FieldFeatureflags, v)) +} + +// FeatureflagsLTE applies the LTE predicate on the "featureflags" field. +func FeatureflagsLTE(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldLTE(FieldFeatureflags, v)) +} + +// FeatureflagsContains applies the Contains predicate on the "featureflags" field. +func FeatureflagsContains(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldContains(FieldFeatureflags, v)) +} + +// FeatureflagsHasPrefix applies the HasPrefix predicate on the "featureflags" field. +func FeatureflagsHasPrefix(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldHasPrefix(FieldFeatureflags, v)) +} + +// FeatureflagsHasSuffix applies the HasSuffix predicate on the "featureflags" field. +func FeatureflagsHasSuffix(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldHasSuffix(FieldFeatureflags, v)) +} + +// FeatureflagsIsNil applies the IsNil predicate on the "featureflags" field. +func FeatureflagsIsNil() predicate.Bouncer { + return predicate.Bouncer(sql.FieldIsNull(FieldFeatureflags)) +} + +// FeatureflagsNotNil applies the NotNil predicate on the "featureflags" field. +func FeatureflagsNotNil() predicate.Bouncer { + return predicate.Bouncer(sql.FieldNotNull(FieldFeatureflags)) +} + +// FeatureflagsEqualFold applies the EqualFold predicate on the "featureflags" field. +func FeatureflagsEqualFold(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldEqualFold(FieldFeatureflags, v)) +} + +// FeatureflagsContainsFold applies the ContainsFold predicate on the "featureflags" field. +func FeatureflagsContainsFold(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldContainsFold(FieldFeatureflags, v)) } // And groups predicates with the AND operator between them. func And(predicates ...predicate.Bouncer) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Bouncer(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.Bouncer) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Bouncer(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.Bouncer) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.Bouncer(sql.NotPredicates(p)) } diff --git a/pkg/database/ent/bouncer_create.go b/pkg/database/ent/bouncer_create.go index 685ce089d1e..29b23f87cf1 100644 --- a/pkg/database/ent/bouncer_create.go +++ b/pkg/database/ent/bouncer_create.go @@ -108,20 +108,6 @@ func (bc *BouncerCreate) SetNillableVersion(s *string) *BouncerCreate { return bc } -// SetUntil sets the "until" field. -func (bc *BouncerCreate) SetUntil(t time.Time) *BouncerCreate { - bc.mutation.SetUntil(t) - return bc -} - -// SetNillableUntil sets the "until" field if the given value is not nil. -func (bc *BouncerCreate) SetNillableUntil(t *time.Time) *BouncerCreate { - if t != nil { - bc.SetUntil(*t) - } - return bc -} - // SetLastPull sets the "last_pull" field. func (bc *BouncerCreate) SetLastPull(t time.Time) *BouncerCreate { bc.mutation.SetLastPull(t) @@ -150,6 +136,48 @@ func (bc *BouncerCreate) SetNillableAuthType(s *string) *BouncerCreate { return bc } +// SetOsname sets the "osname" field. +func (bc *BouncerCreate) SetOsname(s string) *BouncerCreate { + bc.mutation.SetOsname(s) + return bc +} + +// SetNillableOsname sets the "osname" field if the given value is not nil. +func (bc *BouncerCreate) SetNillableOsname(s *string) *BouncerCreate { + if s != nil { + bc.SetOsname(*s) + } + return bc +} + +// SetOsversion sets the "osversion" field. +func (bc *BouncerCreate) SetOsversion(s string) *BouncerCreate { + bc.mutation.SetOsversion(s) + return bc +} + +// SetNillableOsversion sets the "osversion" field if the given value is not nil. +func (bc *BouncerCreate) SetNillableOsversion(s *string) *BouncerCreate { + if s != nil { + bc.SetOsversion(*s) + } + return bc +} + +// SetFeatureflags sets the "featureflags" field. +func (bc *BouncerCreate) SetFeatureflags(s string) *BouncerCreate { + bc.mutation.SetFeatureflags(s) + return bc +} + +// SetNillableFeatureflags sets the "featureflags" field if the given value is not nil. +func (bc *BouncerCreate) SetNillableFeatureflags(s *string) *BouncerCreate { + if s != nil { + bc.SetFeatureflags(*s) + } + return bc +} + // Mutation returns the BouncerMutation object of the builder. func (bc *BouncerCreate) Mutation() *BouncerMutation { return bc.mutation @@ -157,50 +185,8 @@ func (bc *BouncerCreate) Mutation() *BouncerMutation { // Save creates the Bouncer in the database. func (bc *BouncerCreate) Save(ctx context.Context) (*Bouncer, error) { - var ( - err error - node *Bouncer - ) bc.defaults() - if len(bc.hooks) == 0 { - if err = bc.check(); err != nil { - return nil, err - } - node, err = bc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*BouncerMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = bc.check(); err != nil { - return nil, err - } - bc.mutation = mutation - if node, err = bc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(bc.hooks) - 1; i >= 0; i-- { - if bc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = bc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, bc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Bouncer) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from BouncerMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, bc.sqlSave, bc.mutation, bc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -239,14 +225,6 @@ func (bc *BouncerCreate) defaults() { v := bouncer.DefaultIPAddress bc.mutation.SetIPAddress(v) } - if _, ok := bc.mutation.Until(); !ok { - v := bouncer.DefaultUntil() - bc.mutation.SetUntil(v) - } - if _, ok := bc.mutation.LastPull(); !ok { - v := bouncer.DefaultLastPull() - bc.mutation.SetLastPull(v) - } if _, ok := bc.mutation.AuthType(); !ok { v := bouncer.DefaultAuthType bc.mutation.SetAuthType(v) @@ -255,6 +233,12 @@ func (bc *BouncerCreate) defaults() { // check runs all checks and user-defined validators on the builder. func (bc *BouncerCreate) check() error { + if _, ok := bc.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Bouncer.created_at"`)} + } + if _, ok := bc.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "Bouncer.updated_at"`)} + } if _, ok := bc.mutation.Name(); !ok { return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "Bouncer.name"`)} } @@ -264,9 +248,6 @@ func (bc *BouncerCreate) check() error { if _, ok := bc.mutation.Revoked(); !ok { return &ValidationError{Name: "revoked", err: errors.New(`ent: missing required field "Bouncer.revoked"`)} } - if _, ok := bc.mutation.LastPull(); !ok { - return &ValidationError{Name: "last_pull", err: errors.New(`ent: missing required field "Bouncer.last_pull"`)} - } if _, ok := bc.mutation.AuthType(); !ok { return &ValidationError{Name: "auth_type", err: errors.New(`ent: missing required field "Bouncer.auth_type"`)} } @@ -274,6 +255,9 @@ func (bc *BouncerCreate) check() error { } func (bc *BouncerCreate) sqlSave(ctx context.Context) (*Bouncer, error) { + if err := bc.check(); err != nil { + return nil, err + } _node, _spec := bc.createSpec() if err := sqlgraph.CreateNode(ctx, bc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -283,119 +267,83 @@ func (bc *BouncerCreate) sqlSave(ctx context.Context) (*Bouncer, error) { } id := _spec.ID.Value.(int64) _node.ID = int(id) + bc.mutation.id = &_node.ID + bc.mutation.done = true return _node, nil } func (bc *BouncerCreate) createSpec() (*Bouncer, *sqlgraph.CreateSpec) { var ( _node = &Bouncer{config: bc.config} - _spec = &sqlgraph.CreateSpec{ - Table: bouncer.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: bouncer.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(bouncer.Table, sqlgraph.NewFieldSpec(bouncer.FieldID, field.TypeInt)) ) if value, ok := bc.mutation.CreatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldCreatedAt, - }) - _node.CreatedAt = &value + _spec.SetField(bouncer.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value } if value, ok := bc.mutation.UpdatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldUpdatedAt, - }) - _node.UpdatedAt = &value + _spec.SetField(bouncer.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value } if value, ok := bc.mutation.Name(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldName, - }) + _spec.SetField(bouncer.FieldName, field.TypeString, value) _node.Name = value } if value, ok := bc.mutation.APIKey(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldAPIKey, - }) + _spec.SetField(bouncer.FieldAPIKey, field.TypeString, value) _node.APIKey = value } if value, ok := bc.mutation.Revoked(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: bouncer.FieldRevoked, - }) + _spec.SetField(bouncer.FieldRevoked, field.TypeBool, value) _node.Revoked = value } if value, ok := bc.mutation.IPAddress(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldIPAddress, - }) + _spec.SetField(bouncer.FieldIPAddress, field.TypeString, value) _node.IPAddress = value } if value, ok := bc.mutation.GetType(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldType, - }) + _spec.SetField(bouncer.FieldType, field.TypeString, value) _node.Type = value } if value, ok := bc.mutation.Version(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldVersion, - }) + _spec.SetField(bouncer.FieldVersion, field.TypeString, value) _node.Version = value } - if value, ok := bc.mutation.Until(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldUntil, - }) - _node.Until = value - } if value, ok := bc.mutation.LastPull(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldLastPull, - }) - _node.LastPull = value + _spec.SetField(bouncer.FieldLastPull, field.TypeTime, value) + _node.LastPull = &value } if value, ok := bc.mutation.AuthType(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldAuthType, - }) + _spec.SetField(bouncer.FieldAuthType, field.TypeString, value) _node.AuthType = value } + if value, ok := bc.mutation.Osname(); ok { + _spec.SetField(bouncer.FieldOsname, field.TypeString, value) + _node.Osname = value + } + if value, ok := bc.mutation.Osversion(); ok { + _spec.SetField(bouncer.FieldOsversion, field.TypeString, value) + _node.Osversion = value + } + if value, ok := bc.mutation.Featureflags(); ok { + _spec.SetField(bouncer.FieldFeatureflags, field.TypeString, value) + _node.Featureflags = value + } return _node, _spec } // BouncerCreateBulk is the builder for creating many Bouncer entities in bulk. type BouncerCreateBulk struct { config + err error builders []*BouncerCreate } // Save creates the Bouncer entities in the database. func (bcb *BouncerCreateBulk) Save(ctx context.Context) ([]*Bouncer, error) { + if bcb.err != nil { + return nil, bcb.err + } specs := make([]*sqlgraph.CreateSpec, len(bcb.builders)) nodes := make([]*Bouncer, len(bcb.builders)) mutators := make([]Mutator, len(bcb.builders)) @@ -412,8 +360,8 @@ func (bcb *BouncerCreateBulk) Save(ctx context.Context) ([]*Bouncer, error) { return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, bcb.builders[i+1].mutation) } else { diff --git a/pkg/database/ent/bouncer_delete.go b/pkg/database/ent/bouncer_delete.go index 6bfb9459190..bf459e77e28 100644 --- a/pkg/database/ent/bouncer_delete.go +++ b/pkg/database/ent/bouncer_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (bd *BouncerDelete) Where(ps ...predicate.Bouncer) *BouncerDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (bd *BouncerDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(bd.hooks) == 0 { - affected, err = bd.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*BouncerMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - bd.mutation = mutation - affected, err = bd.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(bd.hooks) - 1; i >= 0; i-- { - if bd.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = bd.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, bd.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, bd.sqlExec, bd.mutation, bd.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (bd *BouncerDelete) ExecX(ctx context.Context) int { } func (bd *BouncerDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: bouncer.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: bouncer.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(bouncer.Table, sqlgraph.NewFieldSpec(bouncer.FieldID, field.TypeInt)) if ps := bd.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (bd *BouncerDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + bd.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type BouncerDeleteOne struct { bd *BouncerDelete } +// Where appends a list predicates to the BouncerDelete builder. +func (bdo *BouncerDeleteOne) Where(ps ...predicate.Bouncer) *BouncerDeleteOne { + bdo.bd.mutation.Where(ps...) + return bdo +} + // Exec executes the deletion query. func (bdo *BouncerDeleteOne) Exec(ctx context.Context) error { n, err := bdo.bd.Exec(ctx) @@ -111,5 +82,7 @@ func (bdo *BouncerDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (bdo *BouncerDeleteOne) ExecX(ctx context.Context) { - bdo.bd.ExecX(ctx) + if err := bdo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/pkg/database/ent/bouncer_query.go b/pkg/database/ent/bouncer_query.go index 2747a3e0b3a..ea2b7495733 100644 --- a/pkg/database/ent/bouncer_query.go +++ b/pkg/database/ent/bouncer_query.go @@ -17,11 +17,9 @@ import ( // BouncerQuery is the builder for querying Bouncer entities. type BouncerQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string + ctx *QueryContext + order []bouncer.OrderOption + inters []Interceptor predicates []predicate.Bouncer // intermediate query (i.e. traversal path). sql *sql.Selector @@ -34,27 +32,27 @@ func (bq *BouncerQuery) Where(ps ...predicate.Bouncer) *BouncerQuery { return bq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (bq *BouncerQuery) Limit(limit int) *BouncerQuery { - bq.limit = &limit + bq.ctx.Limit = &limit return bq } -// Offset adds an offset step to the query. +// Offset to start from. func (bq *BouncerQuery) Offset(offset int) *BouncerQuery { - bq.offset = &offset + bq.ctx.Offset = &offset return bq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (bq *BouncerQuery) Unique(unique bool) *BouncerQuery { - bq.unique = &unique + bq.ctx.Unique = &unique return bq } -// Order adds an order step to the query. -func (bq *BouncerQuery) Order(o ...OrderFunc) *BouncerQuery { +// Order specifies how the records should be ordered. +func (bq *BouncerQuery) Order(o ...bouncer.OrderOption) *BouncerQuery { bq.order = append(bq.order, o...) return bq } @@ -62,7 +60,7 @@ func (bq *BouncerQuery) Order(o ...OrderFunc) *BouncerQuery { // First returns the first Bouncer entity from the query. // Returns a *NotFoundError when no Bouncer was found. func (bq *BouncerQuery) First(ctx context.Context) (*Bouncer, error) { - nodes, err := bq.Limit(1).All(ctx) + nodes, err := bq.Limit(1).All(setContextOp(ctx, bq.ctx, "First")) if err != nil { return nil, err } @@ -85,7 +83,7 @@ func (bq *BouncerQuery) FirstX(ctx context.Context) *Bouncer { // Returns a *NotFoundError when no Bouncer ID was found. func (bq *BouncerQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = bq.Limit(1).IDs(ctx); err != nil { + if ids, err = bq.Limit(1).IDs(setContextOp(ctx, bq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -108,7 +106,7 @@ func (bq *BouncerQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Bouncer entity is found. // Returns a *NotFoundError when no Bouncer entities are found. func (bq *BouncerQuery) Only(ctx context.Context) (*Bouncer, error) { - nodes, err := bq.Limit(2).All(ctx) + nodes, err := bq.Limit(2).All(setContextOp(ctx, bq.ctx, "Only")) if err != nil { return nil, err } @@ -136,7 +134,7 @@ func (bq *BouncerQuery) OnlyX(ctx context.Context) *Bouncer { // Returns a *NotFoundError when no entities are found. func (bq *BouncerQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = bq.Limit(2).IDs(ctx); err != nil { + if ids, err = bq.Limit(2).IDs(setContextOp(ctx, bq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -161,10 +159,12 @@ func (bq *BouncerQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Bouncers. func (bq *BouncerQuery) All(ctx context.Context) ([]*Bouncer, error) { + ctx = setContextOp(ctx, bq.ctx, "All") if err := bq.prepareQuery(ctx); err != nil { return nil, err } - return bq.sqlAll(ctx) + qr := querierAll[[]*Bouncer, *BouncerQuery]() + return withInterceptors[[]*Bouncer](ctx, bq, qr, bq.inters) } // AllX is like All, but panics if an error occurs. @@ -177,9 +177,12 @@ func (bq *BouncerQuery) AllX(ctx context.Context) []*Bouncer { } // IDs executes the query and returns a list of Bouncer IDs. -func (bq *BouncerQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := bq.Select(bouncer.FieldID).Scan(ctx, &ids); err != nil { +func (bq *BouncerQuery) IDs(ctx context.Context) (ids []int, err error) { + if bq.ctx.Unique == nil && bq.path != nil { + bq.Unique(true) + } + ctx = setContextOp(ctx, bq.ctx, "IDs") + if err = bq.Select(bouncer.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -196,10 +199,11 @@ func (bq *BouncerQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (bq *BouncerQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, bq.ctx, "Count") if err := bq.prepareQuery(ctx); err != nil { return 0, err } - return bq.sqlCount(ctx) + return withInterceptors[int](ctx, bq, querierCount[*BouncerQuery](), bq.inters) } // CountX is like Count, but panics if an error occurs. @@ -213,10 +217,15 @@ func (bq *BouncerQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (bq *BouncerQuery) Exist(ctx context.Context) (bool, error) { - if err := bq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, bq.ctx, "Exist") + switch _, err := bq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return bq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -236,14 +245,13 @@ func (bq *BouncerQuery) Clone() *BouncerQuery { } return &BouncerQuery{ config: bq.config, - limit: bq.limit, - offset: bq.offset, - order: append([]OrderFunc{}, bq.order...), + ctx: bq.ctx.Clone(), + order: append([]bouncer.OrderOption{}, bq.order...), + inters: append([]Interceptor{}, bq.inters...), predicates: append([]predicate.Bouncer{}, bq.predicates...), // clone intermediate query. - sql: bq.sql.Clone(), - path: bq.path, - unique: bq.unique, + sql: bq.sql.Clone(), + path: bq.path, } } @@ -262,16 +270,11 @@ func (bq *BouncerQuery) Clone() *BouncerQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (bq *BouncerQuery) GroupBy(field string, fields ...string) *BouncerGroupBy { - grbuild := &BouncerGroupBy{config: bq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := bq.prepareQuery(ctx); err != nil { - return nil, err - } - return bq.sqlQuery(ctx), nil - } + bq.ctx.Fields = append([]string{field}, fields...) + grbuild := &BouncerGroupBy{build: bq} + grbuild.flds = &bq.ctx.Fields grbuild.label = bouncer.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -288,15 +291,30 @@ func (bq *BouncerQuery) GroupBy(field string, fields ...string) *BouncerGroupBy // Select(bouncer.FieldCreatedAt). // Scan(ctx, &v) func (bq *BouncerQuery) Select(fields ...string) *BouncerSelect { - bq.fields = append(bq.fields, fields...) - selbuild := &BouncerSelect{BouncerQuery: bq} - selbuild.label = bouncer.Label - selbuild.flds, selbuild.scan = &bq.fields, selbuild.Scan - return selbuild + bq.ctx.Fields = append(bq.ctx.Fields, fields...) + sbuild := &BouncerSelect{BouncerQuery: bq} + sbuild.label = bouncer.Label + sbuild.flds, sbuild.scan = &bq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a BouncerSelect configured with the given aggregations. +func (bq *BouncerQuery) Aggregate(fns ...AggregateFunc) *BouncerSelect { + return bq.Select().Aggregate(fns...) } func (bq *BouncerQuery) prepareQuery(ctx context.Context) error { - for _, f := range bq.fields { + for _, inter := range bq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, bq); err != nil { + return err + } + } + } + for _, f := range bq.ctx.Fields { if !bouncer.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -338,41 +356,22 @@ func (bq *BouncerQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Boun func (bq *BouncerQuery) sqlCount(ctx context.Context) (int, error) { _spec := bq.querySpec() - _spec.Node.Columns = bq.fields - if len(bq.fields) > 0 { - _spec.Unique = bq.unique != nil && *bq.unique + _spec.Node.Columns = bq.ctx.Fields + if len(bq.ctx.Fields) > 0 { + _spec.Unique = bq.ctx.Unique != nil && *bq.ctx.Unique } return sqlgraph.CountNodes(ctx, bq.driver, _spec) } -func (bq *BouncerQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := bq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (bq *BouncerQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: bouncer.Table, - Columns: bouncer.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: bouncer.FieldID, - }, - }, - From: bq.sql, - Unique: true, - } - if unique := bq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(bouncer.Table, bouncer.Columns, sqlgraph.NewFieldSpec(bouncer.FieldID, field.TypeInt)) + _spec.From = bq.sql + if unique := bq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if bq.path != nil { + _spec.Unique = true } - if fields := bq.fields; len(fields) > 0 { + if fields := bq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, bouncer.FieldID) for i := range fields { @@ -388,10 +387,10 @@ func (bq *BouncerQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := bq.limit; limit != nil { + if limit := bq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := bq.offset; offset != nil { + if offset := bq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := bq.order; len(ps) > 0 { @@ -407,7 +406,7 @@ func (bq *BouncerQuery) querySpec() *sqlgraph.QuerySpec { func (bq *BouncerQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(bq.driver.Dialect()) t1 := builder.Table(bouncer.Table) - columns := bq.fields + columns := bq.ctx.Fields if len(columns) == 0 { columns = bouncer.Columns } @@ -416,7 +415,7 @@ func (bq *BouncerQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = bq.sql selector.Select(selector.Columns(columns...)...) } - if bq.unique != nil && *bq.unique { + if bq.ctx.Unique != nil && *bq.ctx.Unique { selector.Distinct() } for _, p := range bq.predicates { @@ -425,12 +424,12 @@ func (bq *BouncerQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range bq.order { p(selector) } - if offset := bq.offset; offset != nil { + if offset := bq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := bq.limit; limit != nil { + if limit := bq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -438,13 +437,8 @@ func (bq *BouncerQuery) sqlQuery(ctx context.Context) *sql.Selector { // BouncerGroupBy is the group-by builder for Bouncer entities. type BouncerGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *BouncerQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -453,74 +447,77 @@ func (bgb *BouncerGroupBy) Aggregate(fns ...AggregateFunc) *BouncerGroupBy { return bgb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (bgb *BouncerGroupBy) Scan(ctx context.Context, v any) error { - query, err := bgb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, bgb.build.ctx, "GroupBy") + if err := bgb.build.prepareQuery(ctx); err != nil { return err } - bgb.sql = query - return bgb.sqlScan(ctx, v) + return scanWithInterceptors[*BouncerQuery, *BouncerGroupBy](ctx, bgb.build, bgb, bgb.build.inters, v) } -func (bgb *BouncerGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range bgb.fields { - if !bouncer.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (bgb *BouncerGroupBy) sqlScan(ctx context.Context, root *BouncerQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(bgb.fns)) + for _, fn := range bgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*bgb.flds)+len(bgb.fns)) + for _, f := range *bgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := bgb.sqlQuery() + selector.GroupBy(selector.Columns(*bgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := bgb.driver.Query(ctx, query, args, rows); err != nil { + if err := bgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (bgb *BouncerGroupBy) sqlQuery() *sql.Selector { - selector := bgb.sql.Select() - aggregation := make([]string, 0, len(bgb.fns)) - for _, fn := range bgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(bgb.fields)+len(bgb.fns)) - for _, f := range bgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(bgb.fields...)...) -} - // BouncerSelect is the builder for selecting fields of Bouncer entities. type BouncerSelect struct { *BouncerQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (bs *BouncerSelect) Aggregate(fns ...AggregateFunc) *BouncerSelect { + bs.fns = append(bs.fns, fns...) + return bs } // Scan applies the selector query and scans the result into the given value. func (bs *BouncerSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, bs.ctx, "Select") if err := bs.prepareQuery(ctx); err != nil { return err } - bs.sql = bs.BouncerQuery.sqlQuery(ctx) - return bs.sqlScan(ctx, v) + return scanWithInterceptors[*BouncerQuery, *BouncerSelect](ctx, bs.BouncerQuery, bs, bs.inters, v) } -func (bs *BouncerSelect) sqlScan(ctx context.Context, v any) error { +func (bs *BouncerSelect) sqlScan(ctx context.Context, root *BouncerQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(bs.fns)) + for _, fn := range bs.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*bs.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := bs.sql.Query() + query, args := selector.Query() if err := bs.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/pkg/database/ent/bouncer_update.go b/pkg/database/ent/bouncer_update.go index acf48dedeec..620b006a49a 100644 --- a/pkg/database/ent/bouncer_update.go +++ b/pkg/database/ent/bouncer_update.go @@ -28,48 +28,40 @@ func (bu *BouncerUpdate) Where(ps ...predicate.Bouncer) *BouncerUpdate { return bu } -// SetCreatedAt sets the "created_at" field. -func (bu *BouncerUpdate) SetCreatedAt(t time.Time) *BouncerUpdate { - bu.mutation.SetCreatedAt(t) - return bu -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (bu *BouncerUpdate) ClearCreatedAt() *BouncerUpdate { - bu.mutation.ClearCreatedAt() - return bu -} - // SetUpdatedAt sets the "updated_at" field. func (bu *BouncerUpdate) SetUpdatedAt(t time.Time) *BouncerUpdate { bu.mutation.SetUpdatedAt(t) return bu } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (bu *BouncerUpdate) ClearUpdatedAt() *BouncerUpdate { - bu.mutation.ClearUpdatedAt() - return bu -} - -// SetName sets the "name" field. -func (bu *BouncerUpdate) SetName(s string) *BouncerUpdate { - bu.mutation.SetName(s) - return bu -} - // SetAPIKey sets the "api_key" field. func (bu *BouncerUpdate) SetAPIKey(s string) *BouncerUpdate { bu.mutation.SetAPIKey(s) return bu } +// SetNillableAPIKey sets the "api_key" field if the given value is not nil. +func (bu *BouncerUpdate) SetNillableAPIKey(s *string) *BouncerUpdate { + if s != nil { + bu.SetAPIKey(*s) + } + return bu +} + // SetRevoked sets the "revoked" field. func (bu *BouncerUpdate) SetRevoked(b bool) *BouncerUpdate { bu.mutation.SetRevoked(b) return bu } +// SetNillableRevoked sets the "revoked" field if the given value is not nil. +func (bu *BouncerUpdate) SetNillableRevoked(b *bool) *BouncerUpdate { + if b != nil { + bu.SetRevoked(*b) + } + return bu +} + // SetIPAddress sets the "ip_address" field. func (bu *BouncerUpdate) SetIPAddress(s string) *BouncerUpdate { bu.mutation.SetIPAddress(s) @@ -130,26 +122,6 @@ func (bu *BouncerUpdate) ClearVersion() *BouncerUpdate { return bu } -// SetUntil sets the "until" field. -func (bu *BouncerUpdate) SetUntil(t time.Time) *BouncerUpdate { - bu.mutation.SetUntil(t) - return bu -} - -// SetNillableUntil sets the "until" field if the given value is not nil. -func (bu *BouncerUpdate) SetNillableUntil(t *time.Time) *BouncerUpdate { - if t != nil { - bu.SetUntil(*t) - } - return bu -} - -// ClearUntil clears the value of the "until" field. -func (bu *BouncerUpdate) ClearUntil() *BouncerUpdate { - bu.mutation.ClearUntil() - return bu -} - // SetLastPull sets the "last_pull" field. func (bu *BouncerUpdate) SetLastPull(t time.Time) *BouncerUpdate { bu.mutation.SetLastPull(t) @@ -164,6 +136,12 @@ func (bu *BouncerUpdate) SetNillableLastPull(t *time.Time) *BouncerUpdate { return bu } +// ClearLastPull clears the value of the "last_pull" field. +func (bu *BouncerUpdate) ClearLastPull() *BouncerUpdate { + bu.mutation.ClearLastPull() + return bu +} + // SetAuthType sets the "auth_type" field. func (bu *BouncerUpdate) SetAuthType(s string) *BouncerUpdate { bu.mutation.SetAuthType(s) @@ -178,6 +156,66 @@ func (bu *BouncerUpdate) SetNillableAuthType(s *string) *BouncerUpdate { return bu } +// SetOsname sets the "osname" field. +func (bu *BouncerUpdate) SetOsname(s string) *BouncerUpdate { + bu.mutation.SetOsname(s) + return bu +} + +// SetNillableOsname sets the "osname" field if the given value is not nil. +func (bu *BouncerUpdate) SetNillableOsname(s *string) *BouncerUpdate { + if s != nil { + bu.SetOsname(*s) + } + return bu +} + +// ClearOsname clears the value of the "osname" field. +func (bu *BouncerUpdate) ClearOsname() *BouncerUpdate { + bu.mutation.ClearOsname() + return bu +} + +// SetOsversion sets the "osversion" field. +func (bu *BouncerUpdate) SetOsversion(s string) *BouncerUpdate { + bu.mutation.SetOsversion(s) + return bu +} + +// SetNillableOsversion sets the "osversion" field if the given value is not nil. +func (bu *BouncerUpdate) SetNillableOsversion(s *string) *BouncerUpdate { + if s != nil { + bu.SetOsversion(*s) + } + return bu +} + +// ClearOsversion clears the value of the "osversion" field. +func (bu *BouncerUpdate) ClearOsversion() *BouncerUpdate { + bu.mutation.ClearOsversion() + return bu +} + +// SetFeatureflags sets the "featureflags" field. +func (bu *BouncerUpdate) SetFeatureflags(s string) *BouncerUpdate { + bu.mutation.SetFeatureflags(s) + return bu +} + +// SetNillableFeatureflags sets the "featureflags" field if the given value is not nil. +func (bu *BouncerUpdate) SetNillableFeatureflags(s *string) *BouncerUpdate { + if s != nil { + bu.SetFeatureflags(*s) + } + return bu +} + +// ClearFeatureflags clears the value of the "featureflags" field. +func (bu *BouncerUpdate) ClearFeatureflags() *BouncerUpdate { + bu.mutation.ClearFeatureflags() + return bu +} + // Mutation returns the BouncerMutation object of the builder. func (bu *BouncerUpdate) Mutation() *BouncerMutation { return bu.mutation @@ -185,35 +223,8 @@ func (bu *BouncerUpdate) Mutation() *BouncerMutation { // Save executes the query and returns the number of nodes affected by the update operation. func (bu *BouncerUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) bu.defaults() - if len(bu.hooks) == 0 { - affected, err = bu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*BouncerMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - bu.mutation = mutation - affected, err = bu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(bu.hooks) - 1; i >= 0; i-- { - if bu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = bu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, bu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, bu.sqlSave, bu.mutation, bu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -240,27 +251,14 @@ func (bu *BouncerUpdate) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (bu *BouncerUpdate) defaults() { - if _, ok := bu.mutation.CreatedAt(); !ok && !bu.mutation.CreatedAtCleared() { - v := bouncer.UpdateDefaultCreatedAt() - bu.mutation.SetCreatedAt(v) - } - if _, ok := bu.mutation.UpdatedAt(); !ok && !bu.mutation.UpdatedAtCleared() { + if _, ok := bu.mutation.UpdatedAt(); !ok { v := bouncer.UpdateDefaultUpdatedAt() bu.mutation.SetUpdatedAt(v) } } func (bu *BouncerUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: bouncer.Table, - Columns: bouncer.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: bouncer.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(bouncer.Table, bouncer.Columns, sqlgraph.NewFieldSpec(bouncer.FieldID, field.TypeInt)) if ps := bu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -268,118 +266,59 @@ func (bu *BouncerUpdate) sqlSave(ctx context.Context) (n int, err error) { } } } - if value, ok := bu.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldCreatedAt, - }) - } - if bu.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: bouncer.FieldCreatedAt, - }) - } if value, ok := bu.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldUpdatedAt, - }) - } - if bu.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: bouncer.FieldUpdatedAt, - }) - } - if value, ok := bu.mutation.Name(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldName, - }) + _spec.SetField(bouncer.FieldUpdatedAt, field.TypeTime, value) } if value, ok := bu.mutation.APIKey(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldAPIKey, - }) + _spec.SetField(bouncer.FieldAPIKey, field.TypeString, value) } if value, ok := bu.mutation.Revoked(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: bouncer.FieldRevoked, - }) + _spec.SetField(bouncer.FieldRevoked, field.TypeBool, value) } if value, ok := bu.mutation.IPAddress(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldIPAddress, - }) + _spec.SetField(bouncer.FieldIPAddress, field.TypeString, value) } if bu.mutation.IPAddressCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: bouncer.FieldIPAddress, - }) + _spec.ClearField(bouncer.FieldIPAddress, field.TypeString) } if value, ok := bu.mutation.GetType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldType, - }) + _spec.SetField(bouncer.FieldType, field.TypeString, value) } if bu.mutation.TypeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: bouncer.FieldType, - }) + _spec.ClearField(bouncer.FieldType, field.TypeString) } if value, ok := bu.mutation.Version(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldVersion, - }) + _spec.SetField(bouncer.FieldVersion, field.TypeString, value) } if bu.mutation.VersionCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: bouncer.FieldVersion, - }) - } - if value, ok := bu.mutation.Until(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldUntil, - }) - } - if bu.mutation.UntilCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: bouncer.FieldUntil, - }) + _spec.ClearField(bouncer.FieldVersion, field.TypeString) } if value, ok := bu.mutation.LastPull(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldLastPull, - }) + _spec.SetField(bouncer.FieldLastPull, field.TypeTime, value) + } + if bu.mutation.LastPullCleared() { + _spec.ClearField(bouncer.FieldLastPull, field.TypeTime) } if value, ok := bu.mutation.AuthType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldAuthType, - }) + _spec.SetField(bouncer.FieldAuthType, field.TypeString, value) + } + if value, ok := bu.mutation.Osname(); ok { + _spec.SetField(bouncer.FieldOsname, field.TypeString, value) + } + if bu.mutation.OsnameCleared() { + _spec.ClearField(bouncer.FieldOsname, field.TypeString) + } + if value, ok := bu.mutation.Osversion(); ok { + _spec.SetField(bouncer.FieldOsversion, field.TypeString, value) + } + if bu.mutation.OsversionCleared() { + _spec.ClearField(bouncer.FieldOsversion, field.TypeString) + } + if value, ok := bu.mutation.Featureflags(); ok { + _spec.SetField(bouncer.FieldFeatureflags, field.TypeString, value) + } + if bu.mutation.FeatureflagsCleared() { + _spec.ClearField(bouncer.FieldFeatureflags, field.TypeString) } if n, err = sqlgraph.UpdateNodes(ctx, bu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { @@ -389,6 +328,7 @@ func (bu *BouncerUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + bu.mutation.done = true return n, nil } @@ -400,48 +340,40 @@ type BouncerUpdateOne struct { mutation *BouncerMutation } -// SetCreatedAt sets the "created_at" field. -func (buo *BouncerUpdateOne) SetCreatedAt(t time.Time) *BouncerUpdateOne { - buo.mutation.SetCreatedAt(t) - return buo -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (buo *BouncerUpdateOne) ClearCreatedAt() *BouncerUpdateOne { - buo.mutation.ClearCreatedAt() - return buo -} - // SetUpdatedAt sets the "updated_at" field. func (buo *BouncerUpdateOne) SetUpdatedAt(t time.Time) *BouncerUpdateOne { buo.mutation.SetUpdatedAt(t) return buo } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (buo *BouncerUpdateOne) ClearUpdatedAt() *BouncerUpdateOne { - buo.mutation.ClearUpdatedAt() - return buo -} - -// SetName sets the "name" field. -func (buo *BouncerUpdateOne) SetName(s string) *BouncerUpdateOne { - buo.mutation.SetName(s) - return buo -} - // SetAPIKey sets the "api_key" field. func (buo *BouncerUpdateOne) SetAPIKey(s string) *BouncerUpdateOne { buo.mutation.SetAPIKey(s) return buo } +// SetNillableAPIKey sets the "api_key" field if the given value is not nil. +func (buo *BouncerUpdateOne) SetNillableAPIKey(s *string) *BouncerUpdateOne { + if s != nil { + buo.SetAPIKey(*s) + } + return buo +} + // SetRevoked sets the "revoked" field. func (buo *BouncerUpdateOne) SetRevoked(b bool) *BouncerUpdateOne { buo.mutation.SetRevoked(b) return buo } +// SetNillableRevoked sets the "revoked" field if the given value is not nil. +func (buo *BouncerUpdateOne) SetNillableRevoked(b *bool) *BouncerUpdateOne { + if b != nil { + buo.SetRevoked(*b) + } + return buo +} + // SetIPAddress sets the "ip_address" field. func (buo *BouncerUpdateOne) SetIPAddress(s string) *BouncerUpdateOne { buo.mutation.SetIPAddress(s) @@ -502,26 +434,6 @@ func (buo *BouncerUpdateOne) ClearVersion() *BouncerUpdateOne { return buo } -// SetUntil sets the "until" field. -func (buo *BouncerUpdateOne) SetUntil(t time.Time) *BouncerUpdateOne { - buo.mutation.SetUntil(t) - return buo -} - -// SetNillableUntil sets the "until" field if the given value is not nil. -func (buo *BouncerUpdateOne) SetNillableUntil(t *time.Time) *BouncerUpdateOne { - if t != nil { - buo.SetUntil(*t) - } - return buo -} - -// ClearUntil clears the value of the "until" field. -func (buo *BouncerUpdateOne) ClearUntil() *BouncerUpdateOne { - buo.mutation.ClearUntil() - return buo -} - // SetLastPull sets the "last_pull" field. func (buo *BouncerUpdateOne) SetLastPull(t time.Time) *BouncerUpdateOne { buo.mutation.SetLastPull(t) @@ -536,6 +448,12 @@ func (buo *BouncerUpdateOne) SetNillableLastPull(t *time.Time) *BouncerUpdateOne return buo } +// ClearLastPull clears the value of the "last_pull" field. +func (buo *BouncerUpdateOne) ClearLastPull() *BouncerUpdateOne { + buo.mutation.ClearLastPull() + return buo +} + // SetAuthType sets the "auth_type" field. func (buo *BouncerUpdateOne) SetAuthType(s string) *BouncerUpdateOne { buo.mutation.SetAuthType(s) @@ -550,11 +468,77 @@ func (buo *BouncerUpdateOne) SetNillableAuthType(s *string) *BouncerUpdateOne { return buo } +// SetOsname sets the "osname" field. +func (buo *BouncerUpdateOne) SetOsname(s string) *BouncerUpdateOne { + buo.mutation.SetOsname(s) + return buo +} + +// SetNillableOsname sets the "osname" field if the given value is not nil. +func (buo *BouncerUpdateOne) SetNillableOsname(s *string) *BouncerUpdateOne { + if s != nil { + buo.SetOsname(*s) + } + return buo +} + +// ClearOsname clears the value of the "osname" field. +func (buo *BouncerUpdateOne) ClearOsname() *BouncerUpdateOne { + buo.mutation.ClearOsname() + return buo +} + +// SetOsversion sets the "osversion" field. +func (buo *BouncerUpdateOne) SetOsversion(s string) *BouncerUpdateOne { + buo.mutation.SetOsversion(s) + return buo +} + +// SetNillableOsversion sets the "osversion" field if the given value is not nil. +func (buo *BouncerUpdateOne) SetNillableOsversion(s *string) *BouncerUpdateOne { + if s != nil { + buo.SetOsversion(*s) + } + return buo +} + +// ClearOsversion clears the value of the "osversion" field. +func (buo *BouncerUpdateOne) ClearOsversion() *BouncerUpdateOne { + buo.mutation.ClearOsversion() + return buo +} + +// SetFeatureflags sets the "featureflags" field. +func (buo *BouncerUpdateOne) SetFeatureflags(s string) *BouncerUpdateOne { + buo.mutation.SetFeatureflags(s) + return buo +} + +// SetNillableFeatureflags sets the "featureflags" field if the given value is not nil. +func (buo *BouncerUpdateOne) SetNillableFeatureflags(s *string) *BouncerUpdateOne { + if s != nil { + buo.SetFeatureflags(*s) + } + return buo +} + +// ClearFeatureflags clears the value of the "featureflags" field. +func (buo *BouncerUpdateOne) ClearFeatureflags() *BouncerUpdateOne { + buo.mutation.ClearFeatureflags() + return buo +} + // Mutation returns the BouncerMutation object of the builder. func (buo *BouncerUpdateOne) Mutation() *BouncerMutation { return buo.mutation } +// Where appends a list predicates to the BouncerUpdate builder. +func (buo *BouncerUpdateOne) Where(ps ...predicate.Bouncer) *BouncerUpdateOne { + buo.mutation.Where(ps...) + return buo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (buo *BouncerUpdateOne) Select(field string, fields ...string) *BouncerUpdateOne { @@ -564,41 +548,8 @@ func (buo *BouncerUpdateOne) Select(field string, fields ...string) *BouncerUpda // Save executes the query and returns the updated Bouncer entity. func (buo *BouncerUpdateOne) Save(ctx context.Context) (*Bouncer, error) { - var ( - err error - node *Bouncer - ) buo.defaults() - if len(buo.hooks) == 0 { - node, err = buo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*BouncerMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - buo.mutation = mutation - node, err = buo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(buo.hooks) - 1; i >= 0; i-- { - if buo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = buo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, buo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Bouncer) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from BouncerMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, buo.sqlSave, buo.mutation, buo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -625,27 +576,14 @@ func (buo *BouncerUpdateOne) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (buo *BouncerUpdateOne) defaults() { - if _, ok := buo.mutation.CreatedAt(); !ok && !buo.mutation.CreatedAtCleared() { - v := bouncer.UpdateDefaultCreatedAt() - buo.mutation.SetCreatedAt(v) - } - if _, ok := buo.mutation.UpdatedAt(); !ok && !buo.mutation.UpdatedAtCleared() { + if _, ok := buo.mutation.UpdatedAt(); !ok { v := bouncer.UpdateDefaultUpdatedAt() buo.mutation.SetUpdatedAt(v) } } func (buo *BouncerUpdateOne) sqlSave(ctx context.Context) (_node *Bouncer, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: bouncer.Table, - Columns: bouncer.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: bouncer.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(bouncer.Table, bouncer.Columns, sqlgraph.NewFieldSpec(bouncer.FieldID, field.TypeInt)) id, ok := buo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Bouncer.id" for update`)} @@ -670,118 +608,59 @@ func (buo *BouncerUpdateOne) sqlSave(ctx context.Context) (_node *Bouncer, err e } } } - if value, ok := buo.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldCreatedAt, - }) - } - if buo.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: bouncer.FieldCreatedAt, - }) - } if value, ok := buo.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldUpdatedAt, - }) - } - if buo.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: bouncer.FieldUpdatedAt, - }) - } - if value, ok := buo.mutation.Name(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldName, - }) + _spec.SetField(bouncer.FieldUpdatedAt, field.TypeTime, value) } if value, ok := buo.mutation.APIKey(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldAPIKey, - }) + _spec.SetField(bouncer.FieldAPIKey, field.TypeString, value) } if value, ok := buo.mutation.Revoked(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: bouncer.FieldRevoked, - }) + _spec.SetField(bouncer.FieldRevoked, field.TypeBool, value) } if value, ok := buo.mutation.IPAddress(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldIPAddress, - }) + _spec.SetField(bouncer.FieldIPAddress, field.TypeString, value) } if buo.mutation.IPAddressCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: bouncer.FieldIPAddress, - }) + _spec.ClearField(bouncer.FieldIPAddress, field.TypeString) } if value, ok := buo.mutation.GetType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldType, - }) + _spec.SetField(bouncer.FieldType, field.TypeString, value) } if buo.mutation.TypeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: bouncer.FieldType, - }) + _spec.ClearField(bouncer.FieldType, field.TypeString) } if value, ok := buo.mutation.Version(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldVersion, - }) + _spec.SetField(bouncer.FieldVersion, field.TypeString, value) } if buo.mutation.VersionCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: bouncer.FieldVersion, - }) - } - if value, ok := buo.mutation.Until(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldUntil, - }) - } - if buo.mutation.UntilCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: bouncer.FieldUntil, - }) + _spec.ClearField(bouncer.FieldVersion, field.TypeString) } if value, ok := buo.mutation.LastPull(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldLastPull, - }) + _spec.SetField(bouncer.FieldLastPull, field.TypeTime, value) + } + if buo.mutation.LastPullCleared() { + _spec.ClearField(bouncer.FieldLastPull, field.TypeTime) } if value, ok := buo.mutation.AuthType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldAuthType, - }) + _spec.SetField(bouncer.FieldAuthType, field.TypeString, value) + } + if value, ok := buo.mutation.Osname(); ok { + _spec.SetField(bouncer.FieldOsname, field.TypeString, value) + } + if buo.mutation.OsnameCleared() { + _spec.ClearField(bouncer.FieldOsname, field.TypeString) + } + if value, ok := buo.mutation.Osversion(); ok { + _spec.SetField(bouncer.FieldOsversion, field.TypeString, value) + } + if buo.mutation.OsversionCleared() { + _spec.ClearField(bouncer.FieldOsversion, field.TypeString) + } + if value, ok := buo.mutation.Featureflags(); ok { + _spec.SetField(bouncer.FieldFeatureflags, field.TypeString, value) + } + if buo.mutation.FeatureflagsCleared() { + _spec.ClearField(bouncer.FieldFeatureflags, field.TypeString) } _node = &Bouncer{config: buo.config} _spec.Assign = _node.assignValues @@ -794,5 +673,6 @@ func (buo *BouncerUpdateOne) sqlSave(ctx context.Context) (_node *Bouncer, err e } return nil, err } + buo.mutation.done = true return _node, nil } diff --git a/pkg/database/ent/client.go b/pkg/database/ent/client.go index 815b1df6d16..59686102ebe 100644 --- a/pkg/database/ent/client.go +++ b/pkg/database/ent/client.go @@ -7,20 +7,23 @@ import ( "errors" "fmt" "log" + "reflect" "github.com/crowdsecurity/crowdsec/pkg/database/ent/migrate" + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" "github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer" "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" "github.com/crowdsecurity/crowdsec/pkg/database/ent/event" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" "github.com/crowdsecurity/crowdsec/pkg/database/ent/meta" - - "entgo.io/ent/dialect" - "entgo.io/ent/dialect/sql" - "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" ) // Client is the client that holds all ent builders. @@ -38,17 +41,19 @@ type Client struct { Decision *DecisionClient // Event is the client for interacting with the Event builders. Event *EventClient + // Lock is the client for interacting with the Lock builders. + Lock *LockClient // Machine is the client for interacting with the Machine builders. Machine *MachineClient // Meta is the client for interacting with the Meta builders. Meta *MetaClient + // Metric is the client for interacting with the Metric builders. + Metric *MetricClient } // NewClient creates a new client configured with the given options. func NewClient(opts ...Option) *Client { - cfg := config{log: log.Println, hooks: &hooks{}} - cfg.options(opts...) - client := &Client{config: cfg} + client := &Client{config: newConfig(opts...)} client.init() return client } @@ -60,8 +65,66 @@ func (c *Client) init() { c.ConfigItem = NewConfigItemClient(c.config) c.Decision = NewDecisionClient(c.config) c.Event = NewEventClient(c.config) + c.Lock = NewLockClient(c.config) c.Machine = NewMachineClient(c.config) c.Meta = NewMetaClient(c.config) + c.Metric = NewMetricClient(c.config) +} + +type ( + // config is the configuration for the client and its builder. + config struct { + // driver used for executing database requests. + driver dialect.Driver + // debug enable a debug logging. + debug bool + // log used for logging on debug mode. + log func(...any) + // hooks to execute on mutations. + hooks *hooks + // interceptors to execute on queries. + inters *inters + } + // Option function to configure the client. + Option func(*config) +) + +// newConfig creates a new config for the client. +func newConfig(opts ...Option) config { + cfg := config{log: log.Println, hooks: &hooks{}, inters: &inters{}} + cfg.options(opts...) + return cfg +} + +// options applies the options on the config object. +func (c *config) options(opts ...Option) { + for _, opt := range opts { + opt(c) + } + if c.debug { + c.driver = dialect.Debug(c.driver, c.log) + } +} + +// Debug enables debug logging on the ent.Driver. +func Debug() Option { + return func(c *config) { + c.debug = true + } +} + +// Log sets the logging function for debug mode. +func Log(fn func(...any)) Option { + return func(c *config) { + c.log = fn + } +} + +// Driver configures the client driver. +func Driver(driver dialect.Driver) Option { + return func(c *config) { + c.driver = driver + } } // Open opens a database/sql.DB specified by the driver name and @@ -80,11 +143,14 @@ func Open(driverName, dataSourceName string, options ...Option) (*Client, error) } } +// ErrTxStarted is returned when trying to start a new transaction from a transactional client. +var ErrTxStarted = errors.New("ent: cannot start a transaction within a transaction") + // Tx returns a new transactional client. The provided context // is used until the transaction is committed or rolled back. func (c *Client) Tx(ctx context.Context) (*Tx, error) { if _, ok := c.driver.(*txDriver); ok { - return nil, errors.New("ent: cannot start a transaction within a transaction") + return nil, ErrTxStarted } tx, err := newTx(ctx, c.driver) if err != nil { @@ -100,8 +166,10 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { ConfigItem: NewConfigItemClient(cfg), Decision: NewDecisionClient(cfg), Event: NewEventClient(cfg), + Lock: NewLockClient(cfg), Machine: NewMachineClient(cfg), Meta: NewMetaClient(cfg), + Metric: NewMetricClient(cfg), }, nil } @@ -126,8 +194,10 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) ConfigItem: NewConfigItemClient(cfg), Decision: NewDecisionClient(cfg), Event: NewEventClient(cfg), + Lock: NewLockClient(cfg), Machine: NewMachineClient(cfg), Meta: NewMetaClient(cfg), + Metric: NewMetricClient(cfg), }, nil } @@ -156,13 +226,49 @@ func (c *Client) Close() error { // Use adds the mutation hooks to all the entity clients. // In order to add hooks to a specific client, call: `client.Node.Use(...)`. func (c *Client) Use(hooks ...Hook) { - c.Alert.Use(hooks...) - c.Bouncer.Use(hooks...) - c.ConfigItem.Use(hooks...) - c.Decision.Use(hooks...) - c.Event.Use(hooks...) - c.Machine.Use(hooks...) - c.Meta.Use(hooks...) + for _, n := range []interface{ Use(...Hook) }{ + c.Alert, c.Bouncer, c.ConfigItem, c.Decision, c.Event, c.Lock, c.Machine, + c.Meta, c.Metric, + } { + n.Use(hooks...) + } +} + +// Intercept adds the query interceptors to all the entity clients. +// In order to add interceptors to a specific client, call: `client.Node.Intercept(...)`. +func (c *Client) Intercept(interceptors ...Interceptor) { + for _, n := range []interface{ Intercept(...Interceptor) }{ + c.Alert, c.Bouncer, c.ConfigItem, c.Decision, c.Event, c.Lock, c.Machine, + c.Meta, c.Metric, + } { + n.Intercept(interceptors...) + } +} + +// Mutate implements the ent.Mutator interface. +func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { + switch m := m.(type) { + case *AlertMutation: + return c.Alert.mutate(ctx, m) + case *BouncerMutation: + return c.Bouncer.mutate(ctx, m) + case *ConfigItemMutation: + return c.ConfigItem.mutate(ctx, m) + case *DecisionMutation: + return c.Decision.mutate(ctx, m) + case *EventMutation: + return c.Event.mutate(ctx, m) + case *LockMutation: + return c.Lock.mutate(ctx, m) + case *MachineMutation: + return c.Machine.mutate(ctx, m) + case *MetaMutation: + return c.Meta.mutate(ctx, m) + case *MetricMutation: + return c.Metric.mutate(ctx, m) + default: + return nil, fmt.Errorf("ent: unknown mutation type %T", m) + } } // AlertClient is a client for the Alert schema. @@ -181,6 +287,12 @@ func (c *AlertClient) Use(hooks ...Hook) { c.hooks.Alert = append(c.hooks.Alert, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `alert.Intercept(f(g(h())))`. +func (c *AlertClient) Intercept(interceptors ...Interceptor) { + c.inters.Alert = append(c.inters.Alert, interceptors...) +} + // Create returns a builder for creating a Alert entity. func (c *AlertClient) Create() *AlertCreate { mutation := newAlertMutation(c.config, OpCreate) @@ -192,6 +304,21 @@ func (c *AlertClient) CreateBulk(builders ...*AlertCreate) *AlertCreateBulk { return &AlertCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *AlertClient) MapCreateBulk(slice any, setFunc func(*AlertCreate, int)) *AlertCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &AlertCreateBulk{err: fmt.Errorf("calling to AlertClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*AlertCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &AlertCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for Alert. func (c *AlertClient) Update() *AlertUpdate { mutation := newAlertMutation(c.config, OpUpdate) @@ -221,7 +348,7 @@ func (c *AlertClient) DeleteOne(a *Alert) *AlertDeleteOne { return c.DeleteOneID(a.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *AlertClient) DeleteOneID(id int) *AlertDeleteOne { builder := c.Delete().Where(alert.ID(id)) builder.mutation.id = &id @@ -233,6 +360,8 @@ func (c *AlertClient) DeleteOneID(id int) *AlertDeleteOne { func (c *AlertClient) Query() *AlertQuery { return &AlertQuery{ config: c.config, + ctx: &QueryContext{Type: TypeAlert}, + inters: c.Interceptors(), } } @@ -252,8 +381,8 @@ func (c *AlertClient) GetX(ctx context.Context, id int) *Alert { // QueryOwner queries the owner edge of a Alert. func (c *AlertClient) QueryOwner(a *Alert) *MachineQuery { - query := &MachineQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&MachineClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := a.ID step := sqlgraph.NewStep( sqlgraph.From(alert.Table, alert.FieldID, id), @@ -268,8 +397,8 @@ func (c *AlertClient) QueryOwner(a *Alert) *MachineQuery { // QueryDecisions queries the decisions edge of a Alert. func (c *AlertClient) QueryDecisions(a *Alert) *DecisionQuery { - query := &DecisionQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&DecisionClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := a.ID step := sqlgraph.NewStep( sqlgraph.From(alert.Table, alert.FieldID, id), @@ -284,8 +413,8 @@ func (c *AlertClient) QueryDecisions(a *Alert) *DecisionQuery { // QueryEvents queries the events edge of a Alert. func (c *AlertClient) QueryEvents(a *Alert) *EventQuery { - query := &EventQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&EventClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := a.ID step := sqlgraph.NewStep( sqlgraph.From(alert.Table, alert.FieldID, id), @@ -300,8 +429,8 @@ func (c *AlertClient) QueryEvents(a *Alert) *EventQuery { // QueryMetas queries the metas edge of a Alert. func (c *AlertClient) QueryMetas(a *Alert) *MetaQuery { - query := &MetaQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&MetaClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := a.ID step := sqlgraph.NewStep( sqlgraph.From(alert.Table, alert.FieldID, id), @@ -319,6 +448,26 @@ func (c *AlertClient) Hooks() []Hook { return c.hooks.Alert } +// Interceptors returns the client interceptors. +func (c *AlertClient) Interceptors() []Interceptor { + return c.inters.Alert +} + +func (c *AlertClient) mutate(ctx context.Context, m *AlertMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&AlertCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&AlertUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&AlertUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&AlertDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Alert mutation op: %q", m.Op()) + } +} + // BouncerClient is a client for the Bouncer schema. type BouncerClient struct { config @@ -335,6 +484,12 @@ func (c *BouncerClient) Use(hooks ...Hook) { c.hooks.Bouncer = append(c.hooks.Bouncer, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `bouncer.Intercept(f(g(h())))`. +func (c *BouncerClient) Intercept(interceptors ...Interceptor) { + c.inters.Bouncer = append(c.inters.Bouncer, interceptors...) +} + // Create returns a builder for creating a Bouncer entity. func (c *BouncerClient) Create() *BouncerCreate { mutation := newBouncerMutation(c.config, OpCreate) @@ -346,6 +501,21 @@ func (c *BouncerClient) CreateBulk(builders ...*BouncerCreate) *BouncerCreateBul return &BouncerCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *BouncerClient) MapCreateBulk(slice any, setFunc func(*BouncerCreate, int)) *BouncerCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &BouncerCreateBulk{err: fmt.Errorf("calling to BouncerClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*BouncerCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &BouncerCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for Bouncer. func (c *BouncerClient) Update() *BouncerUpdate { mutation := newBouncerMutation(c.config, OpUpdate) @@ -375,7 +545,7 @@ func (c *BouncerClient) DeleteOne(b *Bouncer) *BouncerDeleteOne { return c.DeleteOneID(b.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *BouncerClient) DeleteOneID(id int) *BouncerDeleteOne { builder := c.Delete().Where(bouncer.ID(id)) builder.mutation.id = &id @@ -387,6 +557,8 @@ func (c *BouncerClient) DeleteOneID(id int) *BouncerDeleteOne { func (c *BouncerClient) Query() *BouncerQuery { return &BouncerQuery{ config: c.config, + ctx: &QueryContext{Type: TypeBouncer}, + inters: c.Interceptors(), } } @@ -409,6 +581,26 @@ func (c *BouncerClient) Hooks() []Hook { return c.hooks.Bouncer } +// Interceptors returns the client interceptors. +func (c *BouncerClient) Interceptors() []Interceptor { + return c.inters.Bouncer +} + +func (c *BouncerClient) mutate(ctx context.Context, m *BouncerMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&BouncerCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&BouncerUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&BouncerUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&BouncerDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Bouncer mutation op: %q", m.Op()) + } +} + // ConfigItemClient is a client for the ConfigItem schema. type ConfigItemClient struct { config @@ -425,6 +617,12 @@ func (c *ConfigItemClient) Use(hooks ...Hook) { c.hooks.ConfigItem = append(c.hooks.ConfigItem, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `configitem.Intercept(f(g(h())))`. +func (c *ConfigItemClient) Intercept(interceptors ...Interceptor) { + c.inters.ConfigItem = append(c.inters.ConfigItem, interceptors...) +} + // Create returns a builder for creating a ConfigItem entity. func (c *ConfigItemClient) Create() *ConfigItemCreate { mutation := newConfigItemMutation(c.config, OpCreate) @@ -436,6 +634,21 @@ func (c *ConfigItemClient) CreateBulk(builders ...*ConfigItemCreate) *ConfigItem return &ConfigItemCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *ConfigItemClient) MapCreateBulk(slice any, setFunc func(*ConfigItemCreate, int)) *ConfigItemCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &ConfigItemCreateBulk{err: fmt.Errorf("calling to ConfigItemClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*ConfigItemCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &ConfigItemCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for ConfigItem. func (c *ConfigItemClient) Update() *ConfigItemUpdate { mutation := newConfigItemMutation(c.config, OpUpdate) @@ -465,7 +678,7 @@ func (c *ConfigItemClient) DeleteOne(ci *ConfigItem) *ConfigItemDeleteOne { return c.DeleteOneID(ci.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *ConfigItemClient) DeleteOneID(id int) *ConfigItemDeleteOne { builder := c.Delete().Where(configitem.ID(id)) builder.mutation.id = &id @@ -477,6 +690,8 @@ func (c *ConfigItemClient) DeleteOneID(id int) *ConfigItemDeleteOne { func (c *ConfigItemClient) Query() *ConfigItemQuery { return &ConfigItemQuery{ config: c.config, + ctx: &QueryContext{Type: TypeConfigItem}, + inters: c.Interceptors(), } } @@ -499,6 +714,26 @@ func (c *ConfigItemClient) Hooks() []Hook { return c.hooks.ConfigItem } +// Interceptors returns the client interceptors. +func (c *ConfigItemClient) Interceptors() []Interceptor { + return c.inters.ConfigItem +} + +func (c *ConfigItemClient) mutate(ctx context.Context, m *ConfigItemMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&ConfigItemCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&ConfigItemUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&ConfigItemUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&ConfigItemDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown ConfigItem mutation op: %q", m.Op()) + } +} + // DecisionClient is a client for the Decision schema. type DecisionClient struct { config @@ -515,6 +750,12 @@ func (c *DecisionClient) Use(hooks ...Hook) { c.hooks.Decision = append(c.hooks.Decision, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `decision.Intercept(f(g(h())))`. +func (c *DecisionClient) Intercept(interceptors ...Interceptor) { + c.inters.Decision = append(c.inters.Decision, interceptors...) +} + // Create returns a builder for creating a Decision entity. func (c *DecisionClient) Create() *DecisionCreate { mutation := newDecisionMutation(c.config, OpCreate) @@ -526,6 +767,21 @@ func (c *DecisionClient) CreateBulk(builders ...*DecisionCreate) *DecisionCreate return &DecisionCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *DecisionClient) MapCreateBulk(slice any, setFunc func(*DecisionCreate, int)) *DecisionCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &DecisionCreateBulk{err: fmt.Errorf("calling to DecisionClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*DecisionCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &DecisionCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for Decision. func (c *DecisionClient) Update() *DecisionUpdate { mutation := newDecisionMutation(c.config, OpUpdate) @@ -555,7 +811,7 @@ func (c *DecisionClient) DeleteOne(d *Decision) *DecisionDeleteOne { return c.DeleteOneID(d.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *DecisionClient) DeleteOneID(id int) *DecisionDeleteOne { builder := c.Delete().Where(decision.ID(id)) builder.mutation.id = &id @@ -567,6 +823,8 @@ func (c *DecisionClient) DeleteOneID(id int) *DecisionDeleteOne { func (c *DecisionClient) Query() *DecisionQuery { return &DecisionQuery{ config: c.config, + ctx: &QueryContext{Type: TypeDecision}, + inters: c.Interceptors(), } } @@ -586,8 +844,8 @@ func (c *DecisionClient) GetX(ctx context.Context, id int) *Decision { // QueryOwner queries the owner edge of a Decision. func (c *DecisionClient) QueryOwner(d *Decision) *AlertQuery { - query := &AlertQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&AlertClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := d.ID step := sqlgraph.NewStep( sqlgraph.From(decision.Table, decision.FieldID, id), @@ -605,6 +863,26 @@ func (c *DecisionClient) Hooks() []Hook { return c.hooks.Decision } +// Interceptors returns the client interceptors. +func (c *DecisionClient) Interceptors() []Interceptor { + return c.inters.Decision +} + +func (c *DecisionClient) mutate(ctx context.Context, m *DecisionMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&DecisionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&DecisionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&DecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&DecisionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Decision mutation op: %q", m.Op()) + } +} + // EventClient is a client for the Event schema. type EventClient struct { config @@ -621,6 +899,12 @@ func (c *EventClient) Use(hooks ...Hook) { c.hooks.Event = append(c.hooks.Event, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `event.Intercept(f(g(h())))`. +func (c *EventClient) Intercept(interceptors ...Interceptor) { + c.inters.Event = append(c.inters.Event, interceptors...) +} + // Create returns a builder for creating a Event entity. func (c *EventClient) Create() *EventCreate { mutation := newEventMutation(c.config, OpCreate) @@ -632,6 +916,21 @@ func (c *EventClient) CreateBulk(builders ...*EventCreate) *EventCreateBulk { return &EventCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *EventClient) MapCreateBulk(slice any, setFunc func(*EventCreate, int)) *EventCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &EventCreateBulk{err: fmt.Errorf("calling to EventClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*EventCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &EventCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for Event. func (c *EventClient) Update() *EventUpdate { mutation := newEventMutation(c.config, OpUpdate) @@ -661,7 +960,7 @@ func (c *EventClient) DeleteOne(e *Event) *EventDeleteOne { return c.DeleteOneID(e.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *EventClient) DeleteOneID(id int) *EventDeleteOne { builder := c.Delete().Where(event.ID(id)) builder.mutation.id = &id @@ -673,6 +972,8 @@ func (c *EventClient) DeleteOneID(id int) *EventDeleteOne { func (c *EventClient) Query() *EventQuery { return &EventQuery{ config: c.config, + ctx: &QueryContext{Type: TypeEvent}, + inters: c.Interceptors(), } } @@ -692,8 +993,8 @@ func (c *EventClient) GetX(ctx context.Context, id int) *Event { // QueryOwner queries the owner edge of a Event. func (c *EventClient) QueryOwner(e *Event) *AlertQuery { - query := &AlertQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&AlertClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := e.ID step := sqlgraph.NewStep( sqlgraph.From(event.Table, event.FieldID, id), @@ -711,6 +1012,159 @@ func (c *EventClient) Hooks() []Hook { return c.hooks.Event } +// Interceptors returns the client interceptors. +func (c *EventClient) Interceptors() []Interceptor { + return c.inters.Event +} + +func (c *EventClient) mutate(ctx context.Context, m *EventMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&EventCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&EventUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&EventUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&EventDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Event mutation op: %q", m.Op()) + } +} + +// LockClient is a client for the Lock schema. +type LockClient struct { + config +} + +// NewLockClient returns a client for the Lock from the given config. +func NewLockClient(c config) *LockClient { + return &LockClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `lock.Hooks(f(g(h())))`. +func (c *LockClient) Use(hooks ...Hook) { + c.hooks.Lock = append(c.hooks.Lock, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `lock.Intercept(f(g(h())))`. +func (c *LockClient) Intercept(interceptors ...Interceptor) { + c.inters.Lock = append(c.inters.Lock, interceptors...) +} + +// Create returns a builder for creating a Lock entity. +func (c *LockClient) Create() *LockCreate { + mutation := newLockMutation(c.config, OpCreate) + return &LockCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of Lock entities. +func (c *LockClient) CreateBulk(builders ...*LockCreate) *LockCreateBulk { + return &LockCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *LockClient) MapCreateBulk(slice any, setFunc func(*LockCreate, int)) *LockCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &LockCreateBulk{err: fmt.Errorf("calling to LockClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*LockCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &LockCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for Lock. +func (c *LockClient) Update() *LockUpdate { + mutation := newLockMutation(c.config, OpUpdate) + return &LockUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *LockClient) UpdateOne(l *Lock) *LockUpdateOne { + mutation := newLockMutation(c.config, OpUpdateOne, withLock(l)) + return &LockUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *LockClient) UpdateOneID(id int) *LockUpdateOne { + mutation := newLockMutation(c.config, OpUpdateOne, withLockID(id)) + return &LockUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for Lock. +func (c *LockClient) Delete() *LockDelete { + mutation := newLockMutation(c.config, OpDelete) + return &LockDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *LockClient) DeleteOne(l *Lock) *LockDeleteOne { + return c.DeleteOneID(l.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *LockClient) DeleteOneID(id int) *LockDeleteOne { + builder := c.Delete().Where(lock.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &LockDeleteOne{builder} +} + +// Query returns a query builder for Lock. +func (c *LockClient) Query() *LockQuery { + return &LockQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeLock}, + inters: c.Interceptors(), + } +} + +// Get returns a Lock entity by its id. +func (c *LockClient) Get(ctx context.Context, id int) (*Lock, error) { + return c.Query().Where(lock.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *LockClient) GetX(ctx context.Context, id int) *Lock { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// Hooks returns the client hooks. +func (c *LockClient) Hooks() []Hook { + return c.hooks.Lock +} + +// Interceptors returns the client interceptors. +func (c *LockClient) Interceptors() []Interceptor { + return c.inters.Lock +} + +func (c *LockClient) mutate(ctx context.Context, m *LockMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&LockCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&LockUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&LockUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&LockDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Lock mutation op: %q", m.Op()) + } +} + // MachineClient is a client for the Machine schema. type MachineClient struct { config @@ -727,6 +1181,12 @@ func (c *MachineClient) Use(hooks ...Hook) { c.hooks.Machine = append(c.hooks.Machine, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `machine.Intercept(f(g(h())))`. +func (c *MachineClient) Intercept(interceptors ...Interceptor) { + c.inters.Machine = append(c.inters.Machine, interceptors...) +} + // Create returns a builder for creating a Machine entity. func (c *MachineClient) Create() *MachineCreate { mutation := newMachineMutation(c.config, OpCreate) @@ -738,6 +1198,21 @@ func (c *MachineClient) CreateBulk(builders ...*MachineCreate) *MachineCreateBul return &MachineCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *MachineClient) MapCreateBulk(slice any, setFunc func(*MachineCreate, int)) *MachineCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &MachineCreateBulk{err: fmt.Errorf("calling to MachineClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*MachineCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &MachineCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for Machine. func (c *MachineClient) Update() *MachineUpdate { mutation := newMachineMutation(c.config, OpUpdate) @@ -767,7 +1242,7 @@ func (c *MachineClient) DeleteOne(m *Machine) *MachineDeleteOne { return c.DeleteOneID(m.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *MachineClient) DeleteOneID(id int) *MachineDeleteOne { builder := c.Delete().Where(machine.ID(id)) builder.mutation.id = &id @@ -779,6 +1254,8 @@ func (c *MachineClient) DeleteOneID(id int) *MachineDeleteOne { func (c *MachineClient) Query() *MachineQuery { return &MachineQuery{ config: c.config, + ctx: &QueryContext{Type: TypeMachine}, + inters: c.Interceptors(), } } @@ -798,8 +1275,8 @@ func (c *MachineClient) GetX(ctx context.Context, id int) *Machine { // QueryAlerts queries the alerts edge of a Machine. func (c *MachineClient) QueryAlerts(m *Machine) *AlertQuery { - query := &AlertQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&AlertClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := m.ID step := sqlgraph.NewStep( sqlgraph.From(machine.Table, machine.FieldID, id), @@ -817,6 +1294,26 @@ func (c *MachineClient) Hooks() []Hook { return c.hooks.Machine } +// Interceptors returns the client interceptors. +func (c *MachineClient) Interceptors() []Interceptor { + return c.inters.Machine +} + +func (c *MachineClient) mutate(ctx context.Context, m *MachineMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&MachineCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&MachineUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&MachineUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&MachineDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Machine mutation op: %q", m.Op()) + } +} + // MetaClient is a client for the Meta schema. type MetaClient struct { config @@ -833,6 +1330,12 @@ func (c *MetaClient) Use(hooks ...Hook) { c.hooks.Meta = append(c.hooks.Meta, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `meta.Intercept(f(g(h())))`. +func (c *MetaClient) Intercept(interceptors ...Interceptor) { + c.inters.Meta = append(c.inters.Meta, interceptors...) +} + // Create returns a builder for creating a Meta entity. func (c *MetaClient) Create() *MetaCreate { mutation := newMetaMutation(c.config, OpCreate) @@ -844,6 +1347,21 @@ func (c *MetaClient) CreateBulk(builders ...*MetaCreate) *MetaCreateBulk { return &MetaCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *MetaClient) MapCreateBulk(slice any, setFunc func(*MetaCreate, int)) *MetaCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &MetaCreateBulk{err: fmt.Errorf("calling to MetaClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*MetaCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &MetaCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for Meta. func (c *MetaClient) Update() *MetaUpdate { mutation := newMetaMutation(c.config, OpUpdate) @@ -873,7 +1391,7 @@ func (c *MetaClient) DeleteOne(m *Meta) *MetaDeleteOne { return c.DeleteOneID(m.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *MetaClient) DeleteOneID(id int) *MetaDeleteOne { builder := c.Delete().Where(meta.ID(id)) builder.mutation.id = &id @@ -885,6 +1403,8 @@ func (c *MetaClient) DeleteOneID(id int) *MetaDeleteOne { func (c *MetaClient) Query() *MetaQuery { return &MetaQuery{ config: c.config, + ctx: &QueryContext{Type: TypeMeta}, + inters: c.Interceptors(), } } @@ -904,8 +1424,8 @@ func (c *MetaClient) GetX(ctx context.Context, id int) *Meta { // QueryOwner queries the owner edge of a Meta. func (c *MetaClient) QueryOwner(m *Meta) *AlertQuery { - query := &AlertQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&AlertClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := m.ID step := sqlgraph.NewStep( sqlgraph.From(meta.Table, meta.FieldID, id), @@ -922,3 +1442,168 @@ func (c *MetaClient) QueryOwner(m *Meta) *AlertQuery { func (c *MetaClient) Hooks() []Hook { return c.hooks.Meta } + +// Interceptors returns the client interceptors. +func (c *MetaClient) Interceptors() []Interceptor { + return c.inters.Meta +} + +func (c *MetaClient) mutate(ctx context.Context, m *MetaMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&MetaCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&MetaUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&MetaUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&MetaDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Meta mutation op: %q", m.Op()) + } +} + +// MetricClient is a client for the Metric schema. +type MetricClient struct { + config +} + +// NewMetricClient returns a client for the Metric from the given config. +func NewMetricClient(c config) *MetricClient { + return &MetricClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `metric.Hooks(f(g(h())))`. +func (c *MetricClient) Use(hooks ...Hook) { + c.hooks.Metric = append(c.hooks.Metric, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `metric.Intercept(f(g(h())))`. +func (c *MetricClient) Intercept(interceptors ...Interceptor) { + c.inters.Metric = append(c.inters.Metric, interceptors...) +} + +// Create returns a builder for creating a Metric entity. +func (c *MetricClient) Create() *MetricCreate { + mutation := newMetricMutation(c.config, OpCreate) + return &MetricCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of Metric entities. +func (c *MetricClient) CreateBulk(builders ...*MetricCreate) *MetricCreateBulk { + return &MetricCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *MetricClient) MapCreateBulk(slice any, setFunc func(*MetricCreate, int)) *MetricCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &MetricCreateBulk{err: fmt.Errorf("calling to MetricClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*MetricCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &MetricCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for Metric. +func (c *MetricClient) Update() *MetricUpdate { + mutation := newMetricMutation(c.config, OpUpdate) + return &MetricUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *MetricClient) UpdateOne(m *Metric) *MetricUpdateOne { + mutation := newMetricMutation(c.config, OpUpdateOne, withMetric(m)) + return &MetricUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *MetricClient) UpdateOneID(id int) *MetricUpdateOne { + mutation := newMetricMutation(c.config, OpUpdateOne, withMetricID(id)) + return &MetricUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for Metric. +func (c *MetricClient) Delete() *MetricDelete { + mutation := newMetricMutation(c.config, OpDelete) + return &MetricDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *MetricClient) DeleteOne(m *Metric) *MetricDeleteOne { + return c.DeleteOneID(m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *MetricClient) DeleteOneID(id int) *MetricDeleteOne { + builder := c.Delete().Where(metric.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &MetricDeleteOne{builder} +} + +// Query returns a query builder for Metric. +func (c *MetricClient) Query() *MetricQuery { + return &MetricQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeMetric}, + inters: c.Interceptors(), + } +} + +// Get returns a Metric entity by its id. +func (c *MetricClient) Get(ctx context.Context, id int) (*Metric, error) { + return c.Query().Where(metric.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *MetricClient) GetX(ctx context.Context, id int) *Metric { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// Hooks returns the client hooks. +func (c *MetricClient) Hooks() []Hook { + return c.hooks.Metric +} + +// Interceptors returns the client interceptors. +func (c *MetricClient) Interceptors() []Interceptor { + return c.inters.Metric +} + +func (c *MetricClient) mutate(ctx context.Context, m *MetricMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&MetricCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&MetricUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&MetricUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&MetricDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Metric mutation op: %q", m.Op()) + } +} + +// hooks and interceptors per client, for fast access. +type ( + hooks struct { + Alert, Bouncer, ConfigItem, Decision, Event, Lock, Machine, Meta, + Metric []ent.Hook + } + inters struct { + Alert, Bouncer, ConfigItem, Decision, Event, Lock, Machine, Meta, + Metric []ent.Interceptor + } +) diff --git a/pkg/database/ent/config.go b/pkg/database/ent/config.go deleted file mode 100644 index 1a152809a32..00000000000 --- a/pkg/database/ent/config.go +++ /dev/null @@ -1,65 +0,0 @@ -// Code generated by ent, DO NOT EDIT. - -package ent - -import ( - "entgo.io/ent" - "entgo.io/ent/dialect" -) - -// Option function to configure the client. -type Option func(*config) - -// Config is the configuration for the client and its builder. -type config struct { - // driver used for executing database requests. - driver dialect.Driver - // debug enable a debug logging. - debug bool - // log used for logging on debug mode. - log func(...any) - // hooks to execute on mutations. - hooks *hooks -} - -// hooks per client, for fast access. -type hooks struct { - Alert []ent.Hook - Bouncer []ent.Hook - ConfigItem []ent.Hook - Decision []ent.Hook - Event []ent.Hook - Machine []ent.Hook - Meta []ent.Hook -} - -// Options applies the options on the config object. -func (c *config) options(opts ...Option) { - for _, opt := range opts { - opt(c) - } - if c.debug { - c.driver = dialect.Debug(c.driver, c.log) - } -} - -// Debug enables debug logging on the ent.Driver. -func Debug() Option { - return func(c *config) { - c.debug = true - } -} - -// Log sets the logging function for debug mode. -func Log(fn func(...any)) Option { - return func(c *config) { - c.log = fn - } -} - -// Driver configures the client driver. -func Driver(driver dialect.Driver) Option { - return func(c *config) { - c.driver = driver - } -} diff --git a/pkg/database/ent/configitem.go b/pkg/database/ent/configitem.go index 615780dbacc..bdf23ef4948 100644 --- a/pkg/database/ent/configitem.go +++ b/pkg/database/ent/configitem.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" ) @@ -17,13 +18,14 @@ type ConfigItem struct { // ID of the ent. ID int `json:"id,omitempty"` // CreatedAt holds the value of the "created_at" field. - CreatedAt *time.Time `json:"created_at"` + CreatedAt time.Time `json:"created_at"` // UpdatedAt holds the value of the "updated_at" field. - UpdatedAt *time.Time `json:"updated_at"` + UpdatedAt time.Time `json:"updated_at"` // Name holds the value of the "name" field. Name string `json:"name"` // Value holds the value of the "value" field. - Value string `json:"value"` + Value string `json:"value"` + selectValues sql.SelectValues } // scanValues returns the types for scanning values from sql.Rows. @@ -38,7 +40,7 @@ func (*ConfigItem) scanValues(columns []string) ([]any, error) { case configitem.FieldCreatedAt, configitem.FieldUpdatedAt: values[i] = new(sql.NullTime) default: - return nil, fmt.Errorf("unexpected column %q for type ConfigItem", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -62,15 +64,13 @@ func (ci *ConfigItem) assignValues(columns []string, values []any) error { if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) } else if value.Valid { - ci.CreatedAt = new(time.Time) - *ci.CreatedAt = value.Time + ci.CreatedAt = value.Time } case configitem.FieldUpdatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field updated_at", values[i]) } else if value.Valid { - ci.UpdatedAt = new(time.Time) - *ci.UpdatedAt = value.Time + ci.UpdatedAt = value.Time } case configitem.FieldName: if value, ok := values[i].(*sql.NullString); !ok { @@ -84,16 +84,24 @@ func (ci *ConfigItem) assignValues(columns []string, values []any) error { } else if value.Valid { ci.Value = value.String } + default: + ci.selectValues.Set(columns[i], values[i]) } } return nil } +// GetValue returns the ent.Value that was dynamically selected and assigned to the ConfigItem. +// This includes values selected through modifiers, order, etc. +func (ci *ConfigItem) GetValue(name string) (ent.Value, error) { + return ci.selectValues.Get(name) +} + // Update returns a builder for updating this ConfigItem. // Note that you need to call ConfigItem.Unwrap() before calling this method if this ConfigItem // was returned from a transaction, and the transaction was committed or rolled back. func (ci *ConfigItem) Update() *ConfigItemUpdateOne { - return (&ConfigItemClient{config: ci.config}).UpdateOne(ci) + return NewConfigItemClient(ci.config).UpdateOne(ci) } // Unwrap unwraps the ConfigItem entity that was returned from a transaction after it was closed, @@ -112,15 +120,11 @@ func (ci *ConfigItem) String() string { var builder strings.Builder builder.WriteString("ConfigItem(") builder.WriteString(fmt.Sprintf("id=%v, ", ci.ID)) - if v := ci.CreatedAt; v != nil { - builder.WriteString("created_at=") - builder.WriteString(v.Format(time.ANSIC)) - } + builder.WriteString("created_at=") + builder.WriteString(ci.CreatedAt.Format(time.ANSIC)) builder.WriteString(", ") - if v := ci.UpdatedAt; v != nil { - builder.WriteString("updated_at=") - builder.WriteString(v.Format(time.ANSIC)) - } + builder.WriteString("updated_at=") + builder.WriteString(ci.UpdatedAt.Format(time.ANSIC)) builder.WriteString(", ") builder.WriteString("name=") builder.WriteString(ci.Name) @@ -133,9 +137,3 @@ func (ci *ConfigItem) String() string { // ConfigItems is a parsable slice of ConfigItem. type ConfigItems []*ConfigItem - -func (ci ConfigItems) config(cfg config) { - for _i := range ci { - ci[_i].config = cfg - } -} diff --git a/pkg/database/ent/configitem/configitem.go b/pkg/database/ent/configitem/configitem.go index 80e93e4cc7e..611d81a3960 100644 --- a/pkg/database/ent/configitem/configitem.go +++ b/pkg/database/ent/configitem/configitem.go @@ -4,6 +4,8 @@ package configitem import ( "time" + + "entgo.io/ent/dialect/sql" ) const ( @@ -45,10 +47,36 @@ func ValidColumn(column string) bool { var ( // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time - // UpdateDefaultCreatedAt holds the default value on update for the "created_at" field. - UpdateDefaultCreatedAt func() time.Time // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. DefaultUpdatedAt func() time.Time // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. UpdateDefaultUpdatedAt func() time.Time ) + +// OrderOption defines the ordering options for the ConfigItem queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByName orders the results by the name field. +func ByName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldName, opts...).ToFunc() +} + +// ByValue orders the results by the value field. +func ByValue(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldValue, opts...).ToFunc() +} diff --git a/pkg/database/ent/configitem/where.go b/pkg/database/ent/configitem/where.go index 6d06938a855..48ae792fd72 100644 --- a/pkg/database/ent/configitem/where.go +++ b/pkg/database/ent/configitem/where.go @@ -11,485 +11,290 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.ConfigItem(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.ConfigItem(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.ConfigItem(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.ConfigItem(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.ConfigItem(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.ConfigItem(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.ConfigItem(sql.FieldLTE(FieldID, id)) } // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldCreatedAt, v)) } // UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. func UpdatedAt(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldUpdatedAt, v)) } // Name applies equality check predicate on the "name" field. It's identical to NameEQ. func Name(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldName, v)) } // Value applies equality check predicate on the "value" field. It's identical to ValueEQ. func Value(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldValue, v)) } // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldCreatedAt, v)) } // CreatedAtNEQ applies the NEQ predicate on the "created_at" field. func CreatedAtNEQ(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldCreatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldNEQ(FieldCreatedAt, v)) } // CreatedAtIn applies the In predicate on the "created_at" field. func CreatedAtIn(vs ...time.Time) predicate.ConfigItem { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldCreatedAt), v...)) - }) + return predicate.ConfigItem(sql.FieldIn(FieldCreatedAt, vs...)) } // CreatedAtNotIn applies the NotIn predicate on the "created_at" field. func CreatedAtNotIn(vs ...time.Time) predicate.ConfigItem { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldCreatedAt), v...)) - }) + return predicate.ConfigItem(sql.FieldNotIn(FieldCreatedAt, vs...)) } // CreatedAtGT applies the GT predicate on the "created_at" field. func CreatedAtGT(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldCreatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldGT(FieldCreatedAt, v)) } // CreatedAtGTE applies the GTE predicate on the "created_at" field. func CreatedAtGTE(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldCreatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldGTE(FieldCreatedAt, v)) } // CreatedAtLT applies the LT predicate on the "created_at" field. func CreatedAtLT(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldCreatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldLT(FieldCreatedAt, v)) } // CreatedAtLTE applies the LTE predicate on the "created_at" field. func CreatedAtLTE(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldCreatedAt), v)) - }) -} - -// CreatedAtIsNil applies the IsNil predicate on the "created_at" field. -func CreatedAtIsNil() predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldCreatedAt))) - }) -} - -// CreatedAtNotNil applies the NotNil predicate on the "created_at" field. -func CreatedAtNotNil() predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldCreatedAt))) - }) + return predicate.ConfigItem(sql.FieldLTE(FieldCreatedAt, v)) } // UpdatedAtEQ applies the EQ predicate on the "updated_at" field. func UpdatedAtEQ(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldUpdatedAt, v)) } // UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. func UpdatedAtNEQ(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldNEQ(FieldUpdatedAt, v)) } // UpdatedAtIn applies the In predicate on the "updated_at" field. func UpdatedAtIn(vs ...time.Time) predicate.ConfigItem { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUpdatedAt), v...)) - }) + return predicate.ConfigItem(sql.FieldIn(FieldUpdatedAt, vs...)) } // UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. func UpdatedAtNotIn(vs ...time.Time) predicate.ConfigItem { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUpdatedAt), v...)) - }) + return predicate.ConfigItem(sql.FieldNotIn(FieldUpdatedAt, vs...)) } // UpdatedAtGT applies the GT predicate on the "updated_at" field. func UpdatedAtGT(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUpdatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldGT(FieldUpdatedAt, v)) } // UpdatedAtGTE applies the GTE predicate on the "updated_at" field. func UpdatedAtGTE(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUpdatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldGTE(FieldUpdatedAt, v)) } // UpdatedAtLT applies the LT predicate on the "updated_at" field. func UpdatedAtLT(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUpdatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldLT(FieldUpdatedAt, v)) } // UpdatedAtLTE applies the LTE predicate on the "updated_at" field. func UpdatedAtLTE(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUpdatedAt), v)) - }) -} - -// UpdatedAtIsNil applies the IsNil predicate on the "updated_at" field. -func UpdatedAtIsNil() predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUpdatedAt))) - }) -} - -// UpdatedAtNotNil applies the NotNil predicate on the "updated_at" field. -func UpdatedAtNotNil() predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUpdatedAt))) - }) + return predicate.ConfigItem(sql.FieldLTE(FieldUpdatedAt, v)) } // NameEQ applies the EQ predicate on the "name" field. func NameEQ(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldName, v)) } // NameNEQ applies the NEQ predicate on the "name" field. func NameNEQ(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldNEQ(FieldName, v)) } // NameIn applies the In predicate on the "name" field. func NameIn(vs ...string) predicate.ConfigItem { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldName), v...)) - }) + return predicate.ConfigItem(sql.FieldIn(FieldName, vs...)) } // NameNotIn applies the NotIn predicate on the "name" field. func NameNotIn(vs ...string) predicate.ConfigItem { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldName), v...)) - }) + return predicate.ConfigItem(sql.FieldNotIn(FieldName, vs...)) } // NameGT applies the GT predicate on the "name" field. func NameGT(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldGT(FieldName, v)) } // NameGTE applies the GTE predicate on the "name" field. func NameGTE(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldGTE(FieldName, v)) } // NameLT applies the LT predicate on the "name" field. func NameLT(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldLT(FieldName, v)) } // NameLTE applies the LTE predicate on the "name" field. func NameLTE(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldLTE(FieldName, v)) } // NameContains applies the Contains predicate on the "name" field. func NameContains(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldContains(FieldName, v)) } // NameHasPrefix applies the HasPrefix predicate on the "name" field. func NameHasPrefix(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldHasPrefix(FieldName, v)) } // NameHasSuffix applies the HasSuffix predicate on the "name" field. func NameHasSuffix(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldHasSuffix(FieldName, v)) } // NameEqualFold applies the EqualFold predicate on the "name" field. func NameEqualFold(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldEqualFold(FieldName, v)) } // NameContainsFold applies the ContainsFold predicate on the "name" field. func NameContainsFold(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldContainsFold(FieldName, v)) } // ValueEQ applies the EQ predicate on the "value" field. func ValueEQ(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldValue, v)) } // ValueNEQ applies the NEQ predicate on the "value" field. func ValueNEQ(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldNEQ(FieldValue, v)) } // ValueIn applies the In predicate on the "value" field. func ValueIn(vs ...string) predicate.ConfigItem { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldValue), v...)) - }) + return predicate.ConfigItem(sql.FieldIn(FieldValue, vs...)) } // ValueNotIn applies the NotIn predicate on the "value" field. func ValueNotIn(vs ...string) predicate.ConfigItem { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldValue), v...)) - }) + return predicate.ConfigItem(sql.FieldNotIn(FieldValue, vs...)) } // ValueGT applies the GT predicate on the "value" field. func ValueGT(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldGT(FieldValue, v)) } // ValueGTE applies the GTE predicate on the "value" field. func ValueGTE(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldGTE(FieldValue, v)) } // ValueLT applies the LT predicate on the "value" field. func ValueLT(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldLT(FieldValue, v)) } // ValueLTE applies the LTE predicate on the "value" field. func ValueLTE(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldLTE(FieldValue, v)) } // ValueContains applies the Contains predicate on the "value" field. func ValueContains(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldContains(FieldValue, v)) } // ValueHasPrefix applies the HasPrefix predicate on the "value" field. func ValueHasPrefix(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldHasPrefix(FieldValue, v)) } // ValueHasSuffix applies the HasSuffix predicate on the "value" field. func ValueHasSuffix(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldHasSuffix(FieldValue, v)) } // ValueEqualFold applies the EqualFold predicate on the "value" field. func ValueEqualFold(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldEqualFold(FieldValue, v)) } // ValueContainsFold applies the ContainsFold predicate on the "value" field. func ValueContainsFold(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldContainsFold(FieldValue, v)) } // And groups predicates with the AND operator between them. func And(predicates ...predicate.ConfigItem) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.ConfigItem(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.ConfigItem) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.ConfigItem(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.ConfigItem) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.ConfigItem(sql.NotPredicates(p)) } diff --git a/pkg/database/ent/configitem_create.go b/pkg/database/ent/configitem_create.go index 736e6a50514..a2679927aee 100644 --- a/pkg/database/ent/configitem_create.go +++ b/pkg/database/ent/configitem_create.go @@ -67,50 +67,8 @@ func (cic *ConfigItemCreate) Mutation() *ConfigItemMutation { // Save creates the ConfigItem in the database. func (cic *ConfigItemCreate) Save(ctx context.Context) (*ConfigItem, error) { - var ( - err error - node *ConfigItem - ) cic.defaults() - if len(cic.hooks) == 0 { - if err = cic.check(); err != nil { - return nil, err - } - node, err = cic.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*ConfigItemMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = cic.check(); err != nil { - return nil, err - } - cic.mutation = mutation - if node, err = cic.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(cic.hooks) - 1; i >= 0; i-- { - if cic.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = cic.hooks[i](mut) - } - v, err := mut.Mutate(ctx, cic.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*ConfigItem) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from ConfigItemMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, cic.sqlSave, cic.mutation, cic.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -149,6 +107,12 @@ func (cic *ConfigItemCreate) defaults() { // check runs all checks and user-defined validators on the builder. func (cic *ConfigItemCreate) check() error { + if _, ok := cic.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "ConfigItem.created_at"`)} + } + if _, ok := cic.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "ConfigItem.updated_at"`)} + } if _, ok := cic.mutation.Name(); !ok { return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "ConfigItem.name"`)} } @@ -159,6 +123,9 @@ func (cic *ConfigItemCreate) check() error { } func (cic *ConfigItemCreate) sqlSave(ctx context.Context) (*ConfigItem, error) { + if err := cic.check(); err != nil { + return nil, err + } _node, _spec := cic.createSpec() if err := sqlgraph.CreateNode(ctx, cic.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -168,50 +135,30 @@ func (cic *ConfigItemCreate) sqlSave(ctx context.Context) (*ConfigItem, error) { } id := _spec.ID.Value.(int64) _node.ID = int(id) + cic.mutation.id = &_node.ID + cic.mutation.done = true return _node, nil } func (cic *ConfigItemCreate) createSpec() (*ConfigItem, *sqlgraph.CreateSpec) { var ( _node = &ConfigItem{config: cic.config} - _spec = &sqlgraph.CreateSpec{ - Table: configitem.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: configitem.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(configitem.Table, sqlgraph.NewFieldSpec(configitem.FieldID, field.TypeInt)) ) if value, ok := cic.mutation.CreatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: configitem.FieldCreatedAt, - }) - _node.CreatedAt = &value + _spec.SetField(configitem.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value } if value, ok := cic.mutation.UpdatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: configitem.FieldUpdatedAt, - }) - _node.UpdatedAt = &value + _spec.SetField(configitem.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value } if value, ok := cic.mutation.Name(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: configitem.FieldName, - }) + _spec.SetField(configitem.FieldName, field.TypeString, value) _node.Name = value } if value, ok := cic.mutation.Value(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: configitem.FieldValue, - }) + _spec.SetField(configitem.FieldValue, field.TypeString, value) _node.Value = value } return _node, _spec @@ -220,11 +167,15 @@ func (cic *ConfigItemCreate) createSpec() (*ConfigItem, *sqlgraph.CreateSpec) { // ConfigItemCreateBulk is the builder for creating many ConfigItem entities in bulk. type ConfigItemCreateBulk struct { config + err error builders []*ConfigItemCreate } // Save creates the ConfigItem entities in the database. func (cicb *ConfigItemCreateBulk) Save(ctx context.Context) ([]*ConfigItem, error) { + if cicb.err != nil { + return nil, cicb.err + } specs := make([]*sqlgraph.CreateSpec, len(cicb.builders)) nodes := make([]*ConfigItem, len(cicb.builders)) mutators := make([]Mutator, len(cicb.builders)) @@ -241,8 +192,8 @@ func (cicb *ConfigItemCreateBulk) Save(ctx context.Context) ([]*ConfigItem, erro return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, cicb.builders[i+1].mutation) } else { diff --git a/pkg/database/ent/configitem_delete.go b/pkg/database/ent/configitem_delete.go index 223fa9eefbf..a5dc811f60d 100644 --- a/pkg/database/ent/configitem_delete.go +++ b/pkg/database/ent/configitem_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (cid *ConfigItemDelete) Where(ps ...predicate.ConfigItem) *ConfigItemDelete // Exec executes the deletion query and returns how many vertices were deleted. func (cid *ConfigItemDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(cid.hooks) == 0 { - affected, err = cid.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*ConfigItemMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - cid.mutation = mutation - affected, err = cid.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(cid.hooks) - 1; i >= 0; i-- { - if cid.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = cid.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, cid.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, cid.sqlExec, cid.mutation, cid.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (cid *ConfigItemDelete) ExecX(ctx context.Context) int { } func (cid *ConfigItemDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: configitem.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: configitem.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(configitem.Table, sqlgraph.NewFieldSpec(configitem.FieldID, field.TypeInt)) if ps := cid.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (cid *ConfigItemDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + cid.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type ConfigItemDeleteOne struct { cid *ConfigItemDelete } +// Where appends a list predicates to the ConfigItemDelete builder. +func (cido *ConfigItemDeleteOne) Where(ps ...predicate.ConfigItem) *ConfigItemDeleteOne { + cido.cid.mutation.Where(ps...) + return cido +} + // Exec executes the deletion query. func (cido *ConfigItemDeleteOne) Exec(ctx context.Context) error { n, err := cido.cid.Exec(ctx) @@ -111,5 +82,7 @@ func (cido *ConfigItemDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (cido *ConfigItemDeleteOne) ExecX(ctx context.Context) { - cido.cid.ExecX(ctx) + if err := cido.Exec(ctx); err != nil { + panic(err) + } } diff --git a/pkg/database/ent/configitem_query.go b/pkg/database/ent/configitem_query.go index 6c9e6732a9b..f68b8953ddb 100644 --- a/pkg/database/ent/configitem_query.go +++ b/pkg/database/ent/configitem_query.go @@ -17,11 +17,9 @@ import ( // ConfigItemQuery is the builder for querying ConfigItem entities. type ConfigItemQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string + ctx *QueryContext + order []configitem.OrderOption + inters []Interceptor predicates []predicate.ConfigItem // intermediate query (i.e. traversal path). sql *sql.Selector @@ -34,27 +32,27 @@ func (ciq *ConfigItemQuery) Where(ps ...predicate.ConfigItem) *ConfigItemQuery { return ciq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (ciq *ConfigItemQuery) Limit(limit int) *ConfigItemQuery { - ciq.limit = &limit + ciq.ctx.Limit = &limit return ciq } -// Offset adds an offset step to the query. +// Offset to start from. func (ciq *ConfigItemQuery) Offset(offset int) *ConfigItemQuery { - ciq.offset = &offset + ciq.ctx.Offset = &offset return ciq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (ciq *ConfigItemQuery) Unique(unique bool) *ConfigItemQuery { - ciq.unique = &unique + ciq.ctx.Unique = &unique return ciq } -// Order adds an order step to the query. -func (ciq *ConfigItemQuery) Order(o ...OrderFunc) *ConfigItemQuery { +// Order specifies how the records should be ordered. +func (ciq *ConfigItemQuery) Order(o ...configitem.OrderOption) *ConfigItemQuery { ciq.order = append(ciq.order, o...) return ciq } @@ -62,7 +60,7 @@ func (ciq *ConfigItemQuery) Order(o ...OrderFunc) *ConfigItemQuery { // First returns the first ConfigItem entity from the query. // Returns a *NotFoundError when no ConfigItem was found. func (ciq *ConfigItemQuery) First(ctx context.Context) (*ConfigItem, error) { - nodes, err := ciq.Limit(1).All(ctx) + nodes, err := ciq.Limit(1).All(setContextOp(ctx, ciq.ctx, "First")) if err != nil { return nil, err } @@ -85,7 +83,7 @@ func (ciq *ConfigItemQuery) FirstX(ctx context.Context) *ConfigItem { // Returns a *NotFoundError when no ConfigItem ID was found. func (ciq *ConfigItemQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = ciq.Limit(1).IDs(ctx); err != nil { + if ids, err = ciq.Limit(1).IDs(setContextOp(ctx, ciq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -108,7 +106,7 @@ func (ciq *ConfigItemQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one ConfigItem entity is found. // Returns a *NotFoundError when no ConfigItem entities are found. func (ciq *ConfigItemQuery) Only(ctx context.Context) (*ConfigItem, error) { - nodes, err := ciq.Limit(2).All(ctx) + nodes, err := ciq.Limit(2).All(setContextOp(ctx, ciq.ctx, "Only")) if err != nil { return nil, err } @@ -136,7 +134,7 @@ func (ciq *ConfigItemQuery) OnlyX(ctx context.Context) *ConfigItem { // Returns a *NotFoundError when no entities are found. func (ciq *ConfigItemQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = ciq.Limit(2).IDs(ctx); err != nil { + if ids, err = ciq.Limit(2).IDs(setContextOp(ctx, ciq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -161,10 +159,12 @@ func (ciq *ConfigItemQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of ConfigItems. func (ciq *ConfigItemQuery) All(ctx context.Context) ([]*ConfigItem, error) { + ctx = setContextOp(ctx, ciq.ctx, "All") if err := ciq.prepareQuery(ctx); err != nil { return nil, err } - return ciq.sqlAll(ctx) + qr := querierAll[[]*ConfigItem, *ConfigItemQuery]() + return withInterceptors[[]*ConfigItem](ctx, ciq, qr, ciq.inters) } // AllX is like All, but panics if an error occurs. @@ -177,9 +177,12 @@ func (ciq *ConfigItemQuery) AllX(ctx context.Context) []*ConfigItem { } // IDs executes the query and returns a list of ConfigItem IDs. -func (ciq *ConfigItemQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := ciq.Select(configitem.FieldID).Scan(ctx, &ids); err != nil { +func (ciq *ConfigItemQuery) IDs(ctx context.Context) (ids []int, err error) { + if ciq.ctx.Unique == nil && ciq.path != nil { + ciq.Unique(true) + } + ctx = setContextOp(ctx, ciq.ctx, "IDs") + if err = ciq.Select(configitem.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -196,10 +199,11 @@ func (ciq *ConfigItemQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (ciq *ConfigItemQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, ciq.ctx, "Count") if err := ciq.prepareQuery(ctx); err != nil { return 0, err } - return ciq.sqlCount(ctx) + return withInterceptors[int](ctx, ciq, querierCount[*ConfigItemQuery](), ciq.inters) } // CountX is like Count, but panics if an error occurs. @@ -213,10 +217,15 @@ func (ciq *ConfigItemQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (ciq *ConfigItemQuery) Exist(ctx context.Context) (bool, error) { - if err := ciq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, ciq.ctx, "Exist") + switch _, err := ciq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return ciq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -236,14 +245,13 @@ func (ciq *ConfigItemQuery) Clone() *ConfigItemQuery { } return &ConfigItemQuery{ config: ciq.config, - limit: ciq.limit, - offset: ciq.offset, - order: append([]OrderFunc{}, ciq.order...), + ctx: ciq.ctx.Clone(), + order: append([]configitem.OrderOption{}, ciq.order...), + inters: append([]Interceptor{}, ciq.inters...), predicates: append([]predicate.ConfigItem{}, ciq.predicates...), // clone intermediate query. - sql: ciq.sql.Clone(), - path: ciq.path, - unique: ciq.unique, + sql: ciq.sql.Clone(), + path: ciq.path, } } @@ -262,16 +270,11 @@ func (ciq *ConfigItemQuery) Clone() *ConfigItemQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (ciq *ConfigItemQuery) GroupBy(field string, fields ...string) *ConfigItemGroupBy { - grbuild := &ConfigItemGroupBy{config: ciq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := ciq.prepareQuery(ctx); err != nil { - return nil, err - } - return ciq.sqlQuery(ctx), nil - } + ciq.ctx.Fields = append([]string{field}, fields...) + grbuild := &ConfigItemGroupBy{build: ciq} + grbuild.flds = &ciq.ctx.Fields grbuild.label = configitem.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -288,15 +291,30 @@ func (ciq *ConfigItemQuery) GroupBy(field string, fields ...string) *ConfigItemG // Select(configitem.FieldCreatedAt). // Scan(ctx, &v) func (ciq *ConfigItemQuery) Select(fields ...string) *ConfigItemSelect { - ciq.fields = append(ciq.fields, fields...) - selbuild := &ConfigItemSelect{ConfigItemQuery: ciq} - selbuild.label = configitem.Label - selbuild.flds, selbuild.scan = &ciq.fields, selbuild.Scan - return selbuild + ciq.ctx.Fields = append(ciq.ctx.Fields, fields...) + sbuild := &ConfigItemSelect{ConfigItemQuery: ciq} + sbuild.label = configitem.Label + sbuild.flds, sbuild.scan = &ciq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a ConfigItemSelect configured with the given aggregations. +func (ciq *ConfigItemQuery) Aggregate(fns ...AggregateFunc) *ConfigItemSelect { + return ciq.Select().Aggregate(fns...) } func (ciq *ConfigItemQuery) prepareQuery(ctx context.Context) error { - for _, f := range ciq.fields { + for _, inter := range ciq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, ciq); err != nil { + return err + } + } + } + for _, f := range ciq.ctx.Fields { if !configitem.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -338,41 +356,22 @@ func (ciq *ConfigItemQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]* func (ciq *ConfigItemQuery) sqlCount(ctx context.Context) (int, error) { _spec := ciq.querySpec() - _spec.Node.Columns = ciq.fields - if len(ciq.fields) > 0 { - _spec.Unique = ciq.unique != nil && *ciq.unique + _spec.Node.Columns = ciq.ctx.Fields + if len(ciq.ctx.Fields) > 0 { + _spec.Unique = ciq.ctx.Unique != nil && *ciq.ctx.Unique } return sqlgraph.CountNodes(ctx, ciq.driver, _spec) } -func (ciq *ConfigItemQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := ciq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (ciq *ConfigItemQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: configitem.Table, - Columns: configitem.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: configitem.FieldID, - }, - }, - From: ciq.sql, - Unique: true, - } - if unique := ciq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(configitem.Table, configitem.Columns, sqlgraph.NewFieldSpec(configitem.FieldID, field.TypeInt)) + _spec.From = ciq.sql + if unique := ciq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if ciq.path != nil { + _spec.Unique = true } - if fields := ciq.fields; len(fields) > 0 { + if fields := ciq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, configitem.FieldID) for i := range fields { @@ -388,10 +387,10 @@ func (ciq *ConfigItemQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := ciq.limit; limit != nil { + if limit := ciq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := ciq.offset; offset != nil { + if offset := ciq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := ciq.order; len(ps) > 0 { @@ -407,7 +406,7 @@ func (ciq *ConfigItemQuery) querySpec() *sqlgraph.QuerySpec { func (ciq *ConfigItemQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(ciq.driver.Dialect()) t1 := builder.Table(configitem.Table) - columns := ciq.fields + columns := ciq.ctx.Fields if len(columns) == 0 { columns = configitem.Columns } @@ -416,7 +415,7 @@ func (ciq *ConfigItemQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = ciq.sql selector.Select(selector.Columns(columns...)...) } - if ciq.unique != nil && *ciq.unique { + if ciq.ctx.Unique != nil && *ciq.ctx.Unique { selector.Distinct() } for _, p := range ciq.predicates { @@ -425,12 +424,12 @@ func (ciq *ConfigItemQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range ciq.order { p(selector) } - if offset := ciq.offset; offset != nil { + if offset := ciq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := ciq.limit; limit != nil { + if limit := ciq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -438,13 +437,8 @@ func (ciq *ConfigItemQuery) sqlQuery(ctx context.Context) *sql.Selector { // ConfigItemGroupBy is the group-by builder for ConfigItem entities. type ConfigItemGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *ConfigItemQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -453,74 +447,77 @@ func (cigb *ConfigItemGroupBy) Aggregate(fns ...AggregateFunc) *ConfigItemGroupB return cigb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (cigb *ConfigItemGroupBy) Scan(ctx context.Context, v any) error { - query, err := cigb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, cigb.build.ctx, "GroupBy") + if err := cigb.build.prepareQuery(ctx); err != nil { return err } - cigb.sql = query - return cigb.sqlScan(ctx, v) + return scanWithInterceptors[*ConfigItemQuery, *ConfigItemGroupBy](ctx, cigb.build, cigb, cigb.build.inters, v) } -func (cigb *ConfigItemGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range cigb.fields { - if !configitem.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (cigb *ConfigItemGroupBy) sqlScan(ctx context.Context, root *ConfigItemQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(cigb.fns)) + for _, fn := range cigb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*cigb.flds)+len(cigb.fns)) + for _, f := range *cigb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := cigb.sqlQuery() + selector.GroupBy(selector.Columns(*cigb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := cigb.driver.Query(ctx, query, args, rows); err != nil { + if err := cigb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (cigb *ConfigItemGroupBy) sqlQuery() *sql.Selector { - selector := cigb.sql.Select() - aggregation := make([]string, 0, len(cigb.fns)) - for _, fn := range cigb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(cigb.fields)+len(cigb.fns)) - for _, f := range cigb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(cigb.fields...)...) -} - // ConfigItemSelect is the builder for selecting fields of ConfigItem entities. type ConfigItemSelect struct { *ConfigItemQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (cis *ConfigItemSelect) Aggregate(fns ...AggregateFunc) *ConfigItemSelect { + cis.fns = append(cis.fns, fns...) + return cis } // Scan applies the selector query and scans the result into the given value. func (cis *ConfigItemSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, cis.ctx, "Select") if err := cis.prepareQuery(ctx); err != nil { return err } - cis.sql = cis.ConfigItemQuery.sqlQuery(ctx) - return cis.sqlScan(ctx, v) + return scanWithInterceptors[*ConfigItemQuery, *ConfigItemSelect](ctx, cis.ConfigItemQuery, cis, cis.inters, v) } -func (cis *ConfigItemSelect) sqlScan(ctx context.Context, v any) error { +func (cis *ConfigItemSelect) sqlScan(ctx context.Context, root *ConfigItemQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(cis.fns)) + for _, fn := range cis.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*cis.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := cis.sql.Query() + query, args := selector.Query() if err := cis.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/pkg/database/ent/configitem_update.go b/pkg/database/ent/configitem_update.go index e591347a0c3..82309459e76 100644 --- a/pkg/database/ent/configitem_update.go +++ b/pkg/database/ent/configitem_update.go @@ -28,42 +28,26 @@ func (ciu *ConfigItemUpdate) Where(ps ...predicate.ConfigItem) *ConfigItemUpdate return ciu } -// SetCreatedAt sets the "created_at" field. -func (ciu *ConfigItemUpdate) SetCreatedAt(t time.Time) *ConfigItemUpdate { - ciu.mutation.SetCreatedAt(t) - return ciu -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (ciu *ConfigItemUpdate) ClearCreatedAt() *ConfigItemUpdate { - ciu.mutation.ClearCreatedAt() - return ciu -} - // SetUpdatedAt sets the "updated_at" field. func (ciu *ConfigItemUpdate) SetUpdatedAt(t time.Time) *ConfigItemUpdate { ciu.mutation.SetUpdatedAt(t) return ciu } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (ciu *ConfigItemUpdate) ClearUpdatedAt() *ConfigItemUpdate { - ciu.mutation.ClearUpdatedAt() - return ciu -} - -// SetName sets the "name" field. -func (ciu *ConfigItemUpdate) SetName(s string) *ConfigItemUpdate { - ciu.mutation.SetName(s) - return ciu -} - // SetValue sets the "value" field. func (ciu *ConfigItemUpdate) SetValue(s string) *ConfigItemUpdate { ciu.mutation.SetValue(s) return ciu } +// SetNillableValue sets the "value" field if the given value is not nil. +func (ciu *ConfigItemUpdate) SetNillableValue(s *string) *ConfigItemUpdate { + if s != nil { + ciu.SetValue(*s) + } + return ciu +} + // Mutation returns the ConfigItemMutation object of the builder. func (ciu *ConfigItemUpdate) Mutation() *ConfigItemMutation { return ciu.mutation @@ -71,35 +55,8 @@ func (ciu *ConfigItemUpdate) Mutation() *ConfigItemMutation { // Save executes the query and returns the number of nodes affected by the update operation. func (ciu *ConfigItemUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) ciu.defaults() - if len(ciu.hooks) == 0 { - affected, err = ciu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*ConfigItemMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - ciu.mutation = mutation - affected, err = ciu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(ciu.hooks) - 1; i >= 0; i-- { - if ciu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = ciu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, ciu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, ciu.sqlSave, ciu.mutation, ciu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -126,27 +83,14 @@ func (ciu *ConfigItemUpdate) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (ciu *ConfigItemUpdate) defaults() { - if _, ok := ciu.mutation.CreatedAt(); !ok && !ciu.mutation.CreatedAtCleared() { - v := configitem.UpdateDefaultCreatedAt() - ciu.mutation.SetCreatedAt(v) - } - if _, ok := ciu.mutation.UpdatedAt(); !ok && !ciu.mutation.UpdatedAtCleared() { + if _, ok := ciu.mutation.UpdatedAt(); !ok { v := configitem.UpdateDefaultUpdatedAt() ciu.mutation.SetUpdatedAt(v) } } func (ciu *ConfigItemUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: configitem.Table, - Columns: configitem.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: configitem.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(configitem.Table, configitem.Columns, sqlgraph.NewFieldSpec(configitem.FieldID, field.TypeInt)) if ps := ciu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -154,45 +98,11 @@ func (ciu *ConfigItemUpdate) sqlSave(ctx context.Context) (n int, err error) { } } } - if value, ok := ciu.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: configitem.FieldCreatedAt, - }) - } - if ciu.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: configitem.FieldCreatedAt, - }) - } if value, ok := ciu.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: configitem.FieldUpdatedAt, - }) - } - if ciu.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: configitem.FieldUpdatedAt, - }) - } - if value, ok := ciu.mutation.Name(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: configitem.FieldName, - }) + _spec.SetField(configitem.FieldUpdatedAt, field.TypeTime, value) } if value, ok := ciu.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: configitem.FieldValue, - }) + _spec.SetField(configitem.FieldValue, field.TypeString, value) } if n, err = sqlgraph.UpdateNodes(ctx, ciu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { @@ -202,6 +112,7 @@ func (ciu *ConfigItemUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + ciu.mutation.done = true return n, nil } @@ -213,47 +124,37 @@ type ConfigItemUpdateOne struct { mutation *ConfigItemMutation } -// SetCreatedAt sets the "created_at" field. -func (ciuo *ConfigItemUpdateOne) SetCreatedAt(t time.Time) *ConfigItemUpdateOne { - ciuo.mutation.SetCreatedAt(t) - return ciuo -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (ciuo *ConfigItemUpdateOne) ClearCreatedAt() *ConfigItemUpdateOne { - ciuo.mutation.ClearCreatedAt() - return ciuo -} - // SetUpdatedAt sets the "updated_at" field. func (ciuo *ConfigItemUpdateOne) SetUpdatedAt(t time.Time) *ConfigItemUpdateOne { ciuo.mutation.SetUpdatedAt(t) return ciuo } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (ciuo *ConfigItemUpdateOne) ClearUpdatedAt() *ConfigItemUpdateOne { - ciuo.mutation.ClearUpdatedAt() - return ciuo -} - -// SetName sets the "name" field. -func (ciuo *ConfigItemUpdateOne) SetName(s string) *ConfigItemUpdateOne { - ciuo.mutation.SetName(s) - return ciuo -} - // SetValue sets the "value" field. func (ciuo *ConfigItemUpdateOne) SetValue(s string) *ConfigItemUpdateOne { ciuo.mutation.SetValue(s) return ciuo } +// SetNillableValue sets the "value" field if the given value is not nil. +func (ciuo *ConfigItemUpdateOne) SetNillableValue(s *string) *ConfigItemUpdateOne { + if s != nil { + ciuo.SetValue(*s) + } + return ciuo +} + // Mutation returns the ConfigItemMutation object of the builder. func (ciuo *ConfigItemUpdateOne) Mutation() *ConfigItemMutation { return ciuo.mutation } +// Where appends a list predicates to the ConfigItemUpdate builder. +func (ciuo *ConfigItemUpdateOne) Where(ps ...predicate.ConfigItem) *ConfigItemUpdateOne { + ciuo.mutation.Where(ps...) + return ciuo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (ciuo *ConfigItemUpdateOne) Select(field string, fields ...string) *ConfigItemUpdateOne { @@ -263,41 +164,8 @@ func (ciuo *ConfigItemUpdateOne) Select(field string, fields ...string) *ConfigI // Save executes the query and returns the updated ConfigItem entity. func (ciuo *ConfigItemUpdateOne) Save(ctx context.Context) (*ConfigItem, error) { - var ( - err error - node *ConfigItem - ) ciuo.defaults() - if len(ciuo.hooks) == 0 { - node, err = ciuo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*ConfigItemMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - ciuo.mutation = mutation - node, err = ciuo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(ciuo.hooks) - 1; i >= 0; i-- { - if ciuo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = ciuo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, ciuo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*ConfigItem) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from ConfigItemMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, ciuo.sqlSave, ciuo.mutation, ciuo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -324,27 +192,14 @@ func (ciuo *ConfigItemUpdateOne) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (ciuo *ConfigItemUpdateOne) defaults() { - if _, ok := ciuo.mutation.CreatedAt(); !ok && !ciuo.mutation.CreatedAtCleared() { - v := configitem.UpdateDefaultCreatedAt() - ciuo.mutation.SetCreatedAt(v) - } - if _, ok := ciuo.mutation.UpdatedAt(); !ok && !ciuo.mutation.UpdatedAtCleared() { + if _, ok := ciuo.mutation.UpdatedAt(); !ok { v := configitem.UpdateDefaultUpdatedAt() ciuo.mutation.SetUpdatedAt(v) } } func (ciuo *ConfigItemUpdateOne) sqlSave(ctx context.Context) (_node *ConfigItem, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: configitem.Table, - Columns: configitem.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: configitem.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(configitem.Table, configitem.Columns, sqlgraph.NewFieldSpec(configitem.FieldID, field.TypeInt)) id, ok := ciuo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "ConfigItem.id" for update`)} @@ -369,45 +224,11 @@ func (ciuo *ConfigItemUpdateOne) sqlSave(ctx context.Context) (_node *ConfigItem } } } - if value, ok := ciuo.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: configitem.FieldCreatedAt, - }) - } - if ciuo.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: configitem.FieldCreatedAt, - }) - } if value, ok := ciuo.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: configitem.FieldUpdatedAt, - }) - } - if ciuo.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: configitem.FieldUpdatedAt, - }) - } - if value, ok := ciuo.mutation.Name(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: configitem.FieldName, - }) + _spec.SetField(configitem.FieldUpdatedAt, field.TypeTime, value) } if value, ok := ciuo.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: configitem.FieldValue, - }) + _spec.SetField(configitem.FieldValue, field.TypeString, value) } _node = &ConfigItem{config: ciuo.config} _spec.Assign = _node.assignValues @@ -420,5 +241,6 @@ func (ciuo *ConfigItemUpdateOne) sqlSave(ctx context.Context) (_node *ConfigItem } return nil, err } + ciuo.mutation.done = true return _node, nil } diff --git a/pkg/database/ent/context.go b/pkg/database/ent/context.go deleted file mode 100644 index 7811bfa2349..00000000000 --- a/pkg/database/ent/context.go +++ /dev/null @@ -1,33 +0,0 @@ -// Code generated by ent, DO NOT EDIT. - -package ent - -import ( - "context" -) - -type clientCtxKey struct{} - -// FromContext returns a Client stored inside a context, or nil if there isn't one. -func FromContext(ctx context.Context) *Client { - c, _ := ctx.Value(clientCtxKey{}).(*Client) - return c -} - -// NewContext returns a new context with the given Client attached. -func NewContext(parent context.Context, c *Client) context.Context { - return context.WithValue(parent, clientCtxKey{}, c) -} - -type txCtxKey struct{} - -// TxFromContext returns a Tx stored inside a context, or nil if there isn't one. -func TxFromContext(ctx context.Context) *Tx { - tx, _ := ctx.Value(txCtxKey{}).(*Tx) - return tx -} - -// NewTxContext returns a new context with the given Tx attached. -func NewTxContext(parent context.Context, tx *Tx) context.Context { - return context.WithValue(parent, txCtxKey{}, tx) -} diff --git a/pkg/database/ent/decision.go b/pkg/database/ent/decision.go index c969e576724..4a6dc728509 100644 --- a/pkg/database/ent/decision.go +++ b/pkg/database/ent/decision.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" @@ -18,9 +19,9 @@ type Decision struct { // ID of the ent. ID int `json:"id,omitempty"` // CreatedAt holds the value of the "created_at" field. - CreatedAt *time.Time `json:"created_at,omitempty"` + CreatedAt time.Time `json:"created_at,omitempty"` // UpdatedAt holds the value of the "updated_at" field. - UpdatedAt *time.Time `json:"updated_at,omitempty"` + UpdatedAt time.Time `json:"updated_at,omitempty"` // Until holds the value of the "until" field. Until *time.Time `json:"until,omitempty"` // Scenario holds the value of the "scenario" field. @@ -51,7 +52,8 @@ type Decision struct { AlertDecisions int `json:"alert_decisions,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the DecisionQuery when eager-loading is set. - Edges DecisionEdges `json:"edges"` + Edges DecisionEdges `json:"edges"` + selectValues sql.SelectValues } // DecisionEdges holds the relations/edges for other nodes in the graph. @@ -66,12 +68,10 @@ type DecisionEdges struct { // OwnerOrErr returns the Owner value or an error if the edge // was not loaded in eager-loading, or loaded but was not found. func (e DecisionEdges) OwnerOrErr() (*Alert, error) { - if e.loadedTypes[0] { - if e.Owner == nil { - // Edge was loaded but was not found. - return nil, &NotFoundError{label: alert.Label} - } + if e.Owner != nil { return e.Owner, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: alert.Label} } return nil, &NotLoadedError{edge: "owner"} } @@ -90,7 +90,7 @@ func (*Decision) scanValues(columns []string) ([]any, error) { case decision.FieldCreatedAt, decision.FieldUpdatedAt, decision.FieldUntil: values[i] = new(sql.NullTime) default: - return nil, fmt.Errorf("unexpected column %q for type Decision", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -114,15 +114,13 @@ func (d *Decision) assignValues(columns []string, values []any) error { if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) } else if value.Valid { - d.CreatedAt = new(time.Time) - *d.CreatedAt = value.Time + d.CreatedAt = value.Time } case decision.FieldUpdatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field updated_at", values[i]) } else if value.Valid { - d.UpdatedAt = new(time.Time) - *d.UpdatedAt = value.Time + d.UpdatedAt = value.Time } case decision.FieldUntil: if value, ok := values[i].(*sql.NullTime); !ok { @@ -209,21 +207,29 @@ func (d *Decision) assignValues(columns []string, values []any) error { } else if value.Valid { d.AlertDecisions = int(value.Int64) } + default: + d.selectValues.Set(columns[i], values[i]) } } return nil } +// GetValue returns the ent.Value that was dynamically selected and assigned to the Decision. +// This includes values selected through modifiers, order, etc. +func (d *Decision) GetValue(name string) (ent.Value, error) { + return d.selectValues.Get(name) +} + // QueryOwner queries the "owner" edge of the Decision entity. func (d *Decision) QueryOwner() *AlertQuery { - return (&DecisionClient{config: d.config}).QueryOwner(d) + return NewDecisionClient(d.config).QueryOwner(d) } // Update returns a builder for updating this Decision. // Note that you need to call Decision.Unwrap() before calling this method if this Decision // was returned from a transaction, and the transaction was committed or rolled back. func (d *Decision) Update() *DecisionUpdateOne { - return (&DecisionClient{config: d.config}).UpdateOne(d) + return NewDecisionClient(d.config).UpdateOne(d) } // Unwrap unwraps the Decision entity that was returned from a transaction after it was closed, @@ -242,15 +248,11 @@ func (d *Decision) String() string { var builder strings.Builder builder.WriteString("Decision(") builder.WriteString(fmt.Sprintf("id=%v, ", d.ID)) - if v := d.CreatedAt; v != nil { - builder.WriteString("created_at=") - builder.WriteString(v.Format(time.ANSIC)) - } + builder.WriteString("created_at=") + builder.WriteString(d.CreatedAt.Format(time.ANSIC)) builder.WriteString(", ") - if v := d.UpdatedAt; v != nil { - builder.WriteString("updated_at=") - builder.WriteString(v.Format(time.ANSIC)) - } + builder.WriteString("updated_at=") + builder.WriteString(d.UpdatedAt.Format(time.ANSIC)) builder.WriteString(", ") if v := d.Until; v != nil { builder.WriteString("until=") @@ -301,9 +303,3 @@ func (d *Decision) String() string { // Decisions is a parsable slice of Decision. type Decisions []*Decision - -func (d Decisions) config(cfg config) { - for _i := range d { - d[_i].config = cfg - } -} diff --git a/pkg/database/ent/decision/decision.go b/pkg/database/ent/decision/decision.go index a0012d940a8..38c9721db48 100644 --- a/pkg/database/ent/decision/decision.go +++ b/pkg/database/ent/decision/decision.go @@ -4,6 +4,9 @@ package decision import ( "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" ) const ( @@ -90,8 +93,6 @@ func ValidColumn(column string) bool { var ( // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time - // UpdateDefaultCreatedAt holds the default value on update for the "created_at" field. - UpdateDefaultCreatedAt func() time.Time // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. DefaultUpdatedAt func() time.Time // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. @@ -99,3 +100,105 @@ var ( // DefaultSimulated holds the default value on creation for the "simulated" field. DefaultSimulated bool ) + +// OrderOption defines the ordering options for the Decision queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByUntil orders the results by the until field. +func ByUntil(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUntil, opts...).ToFunc() +} + +// ByScenario orders the results by the scenario field. +func ByScenario(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldScenario, opts...).ToFunc() +} + +// ByType orders the results by the type field. +func ByType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldType, opts...).ToFunc() +} + +// ByStartIP orders the results by the start_ip field. +func ByStartIP(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStartIP, opts...).ToFunc() +} + +// ByEndIP orders the results by the end_ip field. +func ByEndIP(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldEndIP, opts...).ToFunc() +} + +// ByStartSuffix orders the results by the start_suffix field. +func ByStartSuffix(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStartSuffix, opts...).ToFunc() +} + +// ByEndSuffix orders the results by the end_suffix field. +func ByEndSuffix(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldEndSuffix, opts...).ToFunc() +} + +// ByIPSize orders the results by the ip_size field. +func ByIPSize(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIPSize, opts...).ToFunc() +} + +// ByScope orders the results by the scope field. +func ByScope(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldScope, opts...).ToFunc() +} + +// ByValue orders the results by the value field. +func ByValue(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldValue, opts...).ToFunc() +} + +// ByOrigin orders the results by the origin field. +func ByOrigin(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldOrigin, opts...).ToFunc() +} + +// BySimulated orders the results by the simulated field. +func BySimulated(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSimulated, opts...).ToFunc() +} + +// ByUUID orders the results by the uuid field. +func ByUUID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUUID, opts...).ToFunc() +} + +// ByAlertDecisions orders the results by the alert_decisions field. +func ByAlertDecisions(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAlertDecisions, opts...).ToFunc() +} + +// ByOwnerField orders the results by owner field. +func ByOwnerField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newOwnerStep(), sql.OrderByField(field, opts...)) + } +} +func newOwnerStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(OwnerInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), + ) +} diff --git a/pkg/database/ent/decision/where.go b/pkg/database/ent/decision/where.go index 18716a4a7c1..99a1889e63e 100644 --- a/pkg/database/ent/decision/where.go +++ b/pkg/database/ent/decision/where.go @@ -12,1481 +12,947 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Decision(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Decision(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.Decision(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.Decision(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.Decision(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.Decision(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.Decision(sql.FieldLTE(FieldID, id)) } // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldCreatedAt, v)) } // UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. func UpdatedAt(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldUpdatedAt, v)) } // Until applies equality check predicate on the "until" field. It's identical to UntilEQ. func Until(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUntil), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldUntil, v)) } // Scenario applies equality check predicate on the "scenario" field. It's identical to ScenarioEQ. func Scenario(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldScenario, v)) } // Type applies equality check predicate on the "type" field. It's identical to TypeEQ. func Type(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldType, v)) } // StartIP applies equality check predicate on the "start_ip" field. It's identical to StartIPEQ. func StartIP(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStartIP), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldStartIP, v)) } // EndIP applies equality check predicate on the "end_ip" field. It's identical to EndIPEQ. func EndIP(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldEndIP), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldEndIP, v)) } // StartSuffix applies equality check predicate on the "start_suffix" field. It's identical to StartSuffixEQ. func StartSuffix(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStartSuffix), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldStartSuffix, v)) } // EndSuffix applies equality check predicate on the "end_suffix" field. It's identical to EndSuffixEQ. func EndSuffix(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldEndSuffix), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldEndSuffix, v)) } // IPSize applies equality check predicate on the "ip_size" field. It's identical to IPSizeEQ. func IPSize(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldIPSize), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldIPSize, v)) } // Scope applies equality check predicate on the "scope" field. It's identical to ScopeEQ. func Scope(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldScope, v)) } // Value applies equality check predicate on the "value" field. It's identical to ValueEQ. func Value(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldValue, v)) } // Origin applies equality check predicate on the "origin" field. It's identical to OriginEQ. func Origin(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldOrigin, v)) } // Simulated applies equality check predicate on the "simulated" field. It's identical to SimulatedEQ. func Simulated(v bool) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSimulated), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldSimulated, v)) } // UUID applies equality check predicate on the "uuid" field. It's identical to UUIDEQ. func UUID(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldUUID, v)) } // AlertDecisions applies equality check predicate on the "alert_decisions" field. It's identical to AlertDecisionsEQ. func AlertDecisions(v int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAlertDecisions), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldAlertDecisions, v)) } // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldCreatedAt, v)) } // CreatedAtNEQ applies the NEQ predicate on the "created_at" field. func CreatedAtNEQ(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldCreatedAt, v)) } // CreatedAtIn applies the In predicate on the "created_at" field. func CreatedAtIn(vs ...time.Time) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldCreatedAt), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldCreatedAt, vs...)) } // CreatedAtNotIn applies the NotIn predicate on the "created_at" field. func CreatedAtNotIn(vs ...time.Time) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldCreatedAt), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldCreatedAt, vs...)) } // CreatedAtGT applies the GT predicate on the "created_at" field. func CreatedAtGT(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldCreatedAt), v)) - }) + return predicate.Decision(sql.FieldGT(FieldCreatedAt, v)) } // CreatedAtGTE applies the GTE predicate on the "created_at" field. func CreatedAtGTE(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldCreatedAt), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldCreatedAt, v)) } // CreatedAtLT applies the LT predicate on the "created_at" field. func CreatedAtLT(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldCreatedAt), v)) - }) + return predicate.Decision(sql.FieldLT(FieldCreatedAt, v)) } // CreatedAtLTE applies the LTE predicate on the "created_at" field. func CreatedAtLTE(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldCreatedAt), v)) - }) -} - -// CreatedAtIsNil applies the IsNil predicate on the "created_at" field. -func CreatedAtIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldCreatedAt))) - }) -} - -// CreatedAtNotNil applies the NotNil predicate on the "created_at" field. -func CreatedAtNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldCreatedAt))) - }) + return predicate.Decision(sql.FieldLTE(FieldCreatedAt, v)) } // UpdatedAtEQ applies the EQ predicate on the "updated_at" field. func UpdatedAtEQ(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldUpdatedAt, v)) } // UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. func UpdatedAtNEQ(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldUpdatedAt, v)) } // UpdatedAtIn applies the In predicate on the "updated_at" field. func UpdatedAtIn(vs ...time.Time) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldUpdatedAt, vs...)) } // UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. func UpdatedAtNotIn(vs ...time.Time) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldUpdatedAt, vs...)) } // UpdatedAtGT applies the GT predicate on the "updated_at" field. func UpdatedAtGT(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Decision(sql.FieldGT(FieldUpdatedAt, v)) } // UpdatedAtGTE applies the GTE predicate on the "updated_at" field. func UpdatedAtGTE(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUpdatedAt), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldUpdatedAt, v)) } // UpdatedAtLT applies the LT predicate on the "updated_at" field. func UpdatedAtLT(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Decision(sql.FieldLT(FieldUpdatedAt, v)) } // UpdatedAtLTE applies the LTE predicate on the "updated_at" field. func UpdatedAtLTE(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUpdatedAt), v)) - }) -} - -// UpdatedAtIsNil applies the IsNil predicate on the "updated_at" field. -func UpdatedAtIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUpdatedAt))) - }) -} - -// UpdatedAtNotNil applies the NotNil predicate on the "updated_at" field. -func UpdatedAtNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUpdatedAt))) - }) + return predicate.Decision(sql.FieldLTE(FieldUpdatedAt, v)) } // UntilEQ applies the EQ predicate on the "until" field. func UntilEQ(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUntil), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldUntil, v)) } // UntilNEQ applies the NEQ predicate on the "until" field. func UntilNEQ(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUntil), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldUntil, v)) } // UntilIn applies the In predicate on the "until" field. func UntilIn(vs ...time.Time) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUntil), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldUntil, vs...)) } // UntilNotIn applies the NotIn predicate on the "until" field. func UntilNotIn(vs ...time.Time) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUntil), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldUntil, vs...)) } // UntilGT applies the GT predicate on the "until" field. func UntilGT(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUntil), v)) - }) + return predicate.Decision(sql.FieldGT(FieldUntil, v)) } // UntilGTE applies the GTE predicate on the "until" field. func UntilGTE(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUntil), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldUntil, v)) } // UntilLT applies the LT predicate on the "until" field. func UntilLT(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUntil), v)) - }) + return predicate.Decision(sql.FieldLT(FieldUntil, v)) } // UntilLTE applies the LTE predicate on the "until" field. func UntilLTE(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUntil), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldUntil, v)) } // UntilIsNil applies the IsNil predicate on the "until" field. func UntilIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUntil))) - }) + return predicate.Decision(sql.FieldIsNull(FieldUntil)) } // UntilNotNil applies the NotNil predicate on the "until" field. func UntilNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUntil))) - }) + return predicate.Decision(sql.FieldNotNull(FieldUntil)) } // ScenarioEQ applies the EQ predicate on the "scenario" field. func ScenarioEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldScenario, v)) } // ScenarioNEQ applies the NEQ predicate on the "scenario" field. func ScenarioNEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldScenario, v)) } // ScenarioIn applies the In predicate on the "scenario" field. func ScenarioIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldScenario), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldScenario, vs...)) } // ScenarioNotIn applies the NotIn predicate on the "scenario" field. func ScenarioNotIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldScenario), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldScenario, vs...)) } // ScenarioGT applies the GT predicate on the "scenario" field. func ScenarioGT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldGT(FieldScenario, v)) } // ScenarioGTE applies the GTE predicate on the "scenario" field. func ScenarioGTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldScenario, v)) } // ScenarioLT applies the LT predicate on the "scenario" field. func ScenarioLT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldLT(FieldScenario, v)) } // ScenarioLTE applies the LTE predicate on the "scenario" field. func ScenarioLTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldScenario, v)) } // ScenarioContains applies the Contains predicate on the "scenario" field. func ScenarioContains(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldContains(FieldScenario, v)) } // ScenarioHasPrefix applies the HasPrefix predicate on the "scenario" field. func ScenarioHasPrefix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldHasPrefix(FieldScenario, v)) } // ScenarioHasSuffix applies the HasSuffix predicate on the "scenario" field. func ScenarioHasSuffix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldHasSuffix(FieldScenario, v)) } // ScenarioEqualFold applies the EqualFold predicate on the "scenario" field. func ScenarioEqualFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldEqualFold(FieldScenario, v)) } // ScenarioContainsFold applies the ContainsFold predicate on the "scenario" field. func ScenarioContainsFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldContainsFold(FieldScenario, v)) } // TypeEQ applies the EQ predicate on the "type" field. func TypeEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldType, v)) } // TypeNEQ applies the NEQ predicate on the "type" field. func TypeNEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldType, v)) } // TypeIn applies the In predicate on the "type" field. func TypeIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldType), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldType, vs...)) } // TypeNotIn applies the NotIn predicate on the "type" field. func TypeNotIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldType), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldType, vs...)) } // TypeGT applies the GT predicate on the "type" field. func TypeGT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldGT(FieldType, v)) } // TypeGTE applies the GTE predicate on the "type" field. func TypeGTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldType, v)) } // TypeLT applies the LT predicate on the "type" field. func TypeLT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldLT(FieldType, v)) } // TypeLTE applies the LTE predicate on the "type" field. func TypeLTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldType, v)) } // TypeContains applies the Contains predicate on the "type" field. func TypeContains(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldContains(FieldType, v)) } // TypeHasPrefix applies the HasPrefix predicate on the "type" field. func TypeHasPrefix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldHasPrefix(FieldType, v)) } // TypeHasSuffix applies the HasSuffix predicate on the "type" field. func TypeHasSuffix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldHasSuffix(FieldType, v)) } // TypeEqualFold applies the EqualFold predicate on the "type" field. func TypeEqualFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldEqualFold(FieldType, v)) } // TypeContainsFold applies the ContainsFold predicate on the "type" field. func TypeContainsFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldContainsFold(FieldType, v)) } // StartIPEQ applies the EQ predicate on the "start_ip" field. func StartIPEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStartIP), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldStartIP, v)) } // StartIPNEQ applies the NEQ predicate on the "start_ip" field. func StartIPNEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldStartIP), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldStartIP, v)) } // StartIPIn applies the In predicate on the "start_ip" field. func StartIPIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldStartIP), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldStartIP, vs...)) } // StartIPNotIn applies the NotIn predicate on the "start_ip" field. func StartIPNotIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldStartIP), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldStartIP, vs...)) } // StartIPGT applies the GT predicate on the "start_ip" field. func StartIPGT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldStartIP), v)) - }) + return predicate.Decision(sql.FieldGT(FieldStartIP, v)) } // StartIPGTE applies the GTE predicate on the "start_ip" field. func StartIPGTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldStartIP), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldStartIP, v)) } // StartIPLT applies the LT predicate on the "start_ip" field. func StartIPLT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldStartIP), v)) - }) + return predicate.Decision(sql.FieldLT(FieldStartIP, v)) } // StartIPLTE applies the LTE predicate on the "start_ip" field. func StartIPLTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldStartIP), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldStartIP, v)) } // StartIPIsNil applies the IsNil predicate on the "start_ip" field. func StartIPIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldStartIP))) - }) + return predicate.Decision(sql.FieldIsNull(FieldStartIP)) } // StartIPNotNil applies the NotNil predicate on the "start_ip" field. func StartIPNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldStartIP))) - }) + return predicate.Decision(sql.FieldNotNull(FieldStartIP)) } // EndIPEQ applies the EQ predicate on the "end_ip" field. func EndIPEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldEndIP), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldEndIP, v)) } // EndIPNEQ applies the NEQ predicate on the "end_ip" field. func EndIPNEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldEndIP), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldEndIP, v)) } // EndIPIn applies the In predicate on the "end_ip" field. func EndIPIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldEndIP), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldEndIP, vs...)) } // EndIPNotIn applies the NotIn predicate on the "end_ip" field. func EndIPNotIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldEndIP), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldEndIP, vs...)) } // EndIPGT applies the GT predicate on the "end_ip" field. func EndIPGT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldEndIP), v)) - }) + return predicate.Decision(sql.FieldGT(FieldEndIP, v)) } // EndIPGTE applies the GTE predicate on the "end_ip" field. func EndIPGTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldEndIP), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldEndIP, v)) } // EndIPLT applies the LT predicate on the "end_ip" field. func EndIPLT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldEndIP), v)) - }) + return predicate.Decision(sql.FieldLT(FieldEndIP, v)) } // EndIPLTE applies the LTE predicate on the "end_ip" field. func EndIPLTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldEndIP), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldEndIP, v)) } // EndIPIsNil applies the IsNil predicate on the "end_ip" field. func EndIPIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldEndIP))) - }) + return predicate.Decision(sql.FieldIsNull(FieldEndIP)) } // EndIPNotNil applies the NotNil predicate on the "end_ip" field. func EndIPNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldEndIP))) - }) + return predicate.Decision(sql.FieldNotNull(FieldEndIP)) } // StartSuffixEQ applies the EQ predicate on the "start_suffix" field. func StartSuffixEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStartSuffix), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldStartSuffix, v)) } // StartSuffixNEQ applies the NEQ predicate on the "start_suffix" field. func StartSuffixNEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldStartSuffix), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldStartSuffix, v)) } // StartSuffixIn applies the In predicate on the "start_suffix" field. func StartSuffixIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldStartSuffix), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldStartSuffix, vs...)) } // StartSuffixNotIn applies the NotIn predicate on the "start_suffix" field. func StartSuffixNotIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldStartSuffix), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldStartSuffix, vs...)) } // StartSuffixGT applies the GT predicate on the "start_suffix" field. func StartSuffixGT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldStartSuffix), v)) - }) + return predicate.Decision(sql.FieldGT(FieldStartSuffix, v)) } // StartSuffixGTE applies the GTE predicate on the "start_suffix" field. func StartSuffixGTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldStartSuffix), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldStartSuffix, v)) } // StartSuffixLT applies the LT predicate on the "start_suffix" field. func StartSuffixLT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldStartSuffix), v)) - }) + return predicate.Decision(sql.FieldLT(FieldStartSuffix, v)) } // StartSuffixLTE applies the LTE predicate on the "start_suffix" field. func StartSuffixLTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldStartSuffix), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldStartSuffix, v)) } // StartSuffixIsNil applies the IsNil predicate on the "start_suffix" field. func StartSuffixIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldStartSuffix))) - }) + return predicate.Decision(sql.FieldIsNull(FieldStartSuffix)) } // StartSuffixNotNil applies the NotNil predicate on the "start_suffix" field. func StartSuffixNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldStartSuffix))) - }) + return predicate.Decision(sql.FieldNotNull(FieldStartSuffix)) } // EndSuffixEQ applies the EQ predicate on the "end_suffix" field. func EndSuffixEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldEndSuffix), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldEndSuffix, v)) } // EndSuffixNEQ applies the NEQ predicate on the "end_suffix" field. func EndSuffixNEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldEndSuffix), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldEndSuffix, v)) } // EndSuffixIn applies the In predicate on the "end_suffix" field. func EndSuffixIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldEndSuffix), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldEndSuffix, vs...)) } // EndSuffixNotIn applies the NotIn predicate on the "end_suffix" field. func EndSuffixNotIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldEndSuffix), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldEndSuffix, vs...)) } // EndSuffixGT applies the GT predicate on the "end_suffix" field. func EndSuffixGT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldEndSuffix), v)) - }) + return predicate.Decision(sql.FieldGT(FieldEndSuffix, v)) } // EndSuffixGTE applies the GTE predicate on the "end_suffix" field. func EndSuffixGTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldEndSuffix), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldEndSuffix, v)) } // EndSuffixLT applies the LT predicate on the "end_suffix" field. func EndSuffixLT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldEndSuffix), v)) - }) + return predicate.Decision(sql.FieldLT(FieldEndSuffix, v)) } // EndSuffixLTE applies the LTE predicate on the "end_suffix" field. func EndSuffixLTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldEndSuffix), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldEndSuffix, v)) } // EndSuffixIsNil applies the IsNil predicate on the "end_suffix" field. func EndSuffixIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldEndSuffix))) - }) + return predicate.Decision(sql.FieldIsNull(FieldEndSuffix)) } // EndSuffixNotNil applies the NotNil predicate on the "end_suffix" field. func EndSuffixNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldEndSuffix))) - }) + return predicate.Decision(sql.FieldNotNull(FieldEndSuffix)) } // IPSizeEQ applies the EQ predicate on the "ip_size" field. func IPSizeEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldIPSize), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldIPSize, v)) } // IPSizeNEQ applies the NEQ predicate on the "ip_size" field. func IPSizeNEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldIPSize), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldIPSize, v)) } // IPSizeIn applies the In predicate on the "ip_size" field. func IPSizeIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldIPSize), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldIPSize, vs...)) } // IPSizeNotIn applies the NotIn predicate on the "ip_size" field. func IPSizeNotIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldIPSize), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldIPSize, vs...)) } // IPSizeGT applies the GT predicate on the "ip_size" field. func IPSizeGT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldIPSize), v)) - }) + return predicate.Decision(sql.FieldGT(FieldIPSize, v)) } // IPSizeGTE applies the GTE predicate on the "ip_size" field. func IPSizeGTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldIPSize), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldIPSize, v)) } // IPSizeLT applies the LT predicate on the "ip_size" field. func IPSizeLT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldIPSize), v)) - }) + return predicate.Decision(sql.FieldLT(FieldIPSize, v)) } // IPSizeLTE applies the LTE predicate on the "ip_size" field. func IPSizeLTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldIPSize), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldIPSize, v)) } // IPSizeIsNil applies the IsNil predicate on the "ip_size" field. func IPSizeIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldIPSize))) - }) + return predicate.Decision(sql.FieldIsNull(FieldIPSize)) } // IPSizeNotNil applies the NotNil predicate on the "ip_size" field. func IPSizeNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldIPSize))) - }) + return predicate.Decision(sql.FieldNotNull(FieldIPSize)) } // ScopeEQ applies the EQ predicate on the "scope" field. func ScopeEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldScope, v)) } // ScopeNEQ applies the NEQ predicate on the "scope" field. func ScopeNEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldScope, v)) } // ScopeIn applies the In predicate on the "scope" field. func ScopeIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldScope), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldScope, vs...)) } // ScopeNotIn applies the NotIn predicate on the "scope" field. func ScopeNotIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldScope), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldScope, vs...)) } // ScopeGT applies the GT predicate on the "scope" field. func ScopeGT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldGT(FieldScope, v)) } // ScopeGTE applies the GTE predicate on the "scope" field. func ScopeGTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldScope, v)) } // ScopeLT applies the LT predicate on the "scope" field. func ScopeLT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldLT(FieldScope, v)) } // ScopeLTE applies the LTE predicate on the "scope" field. func ScopeLTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldScope, v)) } // ScopeContains applies the Contains predicate on the "scope" field. func ScopeContains(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldContains(FieldScope, v)) } // ScopeHasPrefix applies the HasPrefix predicate on the "scope" field. func ScopeHasPrefix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldHasPrefix(FieldScope, v)) } // ScopeHasSuffix applies the HasSuffix predicate on the "scope" field. func ScopeHasSuffix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldHasSuffix(FieldScope, v)) } // ScopeEqualFold applies the EqualFold predicate on the "scope" field. func ScopeEqualFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldEqualFold(FieldScope, v)) } // ScopeContainsFold applies the ContainsFold predicate on the "scope" field. func ScopeContainsFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldContainsFold(FieldScope, v)) } // ValueEQ applies the EQ predicate on the "value" field. func ValueEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldValue, v)) } // ValueNEQ applies the NEQ predicate on the "value" field. func ValueNEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldValue, v)) } // ValueIn applies the In predicate on the "value" field. func ValueIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldValue), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldValue, vs...)) } // ValueNotIn applies the NotIn predicate on the "value" field. func ValueNotIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldValue), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldValue, vs...)) } // ValueGT applies the GT predicate on the "value" field. func ValueGT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldGT(FieldValue, v)) } // ValueGTE applies the GTE predicate on the "value" field. func ValueGTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldValue, v)) } // ValueLT applies the LT predicate on the "value" field. func ValueLT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldLT(FieldValue, v)) } // ValueLTE applies the LTE predicate on the "value" field. func ValueLTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldValue, v)) } // ValueContains applies the Contains predicate on the "value" field. func ValueContains(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldContains(FieldValue, v)) } // ValueHasPrefix applies the HasPrefix predicate on the "value" field. func ValueHasPrefix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldHasPrefix(FieldValue, v)) } // ValueHasSuffix applies the HasSuffix predicate on the "value" field. func ValueHasSuffix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldHasSuffix(FieldValue, v)) } // ValueEqualFold applies the EqualFold predicate on the "value" field. func ValueEqualFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldEqualFold(FieldValue, v)) } // ValueContainsFold applies the ContainsFold predicate on the "value" field. func ValueContainsFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldContainsFold(FieldValue, v)) } // OriginEQ applies the EQ predicate on the "origin" field. func OriginEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldOrigin, v)) } // OriginNEQ applies the NEQ predicate on the "origin" field. func OriginNEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldOrigin, v)) } // OriginIn applies the In predicate on the "origin" field. func OriginIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldOrigin), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldOrigin, vs...)) } // OriginNotIn applies the NotIn predicate on the "origin" field. func OriginNotIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldOrigin), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldOrigin, vs...)) } // OriginGT applies the GT predicate on the "origin" field. func OriginGT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldGT(FieldOrigin, v)) } // OriginGTE applies the GTE predicate on the "origin" field. func OriginGTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldOrigin, v)) } // OriginLT applies the LT predicate on the "origin" field. func OriginLT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldLT(FieldOrigin, v)) } // OriginLTE applies the LTE predicate on the "origin" field. func OriginLTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldOrigin, v)) } // OriginContains applies the Contains predicate on the "origin" field. func OriginContains(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldContains(FieldOrigin, v)) } // OriginHasPrefix applies the HasPrefix predicate on the "origin" field. func OriginHasPrefix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldHasPrefix(FieldOrigin, v)) } // OriginHasSuffix applies the HasSuffix predicate on the "origin" field. func OriginHasSuffix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldHasSuffix(FieldOrigin, v)) } // OriginEqualFold applies the EqualFold predicate on the "origin" field. func OriginEqualFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldEqualFold(FieldOrigin, v)) } // OriginContainsFold applies the ContainsFold predicate on the "origin" field. func OriginContainsFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldContainsFold(FieldOrigin, v)) } // SimulatedEQ applies the EQ predicate on the "simulated" field. func SimulatedEQ(v bool) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSimulated), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldSimulated, v)) } // SimulatedNEQ applies the NEQ predicate on the "simulated" field. func SimulatedNEQ(v bool) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSimulated), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldSimulated, v)) } // UUIDEQ applies the EQ predicate on the "uuid" field. func UUIDEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldUUID, v)) } // UUIDNEQ applies the NEQ predicate on the "uuid" field. func UUIDNEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldUUID, v)) } // UUIDIn applies the In predicate on the "uuid" field. func UUIDIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUUID), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldUUID, vs...)) } // UUIDNotIn applies the NotIn predicate on the "uuid" field. func UUIDNotIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUUID), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldUUID, vs...)) } // UUIDGT applies the GT predicate on the "uuid" field. func UUIDGT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldGT(FieldUUID, v)) } // UUIDGTE applies the GTE predicate on the "uuid" field. func UUIDGTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldUUID, v)) } // UUIDLT applies the LT predicate on the "uuid" field. func UUIDLT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldLT(FieldUUID, v)) } // UUIDLTE applies the LTE predicate on the "uuid" field. func UUIDLTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldUUID, v)) } // UUIDContains applies the Contains predicate on the "uuid" field. func UUIDContains(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldContains(FieldUUID, v)) } // UUIDHasPrefix applies the HasPrefix predicate on the "uuid" field. func UUIDHasPrefix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldHasPrefix(FieldUUID, v)) } // UUIDHasSuffix applies the HasSuffix predicate on the "uuid" field. func UUIDHasSuffix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldHasSuffix(FieldUUID, v)) } // UUIDIsNil applies the IsNil predicate on the "uuid" field. func UUIDIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUUID))) - }) + return predicate.Decision(sql.FieldIsNull(FieldUUID)) } // UUIDNotNil applies the NotNil predicate on the "uuid" field. func UUIDNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUUID))) - }) + return predicate.Decision(sql.FieldNotNull(FieldUUID)) } // UUIDEqualFold applies the EqualFold predicate on the "uuid" field. func UUIDEqualFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldEqualFold(FieldUUID, v)) } // UUIDContainsFold applies the ContainsFold predicate on the "uuid" field. func UUIDContainsFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldContainsFold(FieldUUID, v)) } // AlertDecisionsEQ applies the EQ predicate on the "alert_decisions" field. func AlertDecisionsEQ(v int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAlertDecisions), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldAlertDecisions, v)) } // AlertDecisionsNEQ applies the NEQ predicate on the "alert_decisions" field. func AlertDecisionsNEQ(v int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldAlertDecisions), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldAlertDecisions, v)) } // AlertDecisionsIn applies the In predicate on the "alert_decisions" field. func AlertDecisionsIn(vs ...int) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldAlertDecisions), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldAlertDecisions, vs...)) } // AlertDecisionsNotIn applies the NotIn predicate on the "alert_decisions" field. func AlertDecisionsNotIn(vs ...int) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldAlertDecisions), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldAlertDecisions, vs...)) } // AlertDecisionsIsNil applies the IsNil predicate on the "alert_decisions" field. func AlertDecisionsIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldAlertDecisions))) - }) + return predicate.Decision(sql.FieldIsNull(FieldAlertDecisions)) } // AlertDecisionsNotNil applies the NotNil predicate on the "alert_decisions" field. func AlertDecisionsNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldAlertDecisions))) - }) + return predicate.Decision(sql.FieldNotNull(FieldAlertDecisions)) } // HasOwner applies the HasEdge predicate on the "owner" edge. @@ -1494,7 +960,6 @@ func HasOwner() predicate.Decision { return predicate.Decision(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(OwnerTable, FieldID), sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), ) sqlgraph.HasNeighbors(s, step) @@ -1504,11 +969,7 @@ func HasOwner() predicate.Decision { // HasOwnerWith applies the HasEdge predicate on the "owner" edge with a given conditions (other predicates). func HasOwnerWith(preds ...predicate.Alert) predicate.Decision { return predicate.Decision(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(OwnerInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), - ) + step := newOwnerStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -1519,32 +980,15 @@ func HasOwnerWith(preds ...predicate.Alert) predicate.Decision { // And groups predicates with the AND operator between them. func And(predicates ...predicate.Decision) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Decision(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.Decision) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Decision(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.Decision) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.Decision(sql.NotPredicates(p)) } diff --git a/pkg/database/ent/decision_create.go b/pkg/database/ent/decision_create.go index 64238cb7003..f30d5452120 100644 --- a/pkg/database/ent/decision_create.go +++ b/pkg/database/ent/decision_create.go @@ -231,50 +231,8 @@ func (dc *DecisionCreate) Mutation() *DecisionMutation { // Save creates the Decision in the database. func (dc *DecisionCreate) Save(ctx context.Context) (*Decision, error) { - var ( - err error - node *Decision - ) dc.defaults() - if len(dc.hooks) == 0 { - if err = dc.check(); err != nil { - return nil, err - } - node, err = dc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*DecisionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = dc.check(); err != nil { - return nil, err - } - dc.mutation = mutation - if node, err = dc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(dc.hooks) - 1; i >= 0; i-- { - if dc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = dc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, dc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Decision) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from DecisionMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, dc.sqlSave, dc.mutation, dc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -317,6 +275,12 @@ func (dc *DecisionCreate) defaults() { // check runs all checks and user-defined validators on the builder. func (dc *DecisionCreate) check() error { + if _, ok := dc.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Decision.created_at"`)} + } + if _, ok := dc.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "Decision.updated_at"`)} + } if _, ok := dc.mutation.Scenario(); !ok { return &ValidationError{Name: "scenario", err: errors.New(`ent: missing required field "Decision.scenario"`)} } @@ -339,6 +303,9 @@ func (dc *DecisionCreate) check() error { } func (dc *DecisionCreate) sqlSave(ctx context.Context) (*Decision, error) { + if err := dc.check(); err != nil { + return nil, err + } _node, _spec := dc.createSpec() if err := sqlgraph.CreateNode(ctx, dc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -348,138 +315,74 @@ func (dc *DecisionCreate) sqlSave(ctx context.Context) (*Decision, error) { } id := _spec.ID.Value.(int64) _node.ID = int(id) + dc.mutation.id = &_node.ID + dc.mutation.done = true return _node, nil } func (dc *DecisionCreate) createSpec() (*Decision, *sqlgraph.CreateSpec) { var ( _node = &Decision{config: dc.config} - _spec = &sqlgraph.CreateSpec{ - Table: decision.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(decision.Table, sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt)) ) if value, ok := dc.mutation.CreatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: decision.FieldCreatedAt, - }) - _node.CreatedAt = &value + _spec.SetField(decision.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value } if value, ok := dc.mutation.UpdatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: decision.FieldUpdatedAt, - }) - _node.UpdatedAt = &value + _spec.SetField(decision.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value } if value, ok := dc.mutation.Until(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: decision.FieldUntil, - }) + _spec.SetField(decision.FieldUntil, field.TypeTime, value) _node.Until = &value } if value, ok := dc.mutation.Scenario(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldScenario, - }) + _spec.SetField(decision.FieldScenario, field.TypeString, value) _node.Scenario = value } if value, ok := dc.mutation.GetType(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldType, - }) + _spec.SetField(decision.FieldType, field.TypeString, value) _node.Type = value } if value, ok := dc.mutation.StartIP(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartIP, - }) + _spec.SetField(decision.FieldStartIP, field.TypeInt64, value) _node.StartIP = value } if value, ok := dc.mutation.EndIP(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndIP, - }) + _spec.SetField(decision.FieldEndIP, field.TypeInt64, value) _node.EndIP = value } if value, ok := dc.mutation.StartSuffix(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartSuffix, - }) + _spec.SetField(decision.FieldStartSuffix, field.TypeInt64, value) _node.StartSuffix = value } if value, ok := dc.mutation.EndSuffix(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndSuffix, - }) + _spec.SetField(decision.FieldEndSuffix, field.TypeInt64, value) _node.EndSuffix = value } if value, ok := dc.mutation.IPSize(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldIPSize, - }) + _spec.SetField(decision.FieldIPSize, field.TypeInt64, value) _node.IPSize = value } if value, ok := dc.mutation.Scope(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldScope, - }) + _spec.SetField(decision.FieldScope, field.TypeString, value) _node.Scope = value } if value, ok := dc.mutation.Value(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldValue, - }) + _spec.SetField(decision.FieldValue, field.TypeString, value) _node.Value = value } if value, ok := dc.mutation.Origin(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldOrigin, - }) + _spec.SetField(decision.FieldOrigin, field.TypeString, value) _node.Origin = value } if value, ok := dc.mutation.Simulated(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: decision.FieldSimulated, - }) + _spec.SetField(decision.FieldSimulated, field.TypeBool, value) _node.Simulated = value } if value, ok := dc.mutation.UUID(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldUUID, - }) + _spec.SetField(decision.FieldUUID, field.TypeString, value) _node.UUID = value } if nodes := dc.mutation.OwnerIDs(); len(nodes) > 0 { @@ -490,10 +393,7 @@ func (dc *DecisionCreate) createSpec() (*Decision, *sqlgraph.CreateSpec) { Columns: []string{decision.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -508,11 +408,15 @@ func (dc *DecisionCreate) createSpec() (*Decision, *sqlgraph.CreateSpec) { // DecisionCreateBulk is the builder for creating many Decision entities in bulk. type DecisionCreateBulk struct { config + err error builders []*DecisionCreate } // Save creates the Decision entities in the database. func (dcb *DecisionCreateBulk) Save(ctx context.Context) ([]*Decision, error) { + if dcb.err != nil { + return nil, dcb.err + } specs := make([]*sqlgraph.CreateSpec, len(dcb.builders)) nodes := make([]*Decision, len(dcb.builders)) mutators := make([]Mutator, len(dcb.builders)) @@ -529,8 +433,8 @@ func (dcb *DecisionCreateBulk) Save(ctx context.Context) ([]*Decision, error) { return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, dcb.builders[i+1].mutation) } else { diff --git a/pkg/database/ent/decision_delete.go b/pkg/database/ent/decision_delete.go index 24b494b113e..35bb8767283 100644 --- a/pkg/database/ent/decision_delete.go +++ b/pkg/database/ent/decision_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (dd *DecisionDelete) Where(ps ...predicate.Decision) *DecisionDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (dd *DecisionDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(dd.hooks) == 0 { - affected, err = dd.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*DecisionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - dd.mutation = mutation - affected, err = dd.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(dd.hooks) - 1; i >= 0; i-- { - if dd.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = dd.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, dd.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, dd.sqlExec, dd.mutation, dd.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (dd *DecisionDelete) ExecX(ctx context.Context) int { } func (dd *DecisionDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: decision.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(decision.Table, sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt)) if ps := dd.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (dd *DecisionDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + dd.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type DecisionDeleteOne struct { dd *DecisionDelete } +// Where appends a list predicates to the DecisionDelete builder. +func (ddo *DecisionDeleteOne) Where(ps ...predicate.Decision) *DecisionDeleteOne { + ddo.dd.mutation.Where(ps...) + return ddo +} + // Exec executes the deletion query. func (ddo *DecisionDeleteOne) Exec(ctx context.Context) error { n, err := ddo.dd.Exec(ctx) @@ -111,5 +82,7 @@ func (ddo *DecisionDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (ddo *DecisionDeleteOne) ExecX(ctx context.Context) { - ddo.dd.ExecX(ctx) + if err := ddo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/pkg/database/ent/decision_query.go b/pkg/database/ent/decision_query.go index 91aebded968..b050a4d9649 100644 --- a/pkg/database/ent/decision_query.go +++ b/pkg/database/ent/decision_query.go @@ -18,11 +18,9 @@ import ( // DecisionQuery is the builder for querying Decision entities. type DecisionQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string + ctx *QueryContext + order []decision.OrderOption + inters []Interceptor predicates []predicate.Decision withOwner *AlertQuery // intermediate query (i.e. traversal path). @@ -36,34 +34,34 @@ func (dq *DecisionQuery) Where(ps ...predicate.Decision) *DecisionQuery { return dq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (dq *DecisionQuery) Limit(limit int) *DecisionQuery { - dq.limit = &limit + dq.ctx.Limit = &limit return dq } -// Offset adds an offset step to the query. +// Offset to start from. func (dq *DecisionQuery) Offset(offset int) *DecisionQuery { - dq.offset = &offset + dq.ctx.Offset = &offset return dq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (dq *DecisionQuery) Unique(unique bool) *DecisionQuery { - dq.unique = &unique + dq.ctx.Unique = &unique return dq } -// Order adds an order step to the query. -func (dq *DecisionQuery) Order(o ...OrderFunc) *DecisionQuery { +// Order specifies how the records should be ordered. +func (dq *DecisionQuery) Order(o ...decision.OrderOption) *DecisionQuery { dq.order = append(dq.order, o...) return dq } // QueryOwner chains the current query on the "owner" edge. func (dq *DecisionQuery) QueryOwner() *AlertQuery { - query := &AlertQuery{config: dq.config} + query := (&AlertClient{config: dq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := dq.prepareQuery(ctx); err != nil { return nil, err @@ -86,7 +84,7 @@ func (dq *DecisionQuery) QueryOwner() *AlertQuery { // First returns the first Decision entity from the query. // Returns a *NotFoundError when no Decision was found. func (dq *DecisionQuery) First(ctx context.Context) (*Decision, error) { - nodes, err := dq.Limit(1).All(ctx) + nodes, err := dq.Limit(1).All(setContextOp(ctx, dq.ctx, "First")) if err != nil { return nil, err } @@ -109,7 +107,7 @@ func (dq *DecisionQuery) FirstX(ctx context.Context) *Decision { // Returns a *NotFoundError when no Decision ID was found. func (dq *DecisionQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = dq.Limit(1).IDs(ctx); err != nil { + if ids, err = dq.Limit(1).IDs(setContextOp(ctx, dq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -132,7 +130,7 @@ func (dq *DecisionQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Decision entity is found. // Returns a *NotFoundError when no Decision entities are found. func (dq *DecisionQuery) Only(ctx context.Context) (*Decision, error) { - nodes, err := dq.Limit(2).All(ctx) + nodes, err := dq.Limit(2).All(setContextOp(ctx, dq.ctx, "Only")) if err != nil { return nil, err } @@ -160,7 +158,7 @@ func (dq *DecisionQuery) OnlyX(ctx context.Context) *Decision { // Returns a *NotFoundError when no entities are found. func (dq *DecisionQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = dq.Limit(2).IDs(ctx); err != nil { + if ids, err = dq.Limit(2).IDs(setContextOp(ctx, dq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -185,10 +183,12 @@ func (dq *DecisionQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Decisions. func (dq *DecisionQuery) All(ctx context.Context) ([]*Decision, error) { + ctx = setContextOp(ctx, dq.ctx, "All") if err := dq.prepareQuery(ctx); err != nil { return nil, err } - return dq.sqlAll(ctx) + qr := querierAll[[]*Decision, *DecisionQuery]() + return withInterceptors[[]*Decision](ctx, dq, qr, dq.inters) } // AllX is like All, but panics if an error occurs. @@ -201,9 +201,12 @@ func (dq *DecisionQuery) AllX(ctx context.Context) []*Decision { } // IDs executes the query and returns a list of Decision IDs. -func (dq *DecisionQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := dq.Select(decision.FieldID).Scan(ctx, &ids); err != nil { +func (dq *DecisionQuery) IDs(ctx context.Context) (ids []int, err error) { + if dq.ctx.Unique == nil && dq.path != nil { + dq.Unique(true) + } + ctx = setContextOp(ctx, dq.ctx, "IDs") + if err = dq.Select(decision.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -220,10 +223,11 @@ func (dq *DecisionQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (dq *DecisionQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, dq.ctx, "Count") if err := dq.prepareQuery(ctx); err != nil { return 0, err } - return dq.sqlCount(ctx) + return withInterceptors[int](ctx, dq, querierCount[*DecisionQuery](), dq.inters) } // CountX is like Count, but panics if an error occurs. @@ -237,10 +241,15 @@ func (dq *DecisionQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (dq *DecisionQuery) Exist(ctx context.Context) (bool, error) { - if err := dq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, dq.ctx, "Exist") + switch _, err := dq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return dq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -260,22 +269,21 @@ func (dq *DecisionQuery) Clone() *DecisionQuery { } return &DecisionQuery{ config: dq.config, - limit: dq.limit, - offset: dq.offset, - order: append([]OrderFunc{}, dq.order...), + ctx: dq.ctx.Clone(), + order: append([]decision.OrderOption{}, dq.order...), + inters: append([]Interceptor{}, dq.inters...), predicates: append([]predicate.Decision{}, dq.predicates...), withOwner: dq.withOwner.Clone(), // clone intermediate query. - sql: dq.sql.Clone(), - path: dq.path, - unique: dq.unique, + sql: dq.sql.Clone(), + path: dq.path, } } // WithOwner tells the query-builder to eager-load the nodes that are connected to // the "owner" edge. The optional arguments are used to configure the query builder of the edge. func (dq *DecisionQuery) WithOwner(opts ...func(*AlertQuery)) *DecisionQuery { - query := &AlertQuery{config: dq.config} + query := (&AlertClient{config: dq.config}).Query() for _, opt := range opts { opt(query) } @@ -298,16 +306,11 @@ func (dq *DecisionQuery) WithOwner(opts ...func(*AlertQuery)) *DecisionQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (dq *DecisionQuery) GroupBy(field string, fields ...string) *DecisionGroupBy { - grbuild := &DecisionGroupBy{config: dq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := dq.prepareQuery(ctx); err != nil { - return nil, err - } - return dq.sqlQuery(ctx), nil - } + dq.ctx.Fields = append([]string{field}, fields...) + grbuild := &DecisionGroupBy{build: dq} + grbuild.flds = &dq.ctx.Fields grbuild.label = decision.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -324,15 +327,30 @@ func (dq *DecisionQuery) GroupBy(field string, fields ...string) *DecisionGroupB // Select(decision.FieldCreatedAt). // Scan(ctx, &v) func (dq *DecisionQuery) Select(fields ...string) *DecisionSelect { - dq.fields = append(dq.fields, fields...) - selbuild := &DecisionSelect{DecisionQuery: dq} - selbuild.label = decision.Label - selbuild.flds, selbuild.scan = &dq.fields, selbuild.Scan - return selbuild + dq.ctx.Fields = append(dq.ctx.Fields, fields...) + sbuild := &DecisionSelect{DecisionQuery: dq} + sbuild.label = decision.Label + sbuild.flds, sbuild.scan = &dq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a DecisionSelect configured with the given aggregations. +func (dq *DecisionQuery) Aggregate(fns ...AggregateFunc) *DecisionSelect { + return dq.Select().Aggregate(fns...) } func (dq *DecisionQuery) prepareQuery(ctx context.Context) error { - for _, f := range dq.fields { + for _, inter := range dq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, dq); err != nil { + return err + } + } + } + for _, f := range dq.ctx.Fields { if !decision.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -392,6 +410,9 @@ func (dq *DecisionQuery) loadOwner(ctx context.Context, query *AlertQuery, nodes } nodeids[fk] = append(nodeids[fk], nodes[i]) } + if len(ids) == 0 { + return nil + } query.Where(alert.IDIn(ids...)) neighbors, err := query.All(ctx) if err != nil { @@ -411,41 +432,22 @@ func (dq *DecisionQuery) loadOwner(ctx context.Context, query *AlertQuery, nodes func (dq *DecisionQuery) sqlCount(ctx context.Context) (int, error) { _spec := dq.querySpec() - _spec.Node.Columns = dq.fields - if len(dq.fields) > 0 { - _spec.Unique = dq.unique != nil && *dq.unique + _spec.Node.Columns = dq.ctx.Fields + if len(dq.ctx.Fields) > 0 { + _spec.Unique = dq.ctx.Unique != nil && *dq.ctx.Unique } return sqlgraph.CountNodes(ctx, dq.driver, _spec) } -func (dq *DecisionQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := dq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (dq *DecisionQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: decision.Table, - Columns: decision.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, - }, - From: dq.sql, - Unique: true, - } - if unique := dq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(decision.Table, decision.Columns, sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt)) + _spec.From = dq.sql + if unique := dq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if dq.path != nil { + _spec.Unique = true } - if fields := dq.fields; len(fields) > 0 { + if fields := dq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, decision.FieldID) for i := range fields { @@ -453,6 +455,9 @@ func (dq *DecisionQuery) querySpec() *sqlgraph.QuerySpec { _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) } } + if dq.withOwner != nil { + _spec.Node.AddColumnOnce(decision.FieldAlertDecisions) + } } if ps := dq.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { @@ -461,10 +466,10 @@ func (dq *DecisionQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := dq.limit; limit != nil { + if limit := dq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := dq.offset; offset != nil { + if offset := dq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := dq.order; len(ps) > 0 { @@ -480,7 +485,7 @@ func (dq *DecisionQuery) querySpec() *sqlgraph.QuerySpec { func (dq *DecisionQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(dq.driver.Dialect()) t1 := builder.Table(decision.Table) - columns := dq.fields + columns := dq.ctx.Fields if len(columns) == 0 { columns = decision.Columns } @@ -489,7 +494,7 @@ func (dq *DecisionQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = dq.sql selector.Select(selector.Columns(columns...)...) } - if dq.unique != nil && *dq.unique { + if dq.ctx.Unique != nil && *dq.ctx.Unique { selector.Distinct() } for _, p := range dq.predicates { @@ -498,12 +503,12 @@ func (dq *DecisionQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range dq.order { p(selector) } - if offset := dq.offset; offset != nil { + if offset := dq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := dq.limit; limit != nil { + if limit := dq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -511,13 +516,8 @@ func (dq *DecisionQuery) sqlQuery(ctx context.Context) *sql.Selector { // DecisionGroupBy is the group-by builder for Decision entities. type DecisionGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *DecisionQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -526,74 +526,77 @@ func (dgb *DecisionGroupBy) Aggregate(fns ...AggregateFunc) *DecisionGroupBy { return dgb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (dgb *DecisionGroupBy) Scan(ctx context.Context, v any) error { - query, err := dgb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, dgb.build.ctx, "GroupBy") + if err := dgb.build.prepareQuery(ctx); err != nil { return err } - dgb.sql = query - return dgb.sqlScan(ctx, v) + return scanWithInterceptors[*DecisionQuery, *DecisionGroupBy](ctx, dgb.build, dgb, dgb.build.inters, v) } -func (dgb *DecisionGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range dgb.fields { - if !decision.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (dgb *DecisionGroupBy) sqlScan(ctx context.Context, root *DecisionQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(dgb.fns)) + for _, fn := range dgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*dgb.flds)+len(dgb.fns)) + for _, f := range *dgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := dgb.sqlQuery() + selector.GroupBy(selector.Columns(*dgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := dgb.driver.Query(ctx, query, args, rows); err != nil { + if err := dgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (dgb *DecisionGroupBy) sqlQuery() *sql.Selector { - selector := dgb.sql.Select() - aggregation := make([]string, 0, len(dgb.fns)) - for _, fn := range dgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(dgb.fields)+len(dgb.fns)) - for _, f := range dgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(dgb.fields...)...) -} - // DecisionSelect is the builder for selecting fields of Decision entities. type DecisionSelect struct { *DecisionQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ds *DecisionSelect) Aggregate(fns ...AggregateFunc) *DecisionSelect { + ds.fns = append(ds.fns, fns...) + return ds } // Scan applies the selector query and scans the result into the given value. func (ds *DecisionSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ds.ctx, "Select") if err := ds.prepareQuery(ctx); err != nil { return err } - ds.sql = ds.DecisionQuery.sqlQuery(ctx) - return ds.sqlScan(ctx, v) + return scanWithInterceptors[*DecisionQuery, *DecisionSelect](ctx, ds.DecisionQuery, ds, ds.inters, v) } -func (ds *DecisionSelect) sqlScan(ctx context.Context, v any) error { +func (ds *DecisionSelect) sqlScan(ctx context.Context, root *DecisionQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ds.fns)) + for _, fn := range ds.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ds.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := ds.sql.Query() + query, args := selector.Query() if err := ds.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/pkg/database/ent/decision_update.go b/pkg/database/ent/decision_update.go index 64b40871eca..68d0eb4ace7 100644 --- a/pkg/database/ent/decision_update.go +++ b/pkg/database/ent/decision_update.go @@ -29,30 +29,12 @@ func (du *DecisionUpdate) Where(ps ...predicate.Decision) *DecisionUpdate { return du } -// SetCreatedAt sets the "created_at" field. -func (du *DecisionUpdate) SetCreatedAt(t time.Time) *DecisionUpdate { - du.mutation.SetCreatedAt(t) - return du -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (du *DecisionUpdate) ClearCreatedAt() *DecisionUpdate { - du.mutation.ClearCreatedAt() - return du -} - // SetUpdatedAt sets the "updated_at" field. func (du *DecisionUpdate) SetUpdatedAt(t time.Time) *DecisionUpdate { du.mutation.SetUpdatedAt(t) return du } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (du *DecisionUpdate) ClearUpdatedAt() *DecisionUpdate { - du.mutation.ClearUpdatedAt() - return du -} - // SetUntil sets the "until" field. func (du *DecisionUpdate) SetUntil(t time.Time) *DecisionUpdate { du.mutation.SetUntil(t) @@ -73,205 +55,6 @@ func (du *DecisionUpdate) ClearUntil() *DecisionUpdate { return du } -// SetScenario sets the "scenario" field. -func (du *DecisionUpdate) SetScenario(s string) *DecisionUpdate { - du.mutation.SetScenario(s) - return du -} - -// SetType sets the "type" field. -func (du *DecisionUpdate) SetType(s string) *DecisionUpdate { - du.mutation.SetType(s) - return du -} - -// SetStartIP sets the "start_ip" field. -func (du *DecisionUpdate) SetStartIP(i int64) *DecisionUpdate { - du.mutation.ResetStartIP() - du.mutation.SetStartIP(i) - return du -} - -// SetNillableStartIP sets the "start_ip" field if the given value is not nil. -func (du *DecisionUpdate) SetNillableStartIP(i *int64) *DecisionUpdate { - if i != nil { - du.SetStartIP(*i) - } - return du -} - -// AddStartIP adds i to the "start_ip" field. -func (du *DecisionUpdate) AddStartIP(i int64) *DecisionUpdate { - du.mutation.AddStartIP(i) - return du -} - -// ClearStartIP clears the value of the "start_ip" field. -func (du *DecisionUpdate) ClearStartIP() *DecisionUpdate { - du.mutation.ClearStartIP() - return du -} - -// SetEndIP sets the "end_ip" field. -func (du *DecisionUpdate) SetEndIP(i int64) *DecisionUpdate { - du.mutation.ResetEndIP() - du.mutation.SetEndIP(i) - return du -} - -// SetNillableEndIP sets the "end_ip" field if the given value is not nil. -func (du *DecisionUpdate) SetNillableEndIP(i *int64) *DecisionUpdate { - if i != nil { - du.SetEndIP(*i) - } - return du -} - -// AddEndIP adds i to the "end_ip" field. -func (du *DecisionUpdate) AddEndIP(i int64) *DecisionUpdate { - du.mutation.AddEndIP(i) - return du -} - -// ClearEndIP clears the value of the "end_ip" field. -func (du *DecisionUpdate) ClearEndIP() *DecisionUpdate { - du.mutation.ClearEndIP() - return du -} - -// SetStartSuffix sets the "start_suffix" field. -func (du *DecisionUpdate) SetStartSuffix(i int64) *DecisionUpdate { - du.mutation.ResetStartSuffix() - du.mutation.SetStartSuffix(i) - return du -} - -// SetNillableStartSuffix sets the "start_suffix" field if the given value is not nil. -func (du *DecisionUpdate) SetNillableStartSuffix(i *int64) *DecisionUpdate { - if i != nil { - du.SetStartSuffix(*i) - } - return du -} - -// AddStartSuffix adds i to the "start_suffix" field. -func (du *DecisionUpdate) AddStartSuffix(i int64) *DecisionUpdate { - du.mutation.AddStartSuffix(i) - return du -} - -// ClearStartSuffix clears the value of the "start_suffix" field. -func (du *DecisionUpdate) ClearStartSuffix() *DecisionUpdate { - du.mutation.ClearStartSuffix() - return du -} - -// SetEndSuffix sets the "end_suffix" field. -func (du *DecisionUpdate) SetEndSuffix(i int64) *DecisionUpdate { - du.mutation.ResetEndSuffix() - du.mutation.SetEndSuffix(i) - return du -} - -// SetNillableEndSuffix sets the "end_suffix" field if the given value is not nil. -func (du *DecisionUpdate) SetNillableEndSuffix(i *int64) *DecisionUpdate { - if i != nil { - du.SetEndSuffix(*i) - } - return du -} - -// AddEndSuffix adds i to the "end_suffix" field. -func (du *DecisionUpdate) AddEndSuffix(i int64) *DecisionUpdate { - du.mutation.AddEndSuffix(i) - return du -} - -// ClearEndSuffix clears the value of the "end_suffix" field. -func (du *DecisionUpdate) ClearEndSuffix() *DecisionUpdate { - du.mutation.ClearEndSuffix() - return du -} - -// SetIPSize sets the "ip_size" field. -func (du *DecisionUpdate) SetIPSize(i int64) *DecisionUpdate { - du.mutation.ResetIPSize() - du.mutation.SetIPSize(i) - return du -} - -// SetNillableIPSize sets the "ip_size" field if the given value is not nil. -func (du *DecisionUpdate) SetNillableIPSize(i *int64) *DecisionUpdate { - if i != nil { - du.SetIPSize(*i) - } - return du -} - -// AddIPSize adds i to the "ip_size" field. -func (du *DecisionUpdate) AddIPSize(i int64) *DecisionUpdate { - du.mutation.AddIPSize(i) - return du -} - -// ClearIPSize clears the value of the "ip_size" field. -func (du *DecisionUpdate) ClearIPSize() *DecisionUpdate { - du.mutation.ClearIPSize() - return du -} - -// SetScope sets the "scope" field. -func (du *DecisionUpdate) SetScope(s string) *DecisionUpdate { - du.mutation.SetScope(s) - return du -} - -// SetValue sets the "value" field. -func (du *DecisionUpdate) SetValue(s string) *DecisionUpdate { - du.mutation.SetValue(s) - return du -} - -// SetOrigin sets the "origin" field. -func (du *DecisionUpdate) SetOrigin(s string) *DecisionUpdate { - du.mutation.SetOrigin(s) - return du -} - -// SetSimulated sets the "simulated" field. -func (du *DecisionUpdate) SetSimulated(b bool) *DecisionUpdate { - du.mutation.SetSimulated(b) - return du -} - -// SetNillableSimulated sets the "simulated" field if the given value is not nil. -func (du *DecisionUpdate) SetNillableSimulated(b *bool) *DecisionUpdate { - if b != nil { - du.SetSimulated(*b) - } - return du -} - -// SetUUID sets the "uuid" field. -func (du *DecisionUpdate) SetUUID(s string) *DecisionUpdate { - du.mutation.SetUUID(s) - return du -} - -// SetNillableUUID sets the "uuid" field if the given value is not nil. -func (du *DecisionUpdate) SetNillableUUID(s *string) *DecisionUpdate { - if s != nil { - du.SetUUID(*s) - } - return du -} - -// ClearUUID clears the value of the "uuid" field. -func (du *DecisionUpdate) ClearUUID() *DecisionUpdate { - du.mutation.ClearUUID() - return du -} - // SetAlertDecisions sets the "alert_decisions" field. func (du *DecisionUpdate) SetAlertDecisions(i int) *DecisionUpdate { du.mutation.SetAlertDecisions(i) @@ -324,35 +107,8 @@ func (du *DecisionUpdate) ClearOwner() *DecisionUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (du *DecisionUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) du.defaults() - if len(du.hooks) == 0 { - affected, err = du.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*DecisionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - du.mutation = mutation - affected, err = du.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(du.hooks) - 1; i >= 0; i-- { - if du.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = du.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, du.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, du.sqlSave, du.mutation, du.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -379,27 +135,14 @@ func (du *DecisionUpdate) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (du *DecisionUpdate) defaults() { - if _, ok := du.mutation.CreatedAt(); !ok && !du.mutation.CreatedAtCleared() { - v := decision.UpdateDefaultCreatedAt() - du.mutation.SetCreatedAt(v) - } - if _, ok := du.mutation.UpdatedAt(); !ok && !du.mutation.UpdatedAtCleared() { + if _, ok := du.mutation.UpdatedAt(); !ok { v := decision.UpdateDefaultUpdatedAt() du.mutation.SetUpdatedAt(v) } } func (du *DecisionUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: decision.Table, - Columns: decision.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(decision.Table, decision.Columns, sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt)) if ps := du.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -407,199 +150,32 @@ func (du *DecisionUpdate) sqlSave(ctx context.Context) (n int, err error) { } } } - if value, ok := du.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: decision.FieldCreatedAt, - }) - } - if du.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: decision.FieldCreatedAt, - }) - } if value, ok := du.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: decision.FieldUpdatedAt, - }) - } - if du.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: decision.FieldUpdatedAt, - }) + _spec.SetField(decision.FieldUpdatedAt, field.TypeTime, value) } if value, ok := du.mutation.Until(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: decision.FieldUntil, - }) + _spec.SetField(decision.FieldUntil, field.TypeTime, value) } if du.mutation.UntilCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: decision.FieldUntil, - }) - } - if value, ok := du.mutation.Scenario(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldScenario, - }) - } - if value, ok := du.mutation.GetType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldType, - }) - } - if value, ok := du.mutation.StartIP(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartIP, - }) - } - if value, ok := du.mutation.AddedStartIP(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartIP, - }) + _spec.ClearField(decision.FieldUntil, field.TypeTime) } if du.mutation.StartIPCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldStartIP, - }) - } - if value, ok := du.mutation.EndIP(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndIP, - }) - } - if value, ok := du.mutation.AddedEndIP(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndIP, - }) + _spec.ClearField(decision.FieldStartIP, field.TypeInt64) } if du.mutation.EndIPCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldEndIP, - }) - } - if value, ok := du.mutation.StartSuffix(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartSuffix, - }) - } - if value, ok := du.mutation.AddedStartSuffix(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartSuffix, - }) + _spec.ClearField(decision.FieldEndIP, field.TypeInt64) } if du.mutation.StartSuffixCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldStartSuffix, - }) - } - if value, ok := du.mutation.EndSuffix(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndSuffix, - }) - } - if value, ok := du.mutation.AddedEndSuffix(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndSuffix, - }) + _spec.ClearField(decision.FieldStartSuffix, field.TypeInt64) } if du.mutation.EndSuffixCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldEndSuffix, - }) - } - if value, ok := du.mutation.IPSize(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldIPSize, - }) - } - if value, ok := du.mutation.AddedIPSize(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldIPSize, - }) + _spec.ClearField(decision.FieldEndSuffix, field.TypeInt64) } if du.mutation.IPSizeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldIPSize, - }) - } - if value, ok := du.mutation.Scope(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldScope, - }) - } - if value, ok := du.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldValue, - }) - } - if value, ok := du.mutation.Origin(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldOrigin, - }) - } - if value, ok := du.mutation.Simulated(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: decision.FieldSimulated, - }) - } - if value, ok := du.mutation.UUID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldUUID, - }) + _spec.ClearField(decision.FieldIPSize, field.TypeInt64) } if du.mutation.UUIDCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: decision.FieldUUID, - }) + _spec.ClearField(decision.FieldUUID, field.TypeString) } if du.mutation.OwnerCleared() { edge := &sqlgraph.EdgeSpec{ @@ -609,10 +185,7 @@ func (du *DecisionUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{decision.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -625,10 +198,7 @@ func (du *DecisionUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{decision.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -644,6 +214,7 @@ func (du *DecisionUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + du.mutation.done = true return n, nil } @@ -655,30 +226,12 @@ type DecisionUpdateOne struct { mutation *DecisionMutation } -// SetCreatedAt sets the "created_at" field. -func (duo *DecisionUpdateOne) SetCreatedAt(t time.Time) *DecisionUpdateOne { - duo.mutation.SetCreatedAt(t) - return duo -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (duo *DecisionUpdateOne) ClearCreatedAt() *DecisionUpdateOne { - duo.mutation.ClearCreatedAt() - return duo -} - // SetUpdatedAt sets the "updated_at" field. func (duo *DecisionUpdateOne) SetUpdatedAt(t time.Time) *DecisionUpdateOne { duo.mutation.SetUpdatedAt(t) return duo } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (duo *DecisionUpdateOne) ClearUpdatedAt() *DecisionUpdateOne { - duo.mutation.ClearUpdatedAt() - return duo -} - // SetUntil sets the "until" field. func (duo *DecisionUpdateOne) SetUntil(t time.Time) *DecisionUpdateOne { duo.mutation.SetUntil(t) @@ -699,205 +252,6 @@ func (duo *DecisionUpdateOne) ClearUntil() *DecisionUpdateOne { return duo } -// SetScenario sets the "scenario" field. -func (duo *DecisionUpdateOne) SetScenario(s string) *DecisionUpdateOne { - duo.mutation.SetScenario(s) - return duo -} - -// SetType sets the "type" field. -func (duo *DecisionUpdateOne) SetType(s string) *DecisionUpdateOne { - duo.mutation.SetType(s) - return duo -} - -// SetStartIP sets the "start_ip" field. -func (duo *DecisionUpdateOne) SetStartIP(i int64) *DecisionUpdateOne { - duo.mutation.ResetStartIP() - duo.mutation.SetStartIP(i) - return duo -} - -// SetNillableStartIP sets the "start_ip" field if the given value is not nil. -func (duo *DecisionUpdateOne) SetNillableStartIP(i *int64) *DecisionUpdateOne { - if i != nil { - duo.SetStartIP(*i) - } - return duo -} - -// AddStartIP adds i to the "start_ip" field. -func (duo *DecisionUpdateOne) AddStartIP(i int64) *DecisionUpdateOne { - duo.mutation.AddStartIP(i) - return duo -} - -// ClearStartIP clears the value of the "start_ip" field. -func (duo *DecisionUpdateOne) ClearStartIP() *DecisionUpdateOne { - duo.mutation.ClearStartIP() - return duo -} - -// SetEndIP sets the "end_ip" field. -func (duo *DecisionUpdateOne) SetEndIP(i int64) *DecisionUpdateOne { - duo.mutation.ResetEndIP() - duo.mutation.SetEndIP(i) - return duo -} - -// SetNillableEndIP sets the "end_ip" field if the given value is not nil. -func (duo *DecisionUpdateOne) SetNillableEndIP(i *int64) *DecisionUpdateOne { - if i != nil { - duo.SetEndIP(*i) - } - return duo -} - -// AddEndIP adds i to the "end_ip" field. -func (duo *DecisionUpdateOne) AddEndIP(i int64) *DecisionUpdateOne { - duo.mutation.AddEndIP(i) - return duo -} - -// ClearEndIP clears the value of the "end_ip" field. -func (duo *DecisionUpdateOne) ClearEndIP() *DecisionUpdateOne { - duo.mutation.ClearEndIP() - return duo -} - -// SetStartSuffix sets the "start_suffix" field. -func (duo *DecisionUpdateOne) SetStartSuffix(i int64) *DecisionUpdateOne { - duo.mutation.ResetStartSuffix() - duo.mutation.SetStartSuffix(i) - return duo -} - -// SetNillableStartSuffix sets the "start_suffix" field if the given value is not nil. -func (duo *DecisionUpdateOne) SetNillableStartSuffix(i *int64) *DecisionUpdateOne { - if i != nil { - duo.SetStartSuffix(*i) - } - return duo -} - -// AddStartSuffix adds i to the "start_suffix" field. -func (duo *DecisionUpdateOne) AddStartSuffix(i int64) *DecisionUpdateOne { - duo.mutation.AddStartSuffix(i) - return duo -} - -// ClearStartSuffix clears the value of the "start_suffix" field. -func (duo *DecisionUpdateOne) ClearStartSuffix() *DecisionUpdateOne { - duo.mutation.ClearStartSuffix() - return duo -} - -// SetEndSuffix sets the "end_suffix" field. -func (duo *DecisionUpdateOne) SetEndSuffix(i int64) *DecisionUpdateOne { - duo.mutation.ResetEndSuffix() - duo.mutation.SetEndSuffix(i) - return duo -} - -// SetNillableEndSuffix sets the "end_suffix" field if the given value is not nil. -func (duo *DecisionUpdateOne) SetNillableEndSuffix(i *int64) *DecisionUpdateOne { - if i != nil { - duo.SetEndSuffix(*i) - } - return duo -} - -// AddEndSuffix adds i to the "end_suffix" field. -func (duo *DecisionUpdateOne) AddEndSuffix(i int64) *DecisionUpdateOne { - duo.mutation.AddEndSuffix(i) - return duo -} - -// ClearEndSuffix clears the value of the "end_suffix" field. -func (duo *DecisionUpdateOne) ClearEndSuffix() *DecisionUpdateOne { - duo.mutation.ClearEndSuffix() - return duo -} - -// SetIPSize sets the "ip_size" field. -func (duo *DecisionUpdateOne) SetIPSize(i int64) *DecisionUpdateOne { - duo.mutation.ResetIPSize() - duo.mutation.SetIPSize(i) - return duo -} - -// SetNillableIPSize sets the "ip_size" field if the given value is not nil. -func (duo *DecisionUpdateOne) SetNillableIPSize(i *int64) *DecisionUpdateOne { - if i != nil { - duo.SetIPSize(*i) - } - return duo -} - -// AddIPSize adds i to the "ip_size" field. -func (duo *DecisionUpdateOne) AddIPSize(i int64) *DecisionUpdateOne { - duo.mutation.AddIPSize(i) - return duo -} - -// ClearIPSize clears the value of the "ip_size" field. -func (duo *DecisionUpdateOne) ClearIPSize() *DecisionUpdateOne { - duo.mutation.ClearIPSize() - return duo -} - -// SetScope sets the "scope" field. -func (duo *DecisionUpdateOne) SetScope(s string) *DecisionUpdateOne { - duo.mutation.SetScope(s) - return duo -} - -// SetValue sets the "value" field. -func (duo *DecisionUpdateOne) SetValue(s string) *DecisionUpdateOne { - duo.mutation.SetValue(s) - return duo -} - -// SetOrigin sets the "origin" field. -func (duo *DecisionUpdateOne) SetOrigin(s string) *DecisionUpdateOne { - duo.mutation.SetOrigin(s) - return duo -} - -// SetSimulated sets the "simulated" field. -func (duo *DecisionUpdateOne) SetSimulated(b bool) *DecisionUpdateOne { - duo.mutation.SetSimulated(b) - return duo -} - -// SetNillableSimulated sets the "simulated" field if the given value is not nil. -func (duo *DecisionUpdateOne) SetNillableSimulated(b *bool) *DecisionUpdateOne { - if b != nil { - duo.SetSimulated(*b) - } - return duo -} - -// SetUUID sets the "uuid" field. -func (duo *DecisionUpdateOne) SetUUID(s string) *DecisionUpdateOne { - duo.mutation.SetUUID(s) - return duo -} - -// SetNillableUUID sets the "uuid" field if the given value is not nil. -func (duo *DecisionUpdateOne) SetNillableUUID(s *string) *DecisionUpdateOne { - if s != nil { - duo.SetUUID(*s) - } - return duo -} - -// ClearUUID clears the value of the "uuid" field. -func (duo *DecisionUpdateOne) ClearUUID() *DecisionUpdateOne { - duo.mutation.ClearUUID() - return duo -} - // SetAlertDecisions sets the "alert_decisions" field. func (duo *DecisionUpdateOne) SetAlertDecisions(i int) *DecisionUpdateOne { duo.mutation.SetAlertDecisions(i) @@ -948,6 +302,12 @@ func (duo *DecisionUpdateOne) ClearOwner() *DecisionUpdateOne { return duo } +// Where appends a list predicates to the DecisionUpdate builder. +func (duo *DecisionUpdateOne) Where(ps ...predicate.Decision) *DecisionUpdateOne { + duo.mutation.Where(ps...) + return duo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (duo *DecisionUpdateOne) Select(field string, fields ...string) *DecisionUpdateOne { @@ -957,41 +317,8 @@ func (duo *DecisionUpdateOne) Select(field string, fields ...string) *DecisionUp // Save executes the query and returns the updated Decision entity. func (duo *DecisionUpdateOne) Save(ctx context.Context) (*Decision, error) { - var ( - err error - node *Decision - ) duo.defaults() - if len(duo.hooks) == 0 { - node, err = duo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*DecisionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - duo.mutation = mutation - node, err = duo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(duo.hooks) - 1; i >= 0; i-- { - if duo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = duo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, duo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Decision) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from DecisionMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, duo.sqlSave, duo.mutation, duo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -1018,27 +345,14 @@ func (duo *DecisionUpdateOne) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (duo *DecisionUpdateOne) defaults() { - if _, ok := duo.mutation.CreatedAt(); !ok && !duo.mutation.CreatedAtCleared() { - v := decision.UpdateDefaultCreatedAt() - duo.mutation.SetCreatedAt(v) - } - if _, ok := duo.mutation.UpdatedAt(); !ok && !duo.mutation.UpdatedAtCleared() { + if _, ok := duo.mutation.UpdatedAt(); !ok { v := decision.UpdateDefaultUpdatedAt() duo.mutation.SetUpdatedAt(v) } } func (duo *DecisionUpdateOne) sqlSave(ctx context.Context) (_node *Decision, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: decision.Table, - Columns: decision.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(decision.Table, decision.Columns, sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt)) id, ok := duo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Decision.id" for update`)} @@ -1063,199 +377,32 @@ func (duo *DecisionUpdateOne) sqlSave(ctx context.Context) (_node *Decision, err } } } - if value, ok := duo.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: decision.FieldCreatedAt, - }) - } - if duo.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: decision.FieldCreatedAt, - }) - } if value, ok := duo.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: decision.FieldUpdatedAt, - }) - } - if duo.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: decision.FieldUpdatedAt, - }) + _spec.SetField(decision.FieldUpdatedAt, field.TypeTime, value) } if value, ok := duo.mutation.Until(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: decision.FieldUntil, - }) + _spec.SetField(decision.FieldUntil, field.TypeTime, value) } if duo.mutation.UntilCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: decision.FieldUntil, - }) - } - if value, ok := duo.mutation.Scenario(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldScenario, - }) - } - if value, ok := duo.mutation.GetType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldType, - }) - } - if value, ok := duo.mutation.StartIP(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartIP, - }) - } - if value, ok := duo.mutation.AddedStartIP(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartIP, - }) + _spec.ClearField(decision.FieldUntil, field.TypeTime) } if duo.mutation.StartIPCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldStartIP, - }) - } - if value, ok := duo.mutation.EndIP(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndIP, - }) - } - if value, ok := duo.mutation.AddedEndIP(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndIP, - }) + _spec.ClearField(decision.FieldStartIP, field.TypeInt64) } if duo.mutation.EndIPCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldEndIP, - }) - } - if value, ok := duo.mutation.StartSuffix(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartSuffix, - }) - } - if value, ok := duo.mutation.AddedStartSuffix(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartSuffix, - }) + _spec.ClearField(decision.FieldEndIP, field.TypeInt64) } if duo.mutation.StartSuffixCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldStartSuffix, - }) - } - if value, ok := duo.mutation.EndSuffix(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndSuffix, - }) - } - if value, ok := duo.mutation.AddedEndSuffix(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndSuffix, - }) + _spec.ClearField(decision.FieldStartSuffix, field.TypeInt64) } if duo.mutation.EndSuffixCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldEndSuffix, - }) - } - if value, ok := duo.mutation.IPSize(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldIPSize, - }) - } - if value, ok := duo.mutation.AddedIPSize(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldIPSize, - }) + _spec.ClearField(decision.FieldEndSuffix, field.TypeInt64) } if duo.mutation.IPSizeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldIPSize, - }) - } - if value, ok := duo.mutation.Scope(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldScope, - }) - } - if value, ok := duo.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldValue, - }) - } - if value, ok := duo.mutation.Origin(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldOrigin, - }) - } - if value, ok := duo.mutation.Simulated(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: decision.FieldSimulated, - }) - } - if value, ok := duo.mutation.UUID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldUUID, - }) + _spec.ClearField(decision.FieldIPSize, field.TypeInt64) } if duo.mutation.UUIDCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: decision.FieldUUID, - }) + _spec.ClearField(decision.FieldUUID, field.TypeString) } if duo.mutation.OwnerCleared() { edge := &sqlgraph.EdgeSpec{ @@ -1265,10 +412,7 @@ func (duo *DecisionUpdateOne) sqlSave(ctx context.Context) (_node *Decision, err Columns: []string{decision.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -1281,10 +425,7 @@ func (duo *DecisionUpdateOne) sqlSave(ctx context.Context) (_node *Decision, err Columns: []string{decision.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1303,5 +444,6 @@ func (duo *DecisionUpdateOne) sqlSave(ctx context.Context) (_node *Decision, err } return nil, err } + duo.mutation.done = true return _node, nil } diff --git a/pkg/database/ent/ent.go b/pkg/database/ent/ent.go index 0455af444d2..2a5ad188197 100644 --- a/pkg/database/ent/ent.go +++ b/pkg/database/ent/ent.go @@ -6,6 +6,8 @@ import ( "context" "errors" "fmt" + "reflect" + "sync" "entgo.io/ent" "entgo.io/ent/dialect/sql" @@ -15,56 +17,89 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" "github.com/crowdsecurity/crowdsec/pkg/database/ent/event" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" "github.com/crowdsecurity/crowdsec/pkg/database/ent/meta" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" ) // ent aliases to avoid import conflicts in user's code. type ( - Op = ent.Op - Hook = ent.Hook - Value = ent.Value - Query = ent.Query - Policy = ent.Policy - Mutator = ent.Mutator - Mutation = ent.Mutation - MutateFunc = ent.MutateFunc + Op = ent.Op + Hook = ent.Hook + Value = ent.Value + Query = ent.Query + QueryContext = ent.QueryContext + Querier = ent.Querier + QuerierFunc = ent.QuerierFunc + Interceptor = ent.Interceptor + InterceptFunc = ent.InterceptFunc + Traverser = ent.Traverser + TraverseFunc = ent.TraverseFunc + Policy = ent.Policy + Mutator = ent.Mutator + Mutation = ent.Mutation + MutateFunc = ent.MutateFunc ) +type clientCtxKey struct{} + +// FromContext returns a Client stored inside a context, or nil if there isn't one. +func FromContext(ctx context.Context) *Client { + c, _ := ctx.Value(clientCtxKey{}).(*Client) + return c +} + +// NewContext returns a new context with the given Client attached. +func NewContext(parent context.Context, c *Client) context.Context { + return context.WithValue(parent, clientCtxKey{}, c) +} + +type txCtxKey struct{} + +// TxFromContext returns a Tx stored inside a context, or nil if there isn't one. +func TxFromContext(ctx context.Context) *Tx { + tx, _ := ctx.Value(txCtxKey{}).(*Tx) + return tx +} + +// NewTxContext returns a new context with the given Tx attached. +func NewTxContext(parent context.Context, tx *Tx) context.Context { + return context.WithValue(parent, txCtxKey{}, tx) +} + // OrderFunc applies an ordering on the sql selector. +// Deprecated: Use Asc/Desc functions or the package builders instead. type OrderFunc func(*sql.Selector) -// columnChecker returns a function indicates if the column exists in the given column. -func columnChecker(table string) func(string) error { - checks := map[string]func(string) bool{ - alert.Table: alert.ValidColumn, - bouncer.Table: bouncer.ValidColumn, - configitem.Table: configitem.ValidColumn, - decision.Table: decision.ValidColumn, - event.Table: event.ValidColumn, - machine.Table: machine.ValidColumn, - meta.Table: meta.ValidColumn, - } - check, ok := checks[table] - if !ok { - return func(string) error { - return fmt.Errorf("unknown table %q", table) - } - } - return func(column string) error { - if !check(column) { - return fmt.Errorf("unknown column %q for table %q", column, table) - } - return nil - } +var ( + initCheck sync.Once + columnCheck sql.ColumnCheck +) + +// columnChecker checks if the column exists in the given table. +func checkColumn(table, column string) error { + initCheck.Do(func() { + columnCheck = sql.NewColumnCheck(map[string]func(string) bool{ + alert.Table: alert.ValidColumn, + bouncer.Table: bouncer.ValidColumn, + configitem.Table: configitem.ValidColumn, + decision.Table: decision.ValidColumn, + event.Table: event.ValidColumn, + lock.Table: lock.ValidColumn, + machine.Table: machine.ValidColumn, + meta.Table: meta.ValidColumn, + metric.Table: metric.ValidColumn, + }) + }) + return columnCheck(table, column) } // Asc applies the given fields in ASC order. -func Asc(fields ...string) OrderFunc { +func Asc(fields ...string) func(*sql.Selector) { return func(s *sql.Selector) { - check := columnChecker(s.TableName()) for _, f := range fields { - if err := check(f); err != nil { + if err := checkColumn(s.TableName(), f); err != nil { s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ent: %w", err)}) } s.OrderBy(sql.Asc(s.C(f))) @@ -73,11 +108,10 @@ func Asc(fields ...string) OrderFunc { } // Desc applies the given fields in DESC order. -func Desc(fields ...string) OrderFunc { +func Desc(fields ...string) func(*sql.Selector) { return func(s *sql.Selector) { - check := columnChecker(s.TableName()) for _, f := range fields { - if err := check(f); err != nil { + if err := checkColumn(s.TableName(), f); err != nil { s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ent: %w", err)}) } s.OrderBy(sql.Desc(s.C(f))) @@ -109,8 +143,7 @@ func Count() AggregateFunc { // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { return func(s *sql.Selector) string { - check := columnChecker(s.TableName()) - if err := check(field); err != nil { + if err := checkColumn(s.TableName(), field); err != nil { s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) return "" } @@ -121,8 +154,7 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { return func(s *sql.Selector) string { - check := columnChecker(s.TableName()) - if err := check(field); err != nil { + if err := checkColumn(s.TableName(), field); err != nil { s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) return "" } @@ -133,8 +165,7 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { return func(s *sql.Selector) string { - check := columnChecker(s.TableName()) - if err := check(field); err != nil { + if err := checkColumn(s.TableName(), field); err != nil { s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) return "" } @@ -145,8 +176,7 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { return func(s *sql.Selector) string { - check := columnChecker(s.TableName()) - if err := check(field); err != nil { + if err := checkColumn(s.TableName(), field); err != nil { s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) return "" } @@ -275,6 +305,7 @@ func IsConstraintError(err error) bool { type selector struct { label string flds *[]string + fns []AggregateFunc scan func(context.Context, any) error } @@ -473,5 +504,121 @@ func (s *selector) BoolX(ctx context.Context) bool { return v } +// withHooks invokes the builder operation with the given hooks, if any. +func withHooks[V Value, M any, PM interface { + *M + Mutation +}](ctx context.Context, exec func(context.Context) (V, error), mutation PM, hooks []Hook) (value V, err error) { + if len(hooks) == 0 { + return exec(ctx) + } + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutationT, ok := any(m).(PM) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + // Set the mutation to the builder. + *mutation = *mutationT + return exec(ctx) + }) + for i := len(hooks) - 1; i >= 0; i-- { + if hooks[i] == nil { + return value, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") + } + mut = hooks[i](mut) + } + v, err := mut.Mutate(ctx, mutation) + if err != nil { + return value, err + } + nv, ok := v.(V) + if !ok { + return value, fmt.Errorf("unexpected node type %T returned from %T", v, mutation) + } + return nv, nil +} + +// setContextOp returns a new context with the given QueryContext attached (including its op) in case it does not exist. +func setContextOp(ctx context.Context, qc *QueryContext, op string) context.Context { + if ent.QueryFromContext(ctx) == nil { + qc.Op = op + ctx = ent.NewQueryContext(ctx, qc) + } + return ctx +} + +func querierAll[V Value, Q interface { + sqlAll(context.Context, ...queryHook) (V, error) +}]() Querier { + return QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + query, ok := q.(Q) + if !ok { + return nil, fmt.Errorf("unexpected query type %T", q) + } + return query.sqlAll(ctx) + }) +} + +func querierCount[Q interface { + sqlCount(context.Context) (int, error) +}]() Querier { + return QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + query, ok := q.(Q) + if !ok { + return nil, fmt.Errorf("unexpected query type %T", q) + } + return query.sqlCount(ctx) + }) +} + +func withInterceptors[V Value](ctx context.Context, q Query, qr Querier, inters []Interceptor) (v V, err error) { + for i := len(inters) - 1; i >= 0; i-- { + qr = inters[i].Intercept(qr) + } + rv, err := qr.Query(ctx, q) + if err != nil { + return v, err + } + vt, ok := rv.(V) + if !ok { + return v, fmt.Errorf("unexpected type %T returned from %T. expected type: %T", vt, q, v) + } + return vt, nil +} + +func scanWithInterceptors[Q1 ent.Query, Q2 interface { + sqlScan(context.Context, Q1, any) error +}](ctx context.Context, rootQuery Q1, selectOrGroup Q2, inters []Interceptor, v any) error { + rv := reflect.ValueOf(v) + var qr Querier = QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + query, ok := q.(Q1) + if !ok { + return nil, fmt.Errorf("unexpected query type %T", q) + } + if err := selectOrGroup.sqlScan(ctx, query, v); err != nil { + return nil, err + } + if k := rv.Kind(); k == reflect.Pointer && rv.Elem().CanInterface() { + return rv.Elem().Interface(), nil + } + return v, nil + }) + for i := len(inters) - 1; i >= 0; i-- { + qr = inters[i].Intercept(qr) + } + vv, err := qr.Query(ctx, rootQuery) + if err != nil { + return err + } + switch rv2 := reflect.ValueOf(vv); { + case rv.IsNil(), rv2.IsNil(), rv.Kind() != reflect.Pointer: + case rv.Type() == rv2.Type(): + rv.Elem().Set(rv2.Elem()) + case rv.Elem().Type() == rv2.Type(): + rv.Elem().Set(rv2) + } + return nil +} + // queryHook describes an internal hook for the different sqlAll methods. type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/pkg/database/ent/event.go b/pkg/database/ent/event.go index 4754107fddc..b57f1f34ac9 100644 --- a/pkg/database/ent/event.go +++ b/pkg/database/ent/event.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" "github.com/crowdsecurity/crowdsec/pkg/database/ent/event" @@ -18,9 +19,9 @@ type Event struct { // ID of the ent. ID int `json:"id,omitempty"` // CreatedAt holds the value of the "created_at" field. - CreatedAt *time.Time `json:"created_at,omitempty"` + CreatedAt time.Time `json:"created_at,omitempty"` // UpdatedAt holds the value of the "updated_at" field. - UpdatedAt *time.Time `json:"updated_at,omitempty"` + UpdatedAt time.Time `json:"updated_at,omitempty"` // Time holds the value of the "time" field. Time time.Time `json:"time,omitempty"` // Serialized holds the value of the "serialized" field. @@ -29,7 +30,8 @@ type Event struct { AlertEvents int `json:"alert_events,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the EventQuery when eager-loading is set. - Edges EventEdges `json:"edges"` + Edges EventEdges `json:"edges"` + selectValues sql.SelectValues } // EventEdges holds the relations/edges for other nodes in the graph. @@ -44,12 +46,10 @@ type EventEdges struct { // OwnerOrErr returns the Owner value or an error if the edge // was not loaded in eager-loading, or loaded but was not found. func (e EventEdges) OwnerOrErr() (*Alert, error) { - if e.loadedTypes[0] { - if e.Owner == nil { - // Edge was loaded but was not found. - return nil, &NotFoundError{label: alert.Label} - } + if e.Owner != nil { return e.Owner, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: alert.Label} } return nil, &NotLoadedError{edge: "owner"} } @@ -66,7 +66,7 @@ func (*Event) scanValues(columns []string) ([]any, error) { case event.FieldCreatedAt, event.FieldUpdatedAt, event.FieldTime: values[i] = new(sql.NullTime) default: - return nil, fmt.Errorf("unexpected column %q for type Event", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -90,15 +90,13 @@ func (e *Event) assignValues(columns []string, values []any) error { if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) } else if value.Valid { - e.CreatedAt = new(time.Time) - *e.CreatedAt = value.Time + e.CreatedAt = value.Time } case event.FieldUpdatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field updated_at", values[i]) } else if value.Valid { - e.UpdatedAt = new(time.Time) - *e.UpdatedAt = value.Time + e.UpdatedAt = value.Time } case event.FieldTime: if value, ok := values[i].(*sql.NullTime); !ok { @@ -118,21 +116,29 @@ func (e *Event) assignValues(columns []string, values []any) error { } else if value.Valid { e.AlertEvents = int(value.Int64) } + default: + e.selectValues.Set(columns[i], values[i]) } } return nil } +// Value returns the ent.Value that was dynamically selected and assigned to the Event. +// This includes values selected through modifiers, order, etc. +func (e *Event) Value(name string) (ent.Value, error) { + return e.selectValues.Get(name) +} + // QueryOwner queries the "owner" edge of the Event entity. func (e *Event) QueryOwner() *AlertQuery { - return (&EventClient{config: e.config}).QueryOwner(e) + return NewEventClient(e.config).QueryOwner(e) } // Update returns a builder for updating this Event. // Note that you need to call Event.Unwrap() before calling this method if this Event // was returned from a transaction, and the transaction was committed or rolled back. func (e *Event) Update() *EventUpdateOne { - return (&EventClient{config: e.config}).UpdateOne(e) + return NewEventClient(e.config).UpdateOne(e) } // Unwrap unwraps the Event entity that was returned from a transaction after it was closed, @@ -151,15 +157,11 @@ func (e *Event) String() string { var builder strings.Builder builder.WriteString("Event(") builder.WriteString(fmt.Sprintf("id=%v, ", e.ID)) - if v := e.CreatedAt; v != nil { - builder.WriteString("created_at=") - builder.WriteString(v.Format(time.ANSIC)) - } + builder.WriteString("created_at=") + builder.WriteString(e.CreatedAt.Format(time.ANSIC)) builder.WriteString(", ") - if v := e.UpdatedAt; v != nil { - builder.WriteString("updated_at=") - builder.WriteString(v.Format(time.ANSIC)) - } + builder.WriteString("updated_at=") + builder.WriteString(e.UpdatedAt.Format(time.ANSIC)) builder.WriteString(", ") builder.WriteString("time=") builder.WriteString(e.Time.Format(time.ANSIC)) @@ -175,9 +177,3 @@ func (e *Event) String() string { // Events is a parsable slice of Event. type Events []*Event - -func (e Events) config(cfg config) { - for _i := range e { - e[_i].config = cfg - } -} diff --git a/pkg/database/ent/event/event.go b/pkg/database/ent/event/event.go index 33b9b67f8b9..c975a612669 100644 --- a/pkg/database/ent/event/event.go +++ b/pkg/database/ent/event/event.go @@ -4,6 +4,9 @@ package event import ( "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" ) const ( @@ -57,8 +60,6 @@ func ValidColumn(column string) bool { var ( // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time - // UpdateDefaultCreatedAt holds the default value on update for the "created_at" field. - UpdateDefaultCreatedAt func() time.Time // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. DefaultUpdatedAt func() time.Time // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. @@ -66,3 +67,50 @@ var ( // SerializedValidator is a validator for the "serialized" field. It is called by the builders before save. SerializedValidator func(string) error ) + +// OrderOption defines the ordering options for the Event queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByTime orders the results by the time field. +func ByTime(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTime, opts...).ToFunc() +} + +// BySerialized orders the results by the serialized field. +func BySerialized(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSerialized, opts...).ToFunc() +} + +// ByAlertEvents orders the results by the alert_events field. +func ByAlertEvents(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAlertEvents, opts...).ToFunc() +} + +// ByOwnerField orders the results by owner field. +func ByOwnerField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newOwnerStep(), sql.OrderByField(field, opts...)) + } +} +func newOwnerStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(OwnerInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), + ) +} diff --git a/pkg/database/ent/event/where.go b/pkg/database/ent/event/where.go index 7554e59e678..d420b125026 100644 --- a/pkg/database/ent/event/where.go +++ b/pkg/database/ent/event/where.go @@ -12,477 +12,287 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Event(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Event(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.Event(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.Event(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.Event(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.Event(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.Event(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.Event(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.Event(sql.FieldLTE(FieldID, id)) } // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Event(sql.FieldEQ(FieldCreatedAt, v)) } // UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. func UpdatedAt(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Event(sql.FieldEQ(FieldUpdatedAt, v)) } // Time applies equality check predicate on the "time" field. It's identical to TimeEQ. func Time(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldTime), v)) - }) + return predicate.Event(sql.FieldEQ(FieldTime, v)) } // Serialized applies equality check predicate on the "serialized" field. It's identical to SerializedEQ. func Serialized(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldEQ(FieldSerialized, v)) } // AlertEvents applies equality check predicate on the "alert_events" field. It's identical to AlertEventsEQ. func AlertEvents(v int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAlertEvents), v)) - }) + return predicate.Event(sql.FieldEQ(FieldAlertEvents, v)) } // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Event(sql.FieldEQ(FieldCreatedAt, v)) } // CreatedAtNEQ applies the NEQ predicate on the "created_at" field. func CreatedAtNEQ(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Event(sql.FieldNEQ(FieldCreatedAt, v)) } // CreatedAtIn applies the In predicate on the "created_at" field. func CreatedAtIn(vs ...time.Time) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldCreatedAt), v...)) - }) + return predicate.Event(sql.FieldIn(FieldCreatedAt, vs...)) } // CreatedAtNotIn applies the NotIn predicate on the "created_at" field. func CreatedAtNotIn(vs ...time.Time) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldCreatedAt), v...)) - }) + return predicate.Event(sql.FieldNotIn(FieldCreatedAt, vs...)) } // CreatedAtGT applies the GT predicate on the "created_at" field. func CreatedAtGT(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldCreatedAt), v)) - }) + return predicate.Event(sql.FieldGT(FieldCreatedAt, v)) } // CreatedAtGTE applies the GTE predicate on the "created_at" field. func CreatedAtGTE(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldCreatedAt), v)) - }) + return predicate.Event(sql.FieldGTE(FieldCreatedAt, v)) } // CreatedAtLT applies the LT predicate on the "created_at" field. func CreatedAtLT(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldCreatedAt), v)) - }) + return predicate.Event(sql.FieldLT(FieldCreatedAt, v)) } // CreatedAtLTE applies the LTE predicate on the "created_at" field. func CreatedAtLTE(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldCreatedAt), v)) - }) -} - -// CreatedAtIsNil applies the IsNil predicate on the "created_at" field. -func CreatedAtIsNil() predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldCreatedAt))) - }) -} - -// CreatedAtNotNil applies the NotNil predicate on the "created_at" field. -func CreatedAtNotNil() predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldCreatedAt))) - }) + return predicate.Event(sql.FieldLTE(FieldCreatedAt, v)) } // UpdatedAtEQ applies the EQ predicate on the "updated_at" field. func UpdatedAtEQ(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Event(sql.FieldEQ(FieldUpdatedAt, v)) } // UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. func UpdatedAtNEQ(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Event(sql.FieldNEQ(FieldUpdatedAt, v)) } // UpdatedAtIn applies the In predicate on the "updated_at" field. func UpdatedAtIn(vs ...time.Time) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Event(sql.FieldIn(FieldUpdatedAt, vs...)) } // UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. func UpdatedAtNotIn(vs ...time.Time) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Event(sql.FieldNotIn(FieldUpdatedAt, vs...)) } // UpdatedAtGT applies the GT predicate on the "updated_at" field. func UpdatedAtGT(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Event(sql.FieldGT(FieldUpdatedAt, v)) } // UpdatedAtGTE applies the GTE predicate on the "updated_at" field. func UpdatedAtGTE(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUpdatedAt), v)) - }) + return predicate.Event(sql.FieldGTE(FieldUpdatedAt, v)) } // UpdatedAtLT applies the LT predicate on the "updated_at" field. func UpdatedAtLT(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Event(sql.FieldLT(FieldUpdatedAt, v)) } // UpdatedAtLTE applies the LTE predicate on the "updated_at" field. func UpdatedAtLTE(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUpdatedAt), v)) - }) -} - -// UpdatedAtIsNil applies the IsNil predicate on the "updated_at" field. -func UpdatedAtIsNil() predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUpdatedAt))) - }) -} - -// UpdatedAtNotNil applies the NotNil predicate on the "updated_at" field. -func UpdatedAtNotNil() predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUpdatedAt))) - }) + return predicate.Event(sql.FieldLTE(FieldUpdatedAt, v)) } // TimeEQ applies the EQ predicate on the "time" field. func TimeEQ(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldTime), v)) - }) + return predicate.Event(sql.FieldEQ(FieldTime, v)) } // TimeNEQ applies the NEQ predicate on the "time" field. func TimeNEQ(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldTime), v)) - }) + return predicate.Event(sql.FieldNEQ(FieldTime, v)) } // TimeIn applies the In predicate on the "time" field. func TimeIn(vs ...time.Time) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldTime), v...)) - }) + return predicate.Event(sql.FieldIn(FieldTime, vs...)) } // TimeNotIn applies the NotIn predicate on the "time" field. func TimeNotIn(vs ...time.Time) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldTime), v...)) - }) + return predicate.Event(sql.FieldNotIn(FieldTime, vs...)) } // TimeGT applies the GT predicate on the "time" field. func TimeGT(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldTime), v)) - }) + return predicate.Event(sql.FieldGT(FieldTime, v)) } // TimeGTE applies the GTE predicate on the "time" field. func TimeGTE(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldTime), v)) - }) + return predicate.Event(sql.FieldGTE(FieldTime, v)) } // TimeLT applies the LT predicate on the "time" field. func TimeLT(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldTime), v)) - }) + return predicate.Event(sql.FieldLT(FieldTime, v)) } // TimeLTE applies the LTE predicate on the "time" field. func TimeLTE(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldTime), v)) - }) + return predicate.Event(sql.FieldLTE(FieldTime, v)) } // SerializedEQ applies the EQ predicate on the "serialized" field. func SerializedEQ(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldEQ(FieldSerialized, v)) } // SerializedNEQ applies the NEQ predicate on the "serialized" field. func SerializedNEQ(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldNEQ(FieldSerialized, v)) } // SerializedIn applies the In predicate on the "serialized" field. func SerializedIn(vs ...string) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSerialized), v...)) - }) + return predicate.Event(sql.FieldIn(FieldSerialized, vs...)) } // SerializedNotIn applies the NotIn predicate on the "serialized" field. func SerializedNotIn(vs ...string) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSerialized), v...)) - }) + return predicate.Event(sql.FieldNotIn(FieldSerialized, vs...)) } // SerializedGT applies the GT predicate on the "serialized" field. func SerializedGT(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldGT(FieldSerialized, v)) } // SerializedGTE applies the GTE predicate on the "serialized" field. func SerializedGTE(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldGTE(FieldSerialized, v)) } // SerializedLT applies the LT predicate on the "serialized" field. func SerializedLT(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldLT(FieldSerialized, v)) } // SerializedLTE applies the LTE predicate on the "serialized" field. func SerializedLTE(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldLTE(FieldSerialized, v)) } // SerializedContains applies the Contains predicate on the "serialized" field. func SerializedContains(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldContains(FieldSerialized, v)) } // SerializedHasPrefix applies the HasPrefix predicate on the "serialized" field. func SerializedHasPrefix(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldHasPrefix(FieldSerialized, v)) } // SerializedHasSuffix applies the HasSuffix predicate on the "serialized" field. func SerializedHasSuffix(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldHasSuffix(FieldSerialized, v)) } // SerializedEqualFold applies the EqualFold predicate on the "serialized" field. func SerializedEqualFold(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldEqualFold(FieldSerialized, v)) } // SerializedContainsFold applies the ContainsFold predicate on the "serialized" field. func SerializedContainsFold(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldContainsFold(FieldSerialized, v)) } // AlertEventsEQ applies the EQ predicate on the "alert_events" field. func AlertEventsEQ(v int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAlertEvents), v)) - }) + return predicate.Event(sql.FieldEQ(FieldAlertEvents, v)) } // AlertEventsNEQ applies the NEQ predicate on the "alert_events" field. func AlertEventsNEQ(v int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldAlertEvents), v)) - }) + return predicate.Event(sql.FieldNEQ(FieldAlertEvents, v)) } // AlertEventsIn applies the In predicate on the "alert_events" field. func AlertEventsIn(vs ...int) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldAlertEvents), v...)) - }) + return predicate.Event(sql.FieldIn(FieldAlertEvents, vs...)) } // AlertEventsNotIn applies the NotIn predicate on the "alert_events" field. func AlertEventsNotIn(vs ...int) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldAlertEvents), v...)) - }) + return predicate.Event(sql.FieldNotIn(FieldAlertEvents, vs...)) } // AlertEventsIsNil applies the IsNil predicate on the "alert_events" field. func AlertEventsIsNil() predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldAlertEvents))) - }) + return predicate.Event(sql.FieldIsNull(FieldAlertEvents)) } // AlertEventsNotNil applies the NotNil predicate on the "alert_events" field. func AlertEventsNotNil() predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldAlertEvents))) - }) + return predicate.Event(sql.FieldNotNull(FieldAlertEvents)) } // HasOwner applies the HasEdge predicate on the "owner" edge. @@ -490,7 +300,6 @@ func HasOwner() predicate.Event { return predicate.Event(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(OwnerTable, FieldID), sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), ) sqlgraph.HasNeighbors(s, step) @@ -500,11 +309,7 @@ func HasOwner() predicate.Event { // HasOwnerWith applies the HasEdge predicate on the "owner" edge with a given conditions (other predicates). func HasOwnerWith(preds ...predicate.Alert) predicate.Event { return predicate.Event(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(OwnerInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), - ) + step := newOwnerStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -515,32 +320,15 @@ func HasOwnerWith(preds ...predicate.Alert) predicate.Event { // And groups predicates with the AND operator between them. func And(predicates ...predicate.Event) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Event(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.Event) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Event(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.Event) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.Event(sql.NotPredicates(p)) } diff --git a/pkg/database/ent/event_create.go b/pkg/database/ent/event_create.go index c5861305130..36747babe47 100644 --- a/pkg/database/ent/event_create.go +++ b/pkg/database/ent/event_create.go @@ -101,50 +101,8 @@ func (ec *EventCreate) Mutation() *EventMutation { // Save creates the Event in the database. func (ec *EventCreate) Save(ctx context.Context) (*Event, error) { - var ( - err error - node *Event - ) ec.defaults() - if len(ec.hooks) == 0 { - if err = ec.check(); err != nil { - return nil, err - } - node, err = ec.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*EventMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = ec.check(); err != nil { - return nil, err - } - ec.mutation = mutation - if node, err = ec.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(ec.hooks) - 1; i >= 0; i-- { - if ec.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = ec.hooks[i](mut) - } - v, err := mut.Mutate(ctx, ec.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Event) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from EventMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, ec.sqlSave, ec.mutation, ec.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -183,6 +141,12 @@ func (ec *EventCreate) defaults() { // check runs all checks and user-defined validators on the builder. func (ec *EventCreate) check() error { + if _, ok := ec.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Event.created_at"`)} + } + if _, ok := ec.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "Event.updated_at"`)} + } if _, ok := ec.mutation.Time(); !ok { return &ValidationError{Name: "time", err: errors.New(`ent: missing required field "Event.time"`)} } @@ -198,6 +162,9 @@ func (ec *EventCreate) check() error { } func (ec *EventCreate) sqlSave(ctx context.Context) (*Event, error) { + if err := ec.check(); err != nil { + return nil, err + } _node, _spec := ec.createSpec() if err := sqlgraph.CreateNode(ctx, ec.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -207,50 +174,30 @@ func (ec *EventCreate) sqlSave(ctx context.Context) (*Event, error) { } id := _spec.ID.Value.(int64) _node.ID = int(id) + ec.mutation.id = &_node.ID + ec.mutation.done = true return _node, nil } func (ec *EventCreate) createSpec() (*Event, *sqlgraph.CreateSpec) { var ( _node = &Event{config: ec.config} - _spec = &sqlgraph.CreateSpec{ - Table: event.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(event.Table, sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt)) ) if value, ok := ec.mutation.CreatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: event.FieldCreatedAt, - }) - _node.CreatedAt = &value + _spec.SetField(event.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value } if value, ok := ec.mutation.UpdatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: event.FieldUpdatedAt, - }) - _node.UpdatedAt = &value + _spec.SetField(event.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value } if value, ok := ec.mutation.Time(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: event.FieldTime, - }) + _spec.SetField(event.FieldTime, field.TypeTime, value) _node.Time = value } if value, ok := ec.mutation.Serialized(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: event.FieldSerialized, - }) + _spec.SetField(event.FieldSerialized, field.TypeString, value) _node.Serialized = value } if nodes := ec.mutation.OwnerIDs(); len(nodes) > 0 { @@ -261,10 +208,7 @@ func (ec *EventCreate) createSpec() (*Event, *sqlgraph.CreateSpec) { Columns: []string{event.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -279,11 +223,15 @@ func (ec *EventCreate) createSpec() (*Event, *sqlgraph.CreateSpec) { // EventCreateBulk is the builder for creating many Event entities in bulk. type EventCreateBulk struct { config + err error builders []*EventCreate } // Save creates the Event entities in the database. func (ecb *EventCreateBulk) Save(ctx context.Context) ([]*Event, error) { + if ecb.err != nil { + return nil, ecb.err + } specs := make([]*sqlgraph.CreateSpec, len(ecb.builders)) nodes := make([]*Event, len(ecb.builders)) mutators := make([]Mutator, len(ecb.builders)) @@ -300,8 +248,8 @@ func (ecb *EventCreateBulk) Save(ctx context.Context) ([]*Event, error) { return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, ecb.builders[i+1].mutation) } else { diff --git a/pkg/database/ent/event_delete.go b/pkg/database/ent/event_delete.go index 0220dc71d31..93dd1246b7e 100644 --- a/pkg/database/ent/event_delete.go +++ b/pkg/database/ent/event_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (ed *EventDelete) Where(ps ...predicate.Event) *EventDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (ed *EventDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(ed.hooks) == 0 { - affected, err = ed.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*EventMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - ed.mutation = mutation - affected, err = ed.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(ed.hooks) - 1; i >= 0; i-- { - if ed.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = ed.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, ed.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, ed.sqlExec, ed.mutation, ed.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (ed *EventDelete) ExecX(ctx context.Context) int { } func (ed *EventDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: event.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(event.Table, sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt)) if ps := ed.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (ed *EventDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + ed.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type EventDeleteOne struct { ed *EventDelete } +// Where appends a list predicates to the EventDelete builder. +func (edo *EventDeleteOne) Where(ps ...predicate.Event) *EventDeleteOne { + edo.ed.mutation.Where(ps...) + return edo +} + // Exec executes the deletion query. func (edo *EventDeleteOne) Exec(ctx context.Context) error { n, err := edo.ed.Exec(ctx) @@ -111,5 +82,7 @@ func (edo *EventDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (edo *EventDeleteOne) ExecX(ctx context.Context) { - edo.ed.ExecX(ctx) + if err := edo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/pkg/database/ent/event_query.go b/pkg/database/ent/event_query.go index 045d750f818..1493d7bd32c 100644 --- a/pkg/database/ent/event_query.go +++ b/pkg/database/ent/event_query.go @@ -18,11 +18,9 @@ import ( // EventQuery is the builder for querying Event entities. type EventQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string + ctx *QueryContext + order []event.OrderOption + inters []Interceptor predicates []predicate.Event withOwner *AlertQuery // intermediate query (i.e. traversal path). @@ -36,34 +34,34 @@ func (eq *EventQuery) Where(ps ...predicate.Event) *EventQuery { return eq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (eq *EventQuery) Limit(limit int) *EventQuery { - eq.limit = &limit + eq.ctx.Limit = &limit return eq } -// Offset adds an offset step to the query. +// Offset to start from. func (eq *EventQuery) Offset(offset int) *EventQuery { - eq.offset = &offset + eq.ctx.Offset = &offset return eq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (eq *EventQuery) Unique(unique bool) *EventQuery { - eq.unique = &unique + eq.ctx.Unique = &unique return eq } -// Order adds an order step to the query. -func (eq *EventQuery) Order(o ...OrderFunc) *EventQuery { +// Order specifies how the records should be ordered. +func (eq *EventQuery) Order(o ...event.OrderOption) *EventQuery { eq.order = append(eq.order, o...) return eq } // QueryOwner chains the current query on the "owner" edge. func (eq *EventQuery) QueryOwner() *AlertQuery { - query := &AlertQuery{config: eq.config} + query := (&AlertClient{config: eq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := eq.prepareQuery(ctx); err != nil { return nil, err @@ -86,7 +84,7 @@ func (eq *EventQuery) QueryOwner() *AlertQuery { // First returns the first Event entity from the query. // Returns a *NotFoundError when no Event was found. func (eq *EventQuery) First(ctx context.Context) (*Event, error) { - nodes, err := eq.Limit(1).All(ctx) + nodes, err := eq.Limit(1).All(setContextOp(ctx, eq.ctx, "First")) if err != nil { return nil, err } @@ -109,7 +107,7 @@ func (eq *EventQuery) FirstX(ctx context.Context) *Event { // Returns a *NotFoundError when no Event ID was found. func (eq *EventQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = eq.Limit(1).IDs(ctx); err != nil { + if ids, err = eq.Limit(1).IDs(setContextOp(ctx, eq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -132,7 +130,7 @@ func (eq *EventQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Event entity is found. // Returns a *NotFoundError when no Event entities are found. func (eq *EventQuery) Only(ctx context.Context) (*Event, error) { - nodes, err := eq.Limit(2).All(ctx) + nodes, err := eq.Limit(2).All(setContextOp(ctx, eq.ctx, "Only")) if err != nil { return nil, err } @@ -160,7 +158,7 @@ func (eq *EventQuery) OnlyX(ctx context.Context) *Event { // Returns a *NotFoundError when no entities are found. func (eq *EventQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = eq.Limit(2).IDs(ctx); err != nil { + if ids, err = eq.Limit(2).IDs(setContextOp(ctx, eq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -185,10 +183,12 @@ func (eq *EventQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Events. func (eq *EventQuery) All(ctx context.Context) ([]*Event, error) { + ctx = setContextOp(ctx, eq.ctx, "All") if err := eq.prepareQuery(ctx); err != nil { return nil, err } - return eq.sqlAll(ctx) + qr := querierAll[[]*Event, *EventQuery]() + return withInterceptors[[]*Event](ctx, eq, qr, eq.inters) } // AllX is like All, but panics if an error occurs. @@ -201,9 +201,12 @@ func (eq *EventQuery) AllX(ctx context.Context) []*Event { } // IDs executes the query and returns a list of Event IDs. -func (eq *EventQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := eq.Select(event.FieldID).Scan(ctx, &ids); err != nil { +func (eq *EventQuery) IDs(ctx context.Context) (ids []int, err error) { + if eq.ctx.Unique == nil && eq.path != nil { + eq.Unique(true) + } + ctx = setContextOp(ctx, eq.ctx, "IDs") + if err = eq.Select(event.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -220,10 +223,11 @@ func (eq *EventQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (eq *EventQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, eq.ctx, "Count") if err := eq.prepareQuery(ctx); err != nil { return 0, err } - return eq.sqlCount(ctx) + return withInterceptors[int](ctx, eq, querierCount[*EventQuery](), eq.inters) } // CountX is like Count, but panics if an error occurs. @@ -237,10 +241,15 @@ func (eq *EventQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (eq *EventQuery) Exist(ctx context.Context) (bool, error) { - if err := eq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, eq.ctx, "Exist") + switch _, err := eq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return eq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -260,22 +269,21 @@ func (eq *EventQuery) Clone() *EventQuery { } return &EventQuery{ config: eq.config, - limit: eq.limit, - offset: eq.offset, - order: append([]OrderFunc{}, eq.order...), + ctx: eq.ctx.Clone(), + order: append([]event.OrderOption{}, eq.order...), + inters: append([]Interceptor{}, eq.inters...), predicates: append([]predicate.Event{}, eq.predicates...), withOwner: eq.withOwner.Clone(), // clone intermediate query. - sql: eq.sql.Clone(), - path: eq.path, - unique: eq.unique, + sql: eq.sql.Clone(), + path: eq.path, } } // WithOwner tells the query-builder to eager-load the nodes that are connected to // the "owner" edge. The optional arguments are used to configure the query builder of the edge. func (eq *EventQuery) WithOwner(opts ...func(*AlertQuery)) *EventQuery { - query := &AlertQuery{config: eq.config} + query := (&AlertClient{config: eq.config}).Query() for _, opt := range opts { opt(query) } @@ -298,16 +306,11 @@ func (eq *EventQuery) WithOwner(opts ...func(*AlertQuery)) *EventQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (eq *EventQuery) GroupBy(field string, fields ...string) *EventGroupBy { - grbuild := &EventGroupBy{config: eq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := eq.prepareQuery(ctx); err != nil { - return nil, err - } - return eq.sqlQuery(ctx), nil - } + eq.ctx.Fields = append([]string{field}, fields...) + grbuild := &EventGroupBy{build: eq} + grbuild.flds = &eq.ctx.Fields grbuild.label = event.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -324,15 +327,30 @@ func (eq *EventQuery) GroupBy(field string, fields ...string) *EventGroupBy { // Select(event.FieldCreatedAt). // Scan(ctx, &v) func (eq *EventQuery) Select(fields ...string) *EventSelect { - eq.fields = append(eq.fields, fields...) - selbuild := &EventSelect{EventQuery: eq} - selbuild.label = event.Label - selbuild.flds, selbuild.scan = &eq.fields, selbuild.Scan - return selbuild + eq.ctx.Fields = append(eq.ctx.Fields, fields...) + sbuild := &EventSelect{EventQuery: eq} + sbuild.label = event.Label + sbuild.flds, sbuild.scan = &eq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a EventSelect configured with the given aggregations. +func (eq *EventQuery) Aggregate(fns ...AggregateFunc) *EventSelect { + return eq.Select().Aggregate(fns...) } func (eq *EventQuery) prepareQuery(ctx context.Context) error { - for _, f := range eq.fields { + for _, inter := range eq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, eq); err != nil { + return err + } + } + } + for _, f := range eq.ctx.Fields { if !event.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -392,6 +410,9 @@ func (eq *EventQuery) loadOwner(ctx context.Context, query *AlertQuery, nodes [] } nodeids[fk] = append(nodeids[fk], nodes[i]) } + if len(ids) == 0 { + return nil + } query.Where(alert.IDIn(ids...)) neighbors, err := query.All(ctx) if err != nil { @@ -411,41 +432,22 @@ func (eq *EventQuery) loadOwner(ctx context.Context, query *AlertQuery, nodes [] func (eq *EventQuery) sqlCount(ctx context.Context) (int, error) { _spec := eq.querySpec() - _spec.Node.Columns = eq.fields - if len(eq.fields) > 0 { - _spec.Unique = eq.unique != nil && *eq.unique + _spec.Node.Columns = eq.ctx.Fields + if len(eq.ctx.Fields) > 0 { + _spec.Unique = eq.ctx.Unique != nil && *eq.ctx.Unique } return sqlgraph.CountNodes(ctx, eq.driver, _spec) } -func (eq *EventQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := eq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (eq *EventQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: event.Table, - Columns: event.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, - }, - From: eq.sql, - Unique: true, - } - if unique := eq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(event.Table, event.Columns, sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt)) + _spec.From = eq.sql + if unique := eq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if eq.path != nil { + _spec.Unique = true } - if fields := eq.fields; len(fields) > 0 { + if fields := eq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, event.FieldID) for i := range fields { @@ -453,6 +455,9 @@ func (eq *EventQuery) querySpec() *sqlgraph.QuerySpec { _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) } } + if eq.withOwner != nil { + _spec.Node.AddColumnOnce(event.FieldAlertEvents) + } } if ps := eq.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { @@ -461,10 +466,10 @@ func (eq *EventQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := eq.limit; limit != nil { + if limit := eq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := eq.offset; offset != nil { + if offset := eq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := eq.order; len(ps) > 0 { @@ -480,7 +485,7 @@ func (eq *EventQuery) querySpec() *sqlgraph.QuerySpec { func (eq *EventQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(eq.driver.Dialect()) t1 := builder.Table(event.Table) - columns := eq.fields + columns := eq.ctx.Fields if len(columns) == 0 { columns = event.Columns } @@ -489,7 +494,7 @@ func (eq *EventQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = eq.sql selector.Select(selector.Columns(columns...)...) } - if eq.unique != nil && *eq.unique { + if eq.ctx.Unique != nil && *eq.ctx.Unique { selector.Distinct() } for _, p := range eq.predicates { @@ -498,12 +503,12 @@ func (eq *EventQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range eq.order { p(selector) } - if offset := eq.offset; offset != nil { + if offset := eq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := eq.limit; limit != nil { + if limit := eq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -511,13 +516,8 @@ func (eq *EventQuery) sqlQuery(ctx context.Context) *sql.Selector { // EventGroupBy is the group-by builder for Event entities. type EventGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *EventQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -526,74 +526,77 @@ func (egb *EventGroupBy) Aggregate(fns ...AggregateFunc) *EventGroupBy { return egb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (egb *EventGroupBy) Scan(ctx context.Context, v any) error { - query, err := egb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, egb.build.ctx, "GroupBy") + if err := egb.build.prepareQuery(ctx); err != nil { return err } - egb.sql = query - return egb.sqlScan(ctx, v) + return scanWithInterceptors[*EventQuery, *EventGroupBy](ctx, egb.build, egb, egb.build.inters, v) } -func (egb *EventGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range egb.fields { - if !event.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (egb *EventGroupBy) sqlScan(ctx context.Context, root *EventQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(egb.fns)) + for _, fn := range egb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*egb.flds)+len(egb.fns)) + for _, f := range *egb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := egb.sqlQuery() + selector.GroupBy(selector.Columns(*egb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := egb.driver.Query(ctx, query, args, rows); err != nil { + if err := egb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (egb *EventGroupBy) sqlQuery() *sql.Selector { - selector := egb.sql.Select() - aggregation := make([]string, 0, len(egb.fns)) - for _, fn := range egb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(egb.fields)+len(egb.fns)) - for _, f := range egb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(egb.fields...)...) -} - // EventSelect is the builder for selecting fields of Event entities. type EventSelect struct { *EventQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (es *EventSelect) Aggregate(fns ...AggregateFunc) *EventSelect { + es.fns = append(es.fns, fns...) + return es } // Scan applies the selector query and scans the result into the given value. func (es *EventSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, es.ctx, "Select") if err := es.prepareQuery(ctx); err != nil { return err } - es.sql = es.EventQuery.sqlQuery(ctx) - return es.sqlScan(ctx, v) + return scanWithInterceptors[*EventQuery, *EventSelect](ctx, es.EventQuery, es, es.inters, v) } -func (es *EventSelect) sqlScan(ctx context.Context, v any) error { +func (es *EventSelect) sqlScan(ctx context.Context, root *EventQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(es.fns)) + for _, fn := range es.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*es.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := es.sql.Query() + query, args := selector.Query() if err := es.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/pkg/database/ent/event_update.go b/pkg/database/ent/event_update.go index fcd0cc50c99..c2f5c6cddb1 100644 --- a/pkg/database/ent/event_update.go +++ b/pkg/database/ent/event_update.go @@ -29,42 +29,12 @@ func (eu *EventUpdate) Where(ps ...predicate.Event) *EventUpdate { return eu } -// SetCreatedAt sets the "created_at" field. -func (eu *EventUpdate) SetCreatedAt(t time.Time) *EventUpdate { - eu.mutation.SetCreatedAt(t) - return eu -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (eu *EventUpdate) ClearCreatedAt() *EventUpdate { - eu.mutation.ClearCreatedAt() - return eu -} - // SetUpdatedAt sets the "updated_at" field. func (eu *EventUpdate) SetUpdatedAt(t time.Time) *EventUpdate { eu.mutation.SetUpdatedAt(t) return eu } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (eu *EventUpdate) ClearUpdatedAt() *EventUpdate { - eu.mutation.ClearUpdatedAt() - return eu -} - -// SetTime sets the "time" field. -func (eu *EventUpdate) SetTime(t time.Time) *EventUpdate { - eu.mutation.SetTime(t) - return eu -} - -// SetSerialized sets the "serialized" field. -func (eu *EventUpdate) SetSerialized(s string) *EventUpdate { - eu.mutation.SetSerialized(s) - return eu -} - // SetAlertEvents sets the "alert_events" field. func (eu *EventUpdate) SetAlertEvents(i int) *EventUpdate { eu.mutation.SetAlertEvents(i) @@ -117,41 +87,8 @@ func (eu *EventUpdate) ClearOwner() *EventUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (eu *EventUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) eu.defaults() - if len(eu.hooks) == 0 { - if err = eu.check(); err != nil { - return 0, err - } - affected, err = eu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*EventMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = eu.check(); err != nil { - return 0, err - } - eu.mutation = mutation - affected, err = eu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(eu.hooks) - 1; i >= 0; i-- { - if eu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = eu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, eu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, eu.sqlSave, eu.mutation, eu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -178,37 +115,14 @@ func (eu *EventUpdate) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (eu *EventUpdate) defaults() { - if _, ok := eu.mutation.CreatedAt(); !ok && !eu.mutation.CreatedAtCleared() { - v := event.UpdateDefaultCreatedAt() - eu.mutation.SetCreatedAt(v) - } - if _, ok := eu.mutation.UpdatedAt(); !ok && !eu.mutation.UpdatedAtCleared() { + if _, ok := eu.mutation.UpdatedAt(); !ok { v := event.UpdateDefaultUpdatedAt() eu.mutation.SetUpdatedAt(v) } } -// check runs all checks and user-defined validators on the builder. -func (eu *EventUpdate) check() error { - if v, ok := eu.mutation.Serialized(); ok { - if err := event.SerializedValidator(v); err != nil { - return &ValidationError{Name: "serialized", err: fmt.Errorf(`ent: validator failed for field "Event.serialized": %w`, err)} - } - } - return nil -} - func (eu *EventUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: event.Table, - Columns: event.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(event.Table, event.Columns, sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt)) if ps := eu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -216,45 +130,8 @@ func (eu *EventUpdate) sqlSave(ctx context.Context) (n int, err error) { } } } - if value, ok := eu.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: event.FieldCreatedAt, - }) - } - if eu.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: event.FieldCreatedAt, - }) - } if value, ok := eu.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: event.FieldUpdatedAt, - }) - } - if eu.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: event.FieldUpdatedAt, - }) - } - if value, ok := eu.mutation.Time(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: event.FieldTime, - }) - } - if value, ok := eu.mutation.Serialized(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: event.FieldSerialized, - }) + _spec.SetField(event.FieldUpdatedAt, field.TypeTime, value) } if eu.mutation.OwnerCleared() { edge := &sqlgraph.EdgeSpec{ @@ -264,10 +141,7 @@ func (eu *EventUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{event.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -280,10 +154,7 @@ func (eu *EventUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{event.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -299,6 +170,7 @@ func (eu *EventUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + eu.mutation.done = true return n, nil } @@ -310,42 +182,12 @@ type EventUpdateOne struct { mutation *EventMutation } -// SetCreatedAt sets the "created_at" field. -func (euo *EventUpdateOne) SetCreatedAt(t time.Time) *EventUpdateOne { - euo.mutation.SetCreatedAt(t) - return euo -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (euo *EventUpdateOne) ClearCreatedAt() *EventUpdateOne { - euo.mutation.ClearCreatedAt() - return euo -} - // SetUpdatedAt sets the "updated_at" field. func (euo *EventUpdateOne) SetUpdatedAt(t time.Time) *EventUpdateOne { euo.mutation.SetUpdatedAt(t) return euo } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (euo *EventUpdateOne) ClearUpdatedAt() *EventUpdateOne { - euo.mutation.ClearUpdatedAt() - return euo -} - -// SetTime sets the "time" field. -func (euo *EventUpdateOne) SetTime(t time.Time) *EventUpdateOne { - euo.mutation.SetTime(t) - return euo -} - -// SetSerialized sets the "serialized" field. -func (euo *EventUpdateOne) SetSerialized(s string) *EventUpdateOne { - euo.mutation.SetSerialized(s) - return euo -} - // SetAlertEvents sets the "alert_events" field. func (euo *EventUpdateOne) SetAlertEvents(i int) *EventUpdateOne { euo.mutation.SetAlertEvents(i) @@ -396,6 +238,12 @@ func (euo *EventUpdateOne) ClearOwner() *EventUpdateOne { return euo } +// Where appends a list predicates to the EventUpdate builder. +func (euo *EventUpdateOne) Where(ps ...predicate.Event) *EventUpdateOne { + euo.mutation.Where(ps...) + return euo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (euo *EventUpdateOne) Select(field string, fields ...string) *EventUpdateOne { @@ -405,47 +253,8 @@ func (euo *EventUpdateOne) Select(field string, fields ...string) *EventUpdateOn // Save executes the query and returns the updated Event entity. func (euo *EventUpdateOne) Save(ctx context.Context) (*Event, error) { - var ( - err error - node *Event - ) euo.defaults() - if len(euo.hooks) == 0 { - if err = euo.check(); err != nil { - return nil, err - } - node, err = euo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*EventMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = euo.check(); err != nil { - return nil, err - } - euo.mutation = mutation - node, err = euo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(euo.hooks) - 1; i >= 0; i-- { - if euo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = euo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, euo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Event) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from EventMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, euo.sqlSave, euo.mutation, euo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -472,37 +281,14 @@ func (euo *EventUpdateOne) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (euo *EventUpdateOne) defaults() { - if _, ok := euo.mutation.CreatedAt(); !ok && !euo.mutation.CreatedAtCleared() { - v := event.UpdateDefaultCreatedAt() - euo.mutation.SetCreatedAt(v) - } - if _, ok := euo.mutation.UpdatedAt(); !ok && !euo.mutation.UpdatedAtCleared() { + if _, ok := euo.mutation.UpdatedAt(); !ok { v := event.UpdateDefaultUpdatedAt() euo.mutation.SetUpdatedAt(v) } } -// check runs all checks and user-defined validators on the builder. -func (euo *EventUpdateOne) check() error { - if v, ok := euo.mutation.Serialized(); ok { - if err := event.SerializedValidator(v); err != nil { - return &ValidationError{Name: "serialized", err: fmt.Errorf(`ent: validator failed for field "Event.serialized": %w`, err)} - } - } - return nil -} - func (euo *EventUpdateOne) sqlSave(ctx context.Context) (_node *Event, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: event.Table, - Columns: event.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(event.Table, event.Columns, sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt)) id, ok := euo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Event.id" for update`)} @@ -527,45 +313,8 @@ func (euo *EventUpdateOne) sqlSave(ctx context.Context) (_node *Event, err error } } } - if value, ok := euo.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: event.FieldCreatedAt, - }) - } - if euo.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: event.FieldCreatedAt, - }) - } if value, ok := euo.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: event.FieldUpdatedAt, - }) - } - if euo.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: event.FieldUpdatedAt, - }) - } - if value, ok := euo.mutation.Time(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: event.FieldTime, - }) - } - if value, ok := euo.mutation.Serialized(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: event.FieldSerialized, - }) + _spec.SetField(event.FieldUpdatedAt, field.TypeTime, value) } if euo.mutation.OwnerCleared() { edge := &sqlgraph.EdgeSpec{ @@ -575,10 +324,7 @@ func (euo *EventUpdateOne) sqlSave(ctx context.Context) (_node *Event, err error Columns: []string{event.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -591,10 +337,7 @@ func (euo *EventUpdateOne) sqlSave(ctx context.Context) (_node *Event, err error Columns: []string{event.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -613,5 +356,6 @@ func (euo *EventUpdateOne) sqlSave(ctx context.Context) (_node *Event, err error } return nil, err } + euo.mutation.done = true return _node, nil } diff --git a/pkg/database/ent/generate.go b/pkg/database/ent/generate.go index 9f3a916c7a4..8ada999d7ab 100644 --- a/pkg/database/ent/generate.go +++ b/pkg/database/ent/generate.go @@ -1,4 +1,4 @@ package ent -//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate ./schema +//go:generate go run -mod=mod entgo.io/ent/cmd/ent@v0.13.1 generate ./schema diff --git a/pkg/database/ent/helpers.go b/pkg/database/ent/helpers.go new file mode 100644 index 00000000000..9b30ce451e0 --- /dev/null +++ b/pkg/database/ent/helpers.go @@ -0,0 +1,25 @@ +package ent + +func (m *Machine) GetOsname() string { + return m.Osname +} + +func (b *Bouncer) GetOsname() string { + return b.Osname +} + +func (m *Machine) GetOsversion() string { + return m.Osversion +} + +func (b *Bouncer) GetOsversion() string { + return b.Osversion +} + +func (m *Machine) GetFeatureflags() string { + return m.Featureflags +} + +func (b *Bouncer) GetFeatureflags() string { + return b.Featureflags +} diff --git a/pkg/database/ent/hook/hook.go b/pkg/database/ent/hook/hook.go index 85ab00b01fb..62cc07820d0 100644 --- a/pkg/database/ent/hook/hook.go +++ b/pkg/database/ent/hook/hook.go @@ -15,11 +15,10 @@ type AlertFunc func(context.Context, *ent.AlertMutation) (ent.Value, error) // Mutate calls f(ctx, m). func (f AlertFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.AlertMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AlertMutation", m) + if mv, ok := m.(*ent.AlertMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AlertMutation", m) } // The BouncerFunc type is an adapter to allow the use of ordinary @@ -28,11 +27,10 @@ type BouncerFunc func(context.Context, *ent.BouncerMutation) (ent.Value, error) // Mutate calls f(ctx, m). func (f BouncerFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.BouncerMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.BouncerMutation", m) + if mv, ok := m.(*ent.BouncerMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.BouncerMutation", m) } // The ConfigItemFunc type is an adapter to allow the use of ordinary @@ -41,11 +39,10 @@ type ConfigItemFunc func(context.Context, *ent.ConfigItemMutation) (ent.Value, e // Mutate calls f(ctx, m). func (f ConfigItemFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.ConfigItemMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ConfigItemMutation", m) + if mv, ok := m.(*ent.ConfigItemMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ConfigItemMutation", m) } // The DecisionFunc type is an adapter to allow the use of ordinary @@ -54,11 +51,10 @@ type DecisionFunc func(context.Context, *ent.DecisionMutation) (ent.Value, error // Mutate calls f(ctx, m). func (f DecisionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.DecisionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.DecisionMutation", m) + if mv, ok := m.(*ent.DecisionMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.DecisionMutation", m) } // The EventFunc type is an adapter to allow the use of ordinary @@ -67,11 +63,22 @@ type EventFunc func(context.Context, *ent.EventMutation) (ent.Value, error) // Mutate calls f(ctx, m). func (f EventFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.EventMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.EventMutation", m) + if mv, ok := m.(*ent.EventMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.EventMutation", m) +} + +// The LockFunc type is an adapter to allow the use of ordinary +// function as Lock mutator. +type LockFunc func(context.Context, *ent.LockMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f LockFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.LockMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.LockMutation", m) } // The MachineFunc type is an adapter to allow the use of ordinary @@ -80,11 +87,10 @@ type MachineFunc func(context.Context, *ent.MachineMutation) (ent.Value, error) // Mutate calls f(ctx, m). func (f MachineFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.MachineMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MachineMutation", m) + if mv, ok := m.(*ent.MachineMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MachineMutation", m) } // The MetaFunc type is an adapter to allow the use of ordinary @@ -93,11 +99,22 @@ type MetaFunc func(context.Context, *ent.MetaMutation) (ent.Value, error) // Mutate calls f(ctx, m). func (f MetaFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.MetaMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MetaMutation", m) + if mv, ok := m.(*ent.MetaMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MetaMutation", m) +} + +// The MetricFunc type is an adapter to allow the use of ordinary +// function as Metric mutator. +type MetricFunc func(context.Context, *ent.MetricMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f MetricFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.MetricMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MetricMutation", m) } // Condition is a hook condition function. diff --git a/pkg/database/ent/lock.go b/pkg/database/ent/lock.go new file mode 100644 index 00000000000..85556a30644 --- /dev/null +++ b/pkg/database/ent/lock.go @@ -0,0 +1,117 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" +) + +// Lock is the model entity for the Lock schema. +type Lock struct { + config `json:"-"` + // ID of the ent. + ID int `json:"id,omitempty"` + // Name holds the value of the "name" field. + Name string `json:"name"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at"` + selectValues sql.SelectValues +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*Lock) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case lock.FieldID: + values[i] = new(sql.NullInt64) + case lock.FieldName: + values[i] = new(sql.NullString) + case lock.FieldCreatedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the Lock fields. +func (l *Lock) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case lock.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + l.ID = int(value.Int64) + case lock.FieldName: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[i]) + } else if value.Valid { + l.Name = value.String + } + case lock.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + l.CreatedAt = value.Time + } + default: + l.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the Lock. +// This includes values selected through modifiers, order, etc. +func (l *Lock) Value(name string) (ent.Value, error) { + return l.selectValues.Get(name) +} + +// Update returns a builder for updating this Lock. +// Note that you need to call Lock.Unwrap() before calling this method if this Lock +// was returned from a transaction, and the transaction was committed or rolled back. +func (l *Lock) Update() *LockUpdateOne { + return NewLockClient(l.config).UpdateOne(l) +} + +// Unwrap unwraps the Lock entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (l *Lock) Unwrap() *Lock { + _tx, ok := l.config.driver.(*txDriver) + if !ok { + panic("ent: Lock is not a transactional entity") + } + l.config.driver = _tx.drv + return l +} + +// String implements the fmt.Stringer. +func (l *Lock) String() string { + var builder strings.Builder + builder.WriteString("Lock(") + builder.WriteString(fmt.Sprintf("id=%v, ", l.ID)) + builder.WriteString("name=") + builder.WriteString(l.Name) + builder.WriteString(", ") + builder.WriteString("created_at=") + builder.WriteString(l.CreatedAt.Format(time.ANSIC)) + builder.WriteByte(')') + return builder.String() +} + +// Locks is a parsable slice of Lock. +type Locks []*Lock diff --git a/pkg/database/ent/lock/lock.go b/pkg/database/ent/lock/lock.go new file mode 100644 index 00000000000..d0143470a75 --- /dev/null +++ b/pkg/database/ent/lock/lock.go @@ -0,0 +1,62 @@ +// Code generated by ent, DO NOT EDIT. + +package lock + +import ( + "time" + + "entgo.io/ent/dialect/sql" +) + +const ( + // Label holds the string label denoting the lock type in the database. + Label = "lock" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldName holds the string denoting the name field in the database. + FieldName = "name" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // Table holds the table name of the lock in the database. + Table = "locks" +) + +// Columns holds all SQL columns for lock fields. +var Columns = []string{ + FieldID, + FieldName, + FieldCreatedAt, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time +) + +// OrderOption defines the ordering options for the Lock queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByName orders the results by the name field. +func ByName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldName, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} diff --git a/pkg/database/ent/lock/where.go b/pkg/database/ent/lock/where.go new file mode 100644 index 00000000000..cf59362d203 --- /dev/null +++ b/pkg/database/ent/lock/where.go @@ -0,0 +1,185 @@ +// Code generated by ent, DO NOT EDIT. + +package lock + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int) predicate.Lock { + return predicate.Lock(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int) predicate.Lock { + return predicate.Lock(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int) predicate.Lock { + return predicate.Lock(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int) predicate.Lock { + return predicate.Lock(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int) predicate.Lock { + return predicate.Lock(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int) predicate.Lock { + return predicate.Lock(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int) predicate.Lock { + return predicate.Lock(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int) predicate.Lock { + return predicate.Lock(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int) predicate.Lock { + return predicate.Lock(sql.FieldLTE(FieldID, id)) +} + +// Name applies equality check predicate on the "name" field. It's identical to NameEQ. +func Name(v string) predicate.Lock { + return predicate.Lock(sql.FieldEQ(FieldName, v)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.Lock { + return predicate.Lock(sql.FieldEQ(FieldCreatedAt, v)) +} + +// NameEQ applies the EQ predicate on the "name" field. +func NameEQ(v string) predicate.Lock { + return predicate.Lock(sql.FieldEQ(FieldName, v)) +} + +// NameNEQ applies the NEQ predicate on the "name" field. +func NameNEQ(v string) predicate.Lock { + return predicate.Lock(sql.FieldNEQ(FieldName, v)) +} + +// NameIn applies the In predicate on the "name" field. +func NameIn(vs ...string) predicate.Lock { + return predicate.Lock(sql.FieldIn(FieldName, vs...)) +} + +// NameNotIn applies the NotIn predicate on the "name" field. +func NameNotIn(vs ...string) predicate.Lock { + return predicate.Lock(sql.FieldNotIn(FieldName, vs...)) +} + +// NameGT applies the GT predicate on the "name" field. +func NameGT(v string) predicate.Lock { + return predicate.Lock(sql.FieldGT(FieldName, v)) +} + +// NameGTE applies the GTE predicate on the "name" field. +func NameGTE(v string) predicate.Lock { + return predicate.Lock(sql.FieldGTE(FieldName, v)) +} + +// NameLT applies the LT predicate on the "name" field. +func NameLT(v string) predicate.Lock { + return predicate.Lock(sql.FieldLT(FieldName, v)) +} + +// NameLTE applies the LTE predicate on the "name" field. +func NameLTE(v string) predicate.Lock { + return predicate.Lock(sql.FieldLTE(FieldName, v)) +} + +// NameContains applies the Contains predicate on the "name" field. +func NameContains(v string) predicate.Lock { + return predicate.Lock(sql.FieldContains(FieldName, v)) +} + +// NameHasPrefix applies the HasPrefix predicate on the "name" field. +func NameHasPrefix(v string) predicate.Lock { + return predicate.Lock(sql.FieldHasPrefix(FieldName, v)) +} + +// NameHasSuffix applies the HasSuffix predicate on the "name" field. +func NameHasSuffix(v string) predicate.Lock { + return predicate.Lock(sql.FieldHasSuffix(FieldName, v)) +} + +// NameEqualFold applies the EqualFold predicate on the "name" field. +func NameEqualFold(v string) predicate.Lock { + return predicate.Lock(sql.FieldEqualFold(FieldName, v)) +} + +// NameContainsFold applies the ContainsFold predicate on the "name" field. +func NameContainsFold(v string) predicate.Lock { + return predicate.Lock(sql.FieldContainsFold(FieldName, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.Lock { + return predicate.Lock(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.Lock { + return predicate.Lock(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.Lock { + return predicate.Lock(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.Lock { + return predicate.Lock(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.Lock { + return predicate.Lock(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.Lock { + return predicate.Lock(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.Lock { + return predicate.Lock(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.Lock { + return predicate.Lock(sql.FieldLTE(FieldCreatedAt, v)) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.Lock) predicate.Lock { + return predicate.Lock(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.Lock) predicate.Lock { + return predicate.Lock(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.Lock) predicate.Lock { + return predicate.Lock(sql.NotPredicates(p)) +} diff --git a/pkg/database/ent/lock_create.go b/pkg/database/ent/lock_create.go new file mode 100644 index 00000000000..e2c29c88324 --- /dev/null +++ b/pkg/database/ent/lock_create.go @@ -0,0 +1,215 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" +) + +// LockCreate is the builder for creating a Lock entity. +type LockCreate struct { + config + mutation *LockMutation + hooks []Hook +} + +// SetName sets the "name" field. +func (lc *LockCreate) SetName(s string) *LockCreate { + lc.mutation.SetName(s) + return lc +} + +// SetCreatedAt sets the "created_at" field. +func (lc *LockCreate) SetCreatedAt(t time.Time) *LockCreate { + lc.mutation.SetCreatedAt(t) + return lc +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (lc *LockCreate) SetNillableCreatedAt(t *time.Time) *LockCreate { + if t != nil { + lc.SetCreatedAt(*t) + } + return lc +} + +// Mutation returns the LockMutation object of the builder. +func (lc *LockCreate) Mutation() *LockMutation { + return lc.mutation +} + +// Save creates the Lock in the database. +func (lc *LockCreate) Save(ctx context.Context) (*Lock, error) { + lc.defaults() + return withHooks(ctx, lc.sqlSave, lc.mutation, lc.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (lc *LockCreate) SaveX(ctx context.Context) *Lock { + v, err := lc.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (lc *LockCreate) Exec(ctx context.Context) error { + _, err := lc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (lc *LockCreate) ExecX(ctx context.Context) { + if err := lc.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (lc *LockCreate) defaults() { + if _, ok := lc.mutation.CreatedAt(); !ok { + v := lock.DefaultCreatedAt() + lc.mutation.SetCreatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (lc *LockCreate) check() error { + if _, ok := lc.mutation.Name(); !ok { + return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "Lock.name"`)} + } + if _, ok := lc.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Lock.created_at"`)} + } + return nil +} + +func (lc *LockCreate) sqlSave(ctx context.Context) (*Lock, error) { + if err := lc.check(); err != nil { + return nil, err + } + _node, _spec := lc.createSpec() + if err := sqlgraph.CreateNode(ctx, lc.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int(id) + lc.mutation.id = &_node.ID + lc.mutation.done = true + return _node, nil +} + +func (lc *LockCreate) createSpec() (*Lock, *sqlgraph.CreateSpec) { + var ( + _node = &Lock{config: lc.config} + _spec = sqlgraph.NewCreateSpec(lock.Table, sqlgraph.NewFieldSpec(lock.FieldID, field.TypeInt)) + ) + if value, ok := lc.mutation.Name(); ok { + _spec.SetField(lock.FieldName, field.TypeString, value) + _node.Name = value + } + if value, ok := lc.mutation.CreatedAt(); ok { + _spec.SetField(lock.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + return _node, _spec +} + +// LockCreateBulk is the builder for creating many Lock entities in bulk. +type LockCreateBulk struct { + config + err error + builders []*LockCreate +} + +// Save creates the Lock entities in the database. +func (lcb *LockCreateBulk) Save(ctx context.Context) ([]*Lock, error) { + if lcb.err != nil { + return nil, lcb.err + } + specs := make([]*sqlgraph.CreateSpec, len(lcb.builders)) + nodes := make([]*Lock, len(lcb.builders)) + mutators := make([]Mutator, len(lcb.builders)) + for i := range lcb.builders { + func(i int, root context.Context) { + builder := lcb.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*LockMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, lcb.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, lcb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, lcb.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (lcb *LockCreateBulk) SaveX(ctx context.Context) []*Lock { + v, err := lcb.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (lcb *LockCreateBulk) Exec(ctx context.Context) error { + _, err := lcb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (lcb *LockCreateBulk) ExecX(ctx context.Context) { + if err := lcb.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/pkg/database/ent/lock_delete.go b/pkg/database/ent/lock_delete.go new file mode 100644 index 00000000000..2275c608f75 --- /dev/null +++ b/pkg/database/ent/lock_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" +) + +// LockDelete is the builder for deleting a Lock entity. +type LockDelete struct { + config + hooks []Hook + mutation *LockMutation +} + +// Where appends a list predicates to the LockDelete builder. +func (ld *LockDelete) Where(ps ...predicate.Lock) *LockDelete { + ld.mutation.Where(ps...) + return ld +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (ld *LockDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, ld.sqlExec, ld.mutation, ld.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (ld *LockDelete) ExecX(ctx context.Context) int { + n, err := ld.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (ld *LockDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(lock.Table, sqlgraph.NewFieldSpec(lock.FieldID, field.TypeInt)) + if ps := ld.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, ld.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + ld.mutation.done = true + return affected, err +} + +// LockDeleteOne is the builder for deleting a single Lock entity. +type LockDeleteOne struct { + ld *LockDelete +} + +// Where appends a list predicates to the LockDelete builder. +func (ldo *LockDeleteOne) Where(ps ...predicate.Lock) *LockDeleteOne { + ldo.ld.mutation.Where(ps...) + return ldo +} + +// Exec executes the deletion query. +func (ldo *LockDeleteOne) Exec(ctx context.Context) error { + n, err := ldo.ld.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{lock.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (ldo *LockDeleteOne) ExecX(ctx context.Context) { + if err := ldo.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/pkg/database/ent/lock_query.go b/pkg/database/ent/lock_query.go new file mode 100644 index 00000000000..75e5da48a94 --- /dev/null +++ b/pkg/database/ent/lock_query.go @@ -0,0 +1,526 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" +) + +// LockQuery is the builder for querying Lock entities. +type LockQuery struct { + config + ctx *QueryContext + order []lock.OrderOption + inters []Interceptor + predicates []predicate.Lock + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the LockQuery builder. +func (lq *LockQuery) Where(ps ...predicate.Lock) *LockQuery { + lq.predicates = append(lq.predicates, ps...) + return lq +} + +// Limit the number of records to be returned by this query. +func (lq *LockQuery) Limit(limit int) *LockQuery { + lq.ctx.Limit = &limit + return lq +} + +// Offset to start from. +func (lq *LockQuery) Offset(offset int) *LockQuery { + lq.ctx.Offset = &offset + return lq +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (lq *LockQuery) Unique(unique bool) *LockQuery { + lq.ctx.Unique = &unique + return lq +} + +// Order specifies how the records should be ordered. +func (lq *LockQuery) Order(o ...lock.OrderOption) *LockQuery { + lq.order = append(lq.order, o...) + return lq +} + +// First returns the first Lock entity from the query. +// Returns a *NotFoundError when no Lock was found. +func (lq *LockQuery) First(ctx context.Context) (*Lock, error) { + nodes, err := lq.Limit(1).All(setContextOp(ctx, lq.ctx, "First")) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{lock.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (lq *LockQuery) FirstX(ctx context.Context) *Lock { + node, err := lq.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first Lock ID from the query. +// Returns a *NotFoundError when no Lock ID was found. +func (lq *LockQuery) FirstID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = lq.Limit(1).IDs(setContextOp(ctx, lq.ctx, "FirstID")); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{lock.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (lq *LockQuery) FirstIDX(ctx context.Context) int { + id, err := lq.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single Lock entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one Lock entity is found. +// Returns a *NotFoundError when no Lock entities are found. +func (lq *LockQuery) Only(ctx context.Context) (*Lock, error) { + nodes, err := lq.Limit(2).All(setContextOp(ctx, lq.ctx, "Only")) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{lock.Label} + default: + return nil, &NotSingularError{lock.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (lq *LockQuery) OnlyX(ctx context.Context) *Lock { + node, err := lq.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only Lock ID in the query. +// Returns a *NotSingularError when more than one Lock ID is found. +// Returns a *NotFoundError when no entities are found. +func (lq *LockQuery) OnlyID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = lq.Limit(2).IDs(setContextOp(ctx, lq.ctx, "OnlyID")); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{lock.Label} + default: + err = &NotSingularError{lock.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (lq *LockQuery) OnlyIDX(ctx context.Context) int { + id, err := lq.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of Locks. +func (lq *LockQuery) All(ctx context.Context) ([]*Lock, error) { + ctx = setContextOp(ctx, lq.ctx, "All") + if err := lq.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*Lock, *LockQuery]() + return withInterceptors[[]*Lock](ctx, lq, qr, lq.inters) +} + +// AllX is like All, but panics if an error occurs. +func (lq *LockQuery) AllX(ctx context.Context) []*Lock { + nodes, err := lq.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of Lock IDs. +func (lq *LockQuery) IDs(ctx context.Context) (ids []int, err error) { + if lq.ctx.Unique == nil && lq.path != nil { + lq.Unique(true) + } + ctx = setContextOp(ctx, lq.ctx, "IDs") + if err = lq.Select(lock.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (lq *LockQuery) IDsX(ctx context.Context) []int { + ids, err := lq.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (lq *LockQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, lq.ctx, "Count") + if err := lq.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, lq, querierCount[*LockQuery](), lq.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (lq *LockQuery) CountX(ctx context.Context) int { + count, err := lq.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (lq *LockQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, lq.ctx, "Exist") + switch _, err := lq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (lq *LockQuery) ExistX(ctx context.Context) bool { + exist, err := lq.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the LockQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (lq *LockQuery) Clone() *LockQuery { + if lq == nil { + return nil + } + return &LockQuery{ + config: lq.config, + ctx: lq.ctx.Clone(), + order: append([]lock.OrderOption{}, lq.order...), + inters: append([]Interceptor{}, lq.inters...), + predicates: append([]predicate.Lock{}, lq.predicates...), + // clone intermediate query. + sql: lq.sql.Clone(), + path: lq.path, + } +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// Name string `json:"name"` +// Count int `json:"count,omitempty"` +// } +// +// client.Lock.Query(). +// GroupBy(lock.FieldName). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (lq *LockQuery) GroupBy(field string, fields ...string) *LockGroupBy { + lq.ctx.Fields = append([]string{field}, fields...) + grbuild := &LockGroupBy{build: lq} + grbuild.flds = &lq.ctx.Fields + grbuild.label = lock.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// Name string `json:"name"` +// } +// +// client.Lock.Query(). +// Select(lock.FieldName). +// Scan(ctx, &v) +func (lq *LockQuery) Select(fields ...string) *LockSelect { + lq.ctx.Fields = append(lq.ctx.Fields, fields...) + sbuild := &LockSelect{LockQuery: lq} + sbuild.label = lock.Label + sbuild.flds, sbuild.scan = &lq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a LockSelect configured with the given aggregations. +func (lq *LockQuery) Aggregate(fns ...AggregateFunc) *LockSelect { + return lq.Select().Aggregate(fns...) +} + +func (lq *LockQuery) prepareQuery(ctx context.Context) error { + for _, inter := range lq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, lq); err != nil { + return err + } + } + } + for _, f := range lq.ctx.Fields { + if !lock.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if lq.path != nil { + prev, err := lq.path(ctx) + if err != nil { + return err + } + lq.sql = prev + } + return nil +} + +func (lq *LockQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Lock, error) { + var ( + nodes = []*Lock{} + _spec = lq.querySpec() + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*Lock).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &Lock{config: lq.config} + nodes = append(nodes, node) + return node.assignValues(columns, values) + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, lq.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + return nodes, nil +} + +func (lq *LockQuery) sqlCount(ctx context.Context) (int, error) { + _spec := lq.querySpec() + _spec.Node.Columns = lq.ctx.Fields + if len(lq.ctx.Fields) > 0 { + _spec.Unique = lq.ctx.Unique != nil && *lq.ctx.Unique + } + return sqlgraph.CountNodes(ctx, lq.driver, _spec) +} + +func (lq *LockQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(lock.Table, lock.Columns, sqlgraph.NewFieldSpec(lock.FieldID, field.TypeInt)) + _spec.From = lq.sql + if unique := lq.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if lq.path != nil { + _spec.Unique = true + } + if fields := lq.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, lock.FieldID) + for i := range fields { + if fields[i] != lock.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := lq.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := lq.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := lq.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := lq.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (lq *LockQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(lq.driver.Dialect()) + t1 := builder.Table(lock.Table) + columns := lq.ctx.Fields + if len(columns) == 0 { + columns = lock.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if lq.sql != nil { + selector = lq.sql + selector.Select(selector.Columns(columns...)...) + } + if lq.ctx.Unique != nil && *lq.ctx.Unique { + selector.Distinct() + } + for _, p := range lq.predicates { + p(selector) + } + for _, p := range lq.order { + p(selector) + } + if offset := lq.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := lq.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// LockGroupBy is the group-by builder for Lock entities. +type LockGroupBy struct { + selector + build *LockQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (lgb *LockGroupBy) Aggregate(fns ...AggregateFunc) *LockGroupBy { + lgb.fns = append(lgb.fns, fns...) + return lgb +} + +// Scan applies the selector query and scans the result into the given value. +func (lgb *LockGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, lgb.build.ctx, "GroupBy") + if err := lgb.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*LockQuery, *LockGroupBy](ctx, lgb.build, lgb, lgb.build.inters, v) +} + +func (lgb *LockGroupBy) sqlScan(ctx context.Context, root *LockQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(lgb.fns)) + for _, fn := range lgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*lgb.flds)+len(lgb.fns)) + for _, f := range *lgb.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*lgb.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := lgb.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// LockSelect is the builder for selecting fields of Lock entities. +type LockSelect struct { + *LockQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ls *LockSelect) Aggregate(fns ...AggregateFunc) *LockSelect { + ls.fns = append(ls.fns, fns...) + return ls +} + +// Scan applies the selector query and scans the result into the given value. +func (ls *LockSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ls.ctx, "Select") + if err := ls.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*LockQuery, *LockSelect](ctx, ls.LockQuery, ls, ls.inters, v) +} + +func (ls *LockSelect) sqlScan(ctx context.Context, root *LockQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ls.fns)) + for _, fn := range ls.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ls.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := ls.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/pkg/database/ent/lock_update.go b/pkg/database/ent/lock_update.go new file mode 100644 index 00000000000..934e68c0762 --- /dev/null +++ b/pkg/database/ent/lock_update.go @@ -0,0 +1,175 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" +) + +// LockUpdate is the builder for updating Lock entities. +type LockUpdate struct { + config + hooks []Hook + mutation *LockMutation +} + +// Where appends a list predicates to the LockUpdate builder. +func (lu *LockUpdate) Where(ps ...predicate.Lock) *LockUpdate { + lu.mutation.Where(ps...) + return lu +} + +// Mutation returns the LockMutation object of the builder. +func (lu *LockUpdate) Mutation() *LockMutation { + return lu.mutation +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (lu *LockUpdate) Save(ctx context.Context) (int, error) { + return withHooks(ctx, lu.sqlSave, lu.mutation, lu.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (lu *LockUpdate) SaveX(ctx context.Context) int { + affected, err := lu.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (lu *LockUpdate) Exec(ctx context.Context) error { + _, err := lu.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (lu *LockUpdate) ExecX(ctx context.Context) { + if err := lu.Exec(ctx); err != nil { + panic(err) + } +} + +func (lu *LockUpdate) sqlSave(ctx context.Context) (n int, err error) { + _spec := sqlgraph.NewUpdateSpec(lock.Table, lock.Columns, sqlgraph.NewFieldSpec(lock.FieldID, field.TypeInt)) + if ps := lu.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if n, err = sqlgraph.UpdateNodes(ctx, lu.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{lock.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + lu.mutation.done = true + return n, nil +} + +// LockUpdateOne is the builder for updating a single Lock entity. +type LockUpdateOne struct { + config + fields []string + hooks []Hook + mutation *LockMutation +} + +// Mutation returns the LockMutation object of the builder. +func (luo *LockUpdateOne) Mutation() *LockMutation { + return luo.mutation +} + +// Where appends a list predicates to the LockUpdate builder. +func (luo *LockUpdateOne) Where(ps ...predicate.Lock) *LockUpdateOne { + luo.mutation.Where(ps...) + return luo +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (luo *LockUpdateOne) Select(field string, fields ...string) *LockUpdateOne { + luo.fields = append([]string{field}, fields...) + return luo +} + +// Save executes the query and returns the updated Lock entity. +func (luo *LockUpdateOne) Save(ctx context.Context) (*Lock, error) { + return withHooks(ctx, luo.sqlSave, luo.mutation, luo.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (luo *LockUpdateOne) SaveX(ctx context.Context) *Lock { + node, err := luo.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (luo *LockUpdateOne) Exec(ctx context.Context) error { + _, err := luo.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (luo *LockUpdateOne) ExecX(ctx context.Context) { + if err := luo.Exec(ctx); err != nil { + panic(err) + } +} + +func (luo *LockUpdateOne) sqlSave(ctx context.Context) (_node *Lock, err error) { + _spec := sqlgraph.NewUpdateSpec(lock.Table, lock.Columns, sqlgraph.NewFieldSpec(lock.FieldID, field.TypeInt)) + id, ok := luo.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Lock.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := luo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, lock.FieldID) + for _, f := range fields { + if !lock.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != lock.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := luo.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + _node = &Lock{config: luo.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, luo.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{lock.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + luo.mutation.done = true + return _node, nil +} diff --git a/pkg/database/ent/machine.go b/pkg/database/ent/machine.go index dc2b18ee81c..76127065791 100644 --- a/pkg/database/ent/machine.go +++ b/pkg/database/ent/machine.go @@ -3,12 +3,15 @@ package ent import ( + "encoding/json" "fmt" "strings" "time" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/schema" ) // Machine is the model entity for the Machine schema. @@ -17,9 +20,9 @@ type Machine struct { // ID of the ent. ID int `json:"id,omitempty"` // CreatedAt holds the value of the "created_at" field. - CreatedAt *time.Time `json:"created_at,omitempty"` + CreatedAt time.Time `json:"created_at,omitempty"` // UpdatedAt holds the value of the "updated_at" field. - UpdatedAt *time.Time `json:"updated_at,omitempty"` + UpdatedAt time.Time `json:"updated_at,omitempty"` // LastPush holds the value of the "last_push" field. LastPush *time.Time `json:"last_push,omitempty"` // LastHeartbeat holds the value of the "last_heartbeat" field. @@ -36,13 +39,22 @@ type Machine struct { Version string `json:"version,omitempty"` // IsValidated holds the value of the "isValidated" field. IsValidated bool `json:"isValidated,omitempty"` - // Status holds the value of the "status" field. - Status string `json:"status,omitempty"` // AuthType holds the value of the "auth_type" field. AuthType string `json:"auth_type"` + // Osname holds the value of the "osname" field. + Osname string `json:"osname,omitempty"` + // Osversion holds the value of the "osversion" field. + Osversion string `json:"osversion,omitempty"` + // Featureflags holds the value of the "featureflags" field. + Featureflags string `json:"featureflags,omitempty"` + // Hubstate holds the value of the "hubstate" field. + Hubstate map[string][]schema.ItemState `json:"hubstate,omitempty"` + // Datasources holds the value of the "datasources" field. + Datasources map[string]int64 `json:"datasources,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the MachineQuery when eager-loading is set. - Edges MachineEdges `json:"edges"` + Edges MachineEdges `json:"edges"` + selectValues sql.SelectValues } // MachineEdges holds the relations/edges for other nodes in the graph. @@ -68,16 +80,18 @@ func (*Machine) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { + case machine.FieldHubstate, machine.FieldDatasources: + values[i] = new([]byte) case machine.FieldIsValidated: values[i] = new(sql.NullBool) case machine.FieldID: values[i] = new(sql.NullInt64) - case machine.FieldMachineId, machine.FieldPassword, machine.FieldIpAddress, machine.FieldScenarios, machine.FieldVersion, machine.FieldStatus, machine.FieldAuthType: + case machine.FieldMachineId, machine.FieldPassword, machine.FieldIpAddress, machine.FieldScenarios, machine.FieldVersion, machine.FieldAuthType, machine.FieldOsname, machine.FieldOsversion, machine.FieldFeatureflags: values[i] = new(sql.NullString) case machine.FieldCreatedAt, machine.FieldUpdatedAt, machine.FieldLastPush, machine.FieldLastHeartbeat: values[i] = new(sql.NullTime) default: - return nil, fmt.Errorf("unexpected column %q for type Machine", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -101,15 +115,13 @@ func (m *Machine) assignValues(columns []string, values []any) error { if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) } else if value.Valid { - m.CreatedAt = new(time.Time) - *m.CreatedAt = value.Time + m.CreatedAt = value.Time } case machine.FieldUpdatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field updated_at", values[i]) } else if value.Valid { - m.UpdatedAt = new(time.Time) - *m.UpdatedAt = value.Time + m.UpdatedAt = value.Time } case machine.FieldLastPush: if value, ok := values[i].(*sql.NullTime); !ok { @@ -161,33 +173,69 @@ func (m *Machine) assignValues(columns []string, values []any) error { } else if value.Valid { m.IsValidated = value.Bool } - case machine.FieldStatus: - if value, ok := values[i].(*sql.NullString); !ok { - return fmt.Errorf("unexpected type %T for field status", values[i]) - } else if value.Valid { - m.Status = value.String - } case machine.FieldAuthType: if value, ok := values[i].(*sql.NullString); !ok { return fmt.Errorf("unexpected type %T for field auth_type", values[i]) } else if value.Valid { m.AuthType = value.String } + case machine.FieldOsname: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field osname", values[i]) + } else if value.Valid { + m.Osname = value.String + } + case machine.FieldOsversion: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field osversion", values[i]) + } else if value.Valid { + m.Osversion = value.String + } + case machine.FieldFeatureflags: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field featureflags", values[i]) + } else if value.Valid { + m.Featureflags = value.String + } + case machine.FieldHubstate: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field hubstate", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &m.Hubstate); err != nil { + return fmt.Errorf("unmarshal field hubstate: %w", err) + } + } + case machine.FieldDatasources: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field datasources", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &m.Datasources); err != nil { + return fmt.Errorf("unmarshal field datasources: %w", err) + } + } + default: + m.selectValues.Set(columns[i], values[i]) } } return nil } +// Value returns the ent.Value that was dynamically selected and assigned to the Machine. +// This includes values selected through modifiers, order, etc. +func (m *Machine) Value(name string) (ent.Value, error) { + return m.selectValues.Get(name) +} + // QueryAlerts queries the "alerts" edge of the Machine entity. func (m *Machine) QueryAlerts() *AlertQuery { - return (&MachineClient{config: m.config}).QueryAlerts(m) + return NewMachineClient(m.config).QueryAlerts(m) } // Update returns a builder for updating this Machine. // Note that you need to call Machine.Unwrap() before calling this method if this Machine // was returned from a transaction, and the transaction was committed or rolled back. func (m *Machine) Update() *MachineUpdateOne { - return (&MachineClient{config: m.config}).UpdateOne(m) + return NewMachineClient(m.config).UpdateOne(m) } // Unwrap unwraps the Machine entity that was returned from a transaction after it was closed, @@ -206,15 +254,11 @@ func (m *Machine) String() string { var builder strings.Builder builder.WriteString("Machine(") builder.WriteString(fmt.Sprintf("id=%v, ", m.ID)) - if v := m.CreatedAt; v != nil { - builder.WriteString("created_at=") - builder.WriteString(v.Format(time.ANSIC)) - } + builder.WriteString("created_at=") + builder.WriteString(m.CreatedAt.Format(time.ANSIC)) builder.WriteString(", ") - if v := m.UpdatedAt; v != nil { - builder.WriteString("updated_at=") - builder.WriteString(v.Format(time.ANSIC)) - } + builder.WriteString("updated_at=") + builder.WriteString(m.UpdatedAt.Format(time.ANSIC)) builder.WriteString(", ") if v := m.LastPush; v != nil { builder.WriteString("last_push=") @@ -243,20 +287,26 @@ func (m *Machine) String() string { builder.WriteString("isValidated=") builder.WriteString(fmt.Sprintf("%v", m.IsValidated)) builder.WriteString(", ") - builder.WriteString("status=") - builder.WriteString(m.Status) - builder.WriteString(", ") builder.WriteString("auth_type=") builder.WriteString(m.AuthType) + builder.WriteString(", ") + builder.WriteString("osname=") + builder.WriteString(m.Osname) + builder.WriteString(", ") + builder.WriteString("osversion=") + builder.WriteString(m.Osversion) + builder.WriteString(", ") + builder.WriteString("featureflags=") + builder.WriteString(m.Featureflags) + builder.WriteString(", ") + builder.WriteString("hubstate=") + builder.WriteString(fmt.Sprintf("%v", m.Hubstate)) + builder.WriteString(", ") + builder.WriteString("datasources=") + builder.WriteString(fmt.Sprintf("%v", m.Datasources)) builder.WriteByte(')') return builder.String() } // Machines is a parsable slice of Machine. type Machines []*Machine - -func (m Machines) config(cfg config) { - for _i := range m { - m[_i].config = cfg - } -} diff --git a/pkg/database/ent/machine/machine.go b/pkg/database/ent/machine/machine.go index e6900dd21e1..009e6e19c35 100644 --- a/pkg/database/ent/machine/machine.go +++ b/pkg/database/ent/machine/machine.go @@ -4,6 +4,9 @@ package machine import ( "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" ) const ( @@ -31,10 +34,18 @@ const ( FieldVersion = "version" // FieldIsValidated holds the string denoting the isvalidated field in the database. FieldIsValidated = "is_validated" - // FieldStatus holds the string denoting the status field in the database. - FieldStatus = "status" // FieldAuthType holds the string denoting the auth_type field in the database. FieldAuthType = "auth_type" + // FieldOsname holds the string denoting the osname field in the database. + FieldOsname = "osname" + // FieldOsversion holds the string denoting the osversion field in the database. + FieldOsversion = "osversion" + // FieldFeatureflags holds the string denoting the featureflags field in the database. + FieldFeatureflags = "featureflags" + // FieldHubstate holds the string denoting the hubstate field in the database. + FieldHubstate = "hubstate" + // FieldDatasources holds the string denoting the datasources field in the database. + FieldDatasources = "datasources" // EdgeAlerts holds the string denoting the alerts edge name in mutations. EdgeAlerts = "alerts" // Table holds the table name of the machine in the database. @@ -61,8 +72,12 @@ var Columns = []string{ FieldScenarios, FieldVersion, FieldIsValidated, - FieldStatus, FieldAuthType, + FieldOsname, + FieldOsversion, + FieldFeatureflags, + FieldHubstate, + FieldDatasources, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -78,20 +93,12 @@ func ValidColumn(column string) bool { var ( // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time - // UpdateDefaultCreatedAt holds the default value on update for the "created_at" field. - UpdateDefaultCreatedAt func() time.Time // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. DefaultUpdatedAt func() time.Time // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. UpdateDefaultUpdatedAt func() time.Time // DefaultLastPush holds the default value on creation for the "last_push" field. DefaultLastPush func() time.Time - // UpdateDefaultLastPush holds the default value on update for the "last_push" field. - UpdateDefaultLastPush func() time.Time - // DefaultLastHeartbeat holds the default value on creation for the "last_heartbeat" field. - DefaultLastHeartbeat func() time.Time - // UpdateDefaultLastHeartbeat holds the default value on update for the "last_heartbeat" field. - UpdateDefaultLastHeartbeat func() time.Time // ScenariosValidator is a validator for the "scenarios" field. It is called by the builders before save. ScenariosValidator func(string) error // DefaultIsValidated holds the default value on creation for the "isValidated" field. @@ -99,3 +106,102 @@ var ( // DefaultAuthType holds the default value on creation for the "auth_type" field. DefaultAuthType string ) + +// OrderOption defines the ordering options for the Machine queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByLastPush orders the results by the last_push field. +func ByLastPush(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLastPush, opts...).ToFunc() +} + +// ByLastHeartbeat orders the results by the last_heartbeat field. +func ByLastHeartbeat(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLastHeartbeat, opts...).ToFunc() +} + +// ByMachineId orders the results by the machineId field. +func ByMachineId(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMachineId, opts...).ToFunc() +} + +// ByPassword orders the results by the password field. +func ByPassword(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPassword, opts...).ToFunc() +} + +// ByIpAddress orders the results by the ipAddress field. +func ByIpAddress(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIpAddress, opts...).ToFunc() +} + +// ByScenarios orders the results by the scenarios field. +func ByScenarios(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldScenarios, opts...).ToFunc() +} + +// ByVersion orders the results by the version field. +func ByVersion(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldVersion, opts...).ToFunc() +} + +// ByIsValidated orders the results by the isValidated field. +func ByIsValidated(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIsValidated, opts...).ToFunc() +} + +// ByAuthType orders the results by the auth_type field. +func ByAuthType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAuthType, opts...).ToFunc() +} + +// ByOsname orders the results by the osname field. +func ByOsname(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldOsname, opts...).ToFunc() +} + +// ByOsversion orders the results by the osversion field. +func ByOsversion(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldOsversion, opts...).ToFunc() +} + +// ByFeatureflags orders the results by the featureflags field. +func ByFeatureflags(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldFeatureflags, opts...).ToFunc() +} + +// ByAlertsCount orders the results by alerts count. +func ByAlertsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newAlertsStep(), opts...) + } +} + +// ByAlerts orders the results by alerts terms. +func ByAlerts(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAlertsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} +func newAlertsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AlertsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, AlertsTable, AlertsColumn), + ) +} diff --git a/pkg/database/ent/machine/where.go b/pkg/database/ent/machine/where.go index 7d0227731cc..de523510f33 100644 --- a/pkg/database/ent/machine/where.go +++ b/pkg/database/ent/machine/where.go @@ -12,1218 +12,962 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Machine(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Machine(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.Machine(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.Machine(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.Machine(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.Machine(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.Machine(sql.FieldLTE(FieldID, id)) } // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldCreatedAt, v)) } // UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. func UpdatedAt(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldUpdatedAt, v)) } // LastPush applies equality check predicate on the "last_push" field. It's identical to LastPushEQ. func LastPush(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldLastPush), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldLastPush, v)) } // LastHeartbeat applies equality check predicate on the "last_heartbeat" field. It's identical to LastHeartbeatEQ. func LastHeartbeat(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldLastHeartbeat), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldLastHeartbeat, v)) } // MachineId applies equality check predicate on the "machineId" field. It's identical to MachineIdEQ. func MachineId(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldMachineId, v)) } // Password applies equality check predicate on the "password" field. It's identical to PasswordEQ. func Password(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldPassword, v)) } // IpAddress applies equality check predicate on the "ipAddress" field. It's identical to IpAddressEQ. func IpAddress(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldIpAddress, v)) } // Scenarios applies equality check predicate on the "scenarios" field. It's identical to ScenariosEQ. func Scenarios(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldScenarios, v)) } // Version applies equality check predicate on the "version" field. It's identical to VersionEQ. func Version(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldVersion, v)) } // IsValidated applies equality check predicate on the "isValidated" field. It's identical to IsValidatedEQ. func IsValidated(v bool) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldIsValidated), v)) - }) -} - -// Status applies equality check predicate on the "status" field. It's identical to StatusEQ. -func Status(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStatus), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldIsValidated, v)) } // AuthType applies equality check predicate on the "auth_type" field. It's identical to AuthTypeEQ. func AuthType(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAuthType), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldAuthType, v)) +} + +// Osname applies equality check predicate on the "osname" field. It's identical to OsnameEQ. +func Osname(v string) predicate.Machine { + return predicate.Machine(sql.FieldEQ(FieldOsname, v)) +} + +// Osversion applies equality check predicate on the "osversion" field. It's identical to OsversionEQ. +func Osversion(v string) predicate.Machine { + return predicate.Machine(sql.FieldEQ(FieldOsversion, v)) +} + +// Featureflags applies equality check predicate on the "featureflags" field. It's identical to FeatureflagsEQ. +func Featureflags(v string) predicate.Machine { + return predicate.Machine(sql.FieldEQ(FieldFeatureflags, v)) } // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldCreatedAt, v)) } // CreatedAtNEQ applies the NEQ predicate on the "created_at" field. func CreatedAtNEQ(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldCreatedAt, v)) } // CreatedAtIn applies the In predicate on the "created_at" field. func CreatedAtIn(vs ...time.Time) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldCreatedAt), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldCreatedAt, vs...)) } // CreatedAtNotIn applies the NotIn predicate on the "created_at" field. func CreatedAtNotIn(vs ...time.Time) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldCreatedAt), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldCreatedAt, vs...)) } // CreatedAtGT applies the GT predicate on the "created_at" field. func CreatedAtGT(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldCreatedAt), v)) - }) + return predicate.Machine(sql.FieldGT(FieldCreatedAt, v)) } // CreatedAtGTE applies the GTE predicate on the "created_at" field. func CreatedAtGTE(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldCreatedAt), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldCreatedAt, v)) } // CreatedAtLT applies the LT predicate on the "created_at" field. func CreatedAtLT(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldCreatedAt), v)) - }) + return predicate.Machine(sql.FieldLT(FieldCreatedAt, v)) } // CreatedAtLTE applies the LTE predicate on the "created_at" field. func CreatedAtLTE(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldCreatedAt), v)) - }) -} - -// CreatedAtIsNil applies the IsNil predicate on the "created_at" field. -func CreatedAtIsNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldCreatedAt))) - }) -} - -// CreatedAtNotNil applies the NotNil predicate on the "created_at" field. -func CreatedAtNotNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldCreatedAt))) - }) + return predicate.Machine(sql.FieldLTE(FieldCreatedAt, v)) } // UpdatedAtEQ applies the EQ predicate on the "updated_at" field. func UpdatedAtEQ(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldUpdatedAt, v)) } // UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. func UpdatedAtNEQ(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldUpdatedAt, v)) } // UpdatedAtIn applies the In predicate on the "updated_at" field. func UpdatedAtIn(vs ...time.Time) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldUpdatedAt, vs...)) } // UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. func UpdatedAtNotIn(vs ...time.Time) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldUpdatedAt, vs...)) } // UpdatedAtGT applies the GT predicate on the "updated_at" field. func UpdatedAtGT(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Machine(sql.FieldGT(FieldUpdatedAt, v)) } // UpdatedAtGTE applies the GTE predicate on the "updated_at" field. func UpdatedAtGTE(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUpdatedAt), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldUpdatedAt, v)) } // UpdatedAtLT applies the LT predicate on the "updated_at" field. func UpdatedAtLT(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Machine(sql.FieldLT(FieldUpdatedAt, v)) } // UpdatedAtLTE applies the LTE predicate on the "updated_at" field. func UpdatedAtLTE(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUpdatedAt), v)) - }) -} - -// UpdatedAtIsNil applies the IsNil predicate on the "updated_at" field. -func UpdatedAtIsNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUpdatedAt))) - }) -} - -// UpdatedAtNotNil applies the NotNil predicate on the "updated_at" field. -func UpdatedAtNotNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUpdatedAt))) - }) + return predicate.Machine(sql.FieldLTE(FieldUpdatedAt, v)) } // LastPushEQ applies the EQ predicate on the "last_push" field. func LastPushEQ(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldLastPush), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldLastPush, v)) } // LastPushNEQ applies the NEQ predicate on the "last_push" field. func LastPushNEQ(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldLastPush), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldLastPush, v)) } // LastPushIn applies the In predicate on the "last_push" field. func LastPushIn(vs ...time.Time) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldLastPush), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldLastPush, vs...)) } // LastPushNotIn applies the NotIn predicate on the "last_push" field. func LastPushNotIn(vs ...time.Time) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldLastPush), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldLastPush, vs...)) } // LastPushGT applies the GT predicate on the "last_push" field. func LastPushGT(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldLastPush), v)) - }) + return predicate.Machine(sql.FieldGT(FieldLastPush, v)) } // LastPushGTE applies the GTE predicate on the "last_push" field. func LastPushGTE(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldLastPush), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldLastPush, v)) } // LastPushLT applies the LT predicate on the "last_push" field. func LastPushLT(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldLastPush), v)) - }) + return predicate.Machine(sql.FieldLT(FieldLastPush, v)) } // LastPushLTE applies the LTE predicate on the "last_push" field. func LastPushLTE(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldLastPush), v)) - }) + return predicate.Machine(sql.FieldLTE(FieldLastPush, v)) } // LastPushIsNil applies the IsNil predicate on the "last_push" field. func LastPushIsNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldLastPush))) - }) + return predicate.Machine(sql.FieldIsNull(FieldLastPush)) } // LastPushNotNil applies the NotNil predicate on the "last_push" field. func LastPushNotNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldLastPush))) - }) + return predicate.Machine(sql.FieldNotNull(FieldLastPush)) } // LastHeartbeatEQ applies the EQ predicate on the "last_heartbeat" field. func LastHeartbeatEQ(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldLastHeartbeat), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldLastHeartbeat, v)) } // LastHeartbeatNEQ applies the NEQ predicate on the "last_heartbeat" field. func LastHeartbeatNEQ(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldLastHeartbeat), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldLastHeartbeat, v)) } // LastHeartbeatIn applies the In predicate on the "last_heartbeat" field. func LastHeartbeatIn(vs ...time.Time) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldLastHeartbeat), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldLastHeartbeat, vs...)) } // LastHeartbeatNotIn applies the NotIn predicate on the "last_heartbeat" field. func LastHeartbeatNotIn(vs ...time.Time) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldLastHeartbeat), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldLastHeartbeat, vs...)) } // LastHeartbeatGT applies the GT predicate on the "last_heartbeat" field. func LastHeartbeatGT(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldLastHeartbeat), v)) - }) + return predicate.Machine(sql.FieldGT(FieldLastHeartbeat, v)) } // LastHeartbeatGTE applies the GTE predicate on the "last_heartbeat" field. func LastHeartbeatGTE(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldLastHeartbeat), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldLastHeartbeat, v)) } // LastHeartbeatLT applies the LT predicate on the "last_heartbeat" field. func LastHeartbeatLT(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldLastHeartbeat), v)) - }) + return predicate.Machine(sql.FieldLT(FieldLastHeartbeat, v)) } // LastHeartbeatLTE applies the LTE predicate on the "last_heartbeat" field. func LastHeartbeatLTE(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldLastHeartbeat), v)) - }) + return predicate.Machine(sql.FieldLTE(FieldLastHeartbeat, v)) } // LastHeartbeatIsNil applies the IsNil predicate on the "last_heartbeat" field. func LastHeartbeatIsNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldLastHeartbeat))) - }) + return predicate.Machine(sql.FieldIsNull(FieldLastHeartbeat)) } // LastHeartbeatNotNil applies the NotNil predicate on the "last_heartbeat" field. func LastHeartbeatNotNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldLastHeartbeat))) - }) + return predicate.Machine(sql.FieldNotNull(FieldLastHeartbeat)) } // MachineIdEQ applies the EQ predicate on the "machineId" field. func MachineIdEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldMachineId, v)) } // MachineIdNEQ applies the NEQ predicate on the "machineId" field. func MachineIdNEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldMachineId, v)) } // MachineIdIn applies the In predicate on the "machineId" field. func MachineIdIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldMachineId), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldMachineId, vs...)) } // MachineIdNotIn applies the NotIn predicate on the "machineId" field. func MachineIdNotIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldMachineId), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldMachineId, vs...)) } // MachineIdGT applies the GT predicate on the "machineId" field. func MachineIdGT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldGT(FieldMachineId, v)) } // MachineIdGTE applies the GTE predicate on the "machineId" field. func MachineIdGTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldMachineId, v)) } // MachineIdLT applies the LT predicate on the "machineId" field. func MachineIdLT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldLT(FieldMachineId, v)) } // MachineIdLTE applies the LTE predicate on the "machineId" field. func MachineIdLTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldLTE(FieldMachineId, v)) } // MachineIdContains applies the Contains predicate on the "machineId" field. func MachineIdContains(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldContains(FieldMachineId, v)) } // MachineIdHasPrefix applies the HasPrefix predicate on the "machineId" field. func MachineIdHasPrefix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldHasPrefix(FieldMachineId, v)) } // MachineIdHasSuffix applies the HasSuffix predicate on the "machineId" field. func MachineIdHasSuffix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldHasSuffix(FieldMachineId, v)) } // MachineIdEqualFold applies the EqualFold predicate on the "machineId" field. func MachineIdEqualFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldEqualFold(FieldMachineId, v)) } // MachineIdContainsFold applies the ContainsFold predicate on the "machineId" field. func MachineIdContainsFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldContainsFold(FieldMachineId, v)) } // PasswordEQ applies the EQ predicate on the "password" field. func PasswordEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldPassword, v)) } // PasswordNEQ applies the NEQ predicate on the "password" field. func PasswordNEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldPassword, v)) } // PasswordIn applies the In predicate on the "password" field. func PasswordIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldPassword), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldPassword, vs...)) } // PasswordNotIn applies the NotIn predicate on the "password" field. func PasswordNotIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldPassword), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldPassword, vs...)) } // PasswordGT applies the GT predicate on the "password" field. func PasswordGT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldGT(FieldPassword, v)) } // PasswordGTE applies the GTE predicate on the "password" field. func PasswordGTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldPassword, v)) } // PasswordLT applies the LT predicate on the "password" field. func PasswordLT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldLT(FieldPassword, v)) } // PasswordLTE applies the LTE predicate on the "password" field. func PasswordLTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldLTE(FieldPassword, v)) } // PasswordContains applies the Contains predicate on the "password" field. func PasswordContains(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldContains(FieldPassword, v)) } // PasswordHasPrefix applies the HasPrefix predicate on the "password" field. func PasswordHasPrefix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldHasPrefix(FieldPassword, v)) } // PasswordHasSuffix applies the HasSuffix predicate on the "password" field. func PasswordHasSuffix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldHasSuffix(FieldPassword, v)) } // PasswordEqualFold applies the EqualFold predicate on the "password" field. func PasswordEqualFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldEqualFold(FieldPassword, v)) } // PasswordContainsFold applies the ContainsFold predicate on the "password" field. func PasswordContainsFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldContainsFold(FieldPassword, v)) } // IpAddressEQ applies the EQ predicate on the "ipAddress" field. func IpAddressEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldIpAddress, v)) } // IpAddressNEQ applies the NEQ predicate on the "ipAddress" field. func IpAddressNEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldIpAddress, v)) } // IpAddressIn applies the In predicate on the "ipAddress" field. func IpAddressIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldIpAddress), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldIpAddress, vs...)) } // IpAddressNotIn applies the NotIn predicate on the "ipAddress" field. func IpAddressNotIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldIpAddress), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldIpAddress, vs...)) } // IpAddressGT applies the GT predicate on the "ipAddress" field. func IpAddressGT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldGT(FieldIpAddress, v)) } // IpAddressGTE applies the GTE predicate on the "ipAddress" field. func IpAddressGTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldIpAddress, v)) } // IpAddressLT applies the LT predicate on the "ipAddress" field. func IpAddressLT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldLT(FieldIpAddress, v)) } // IpAddressLTE applies the LTE predicate on the "ipAddress" field. func IpAddressLTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldLTE(FieldIpAddress, v)) } // IpAddressContains applies the Contains predicate on the "ipAddress" field. func IpAddressContains(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldContains(FieldIpAddress, v)) } // IpAddressHasPrefix applies the HasPrefix predicate on the "ipAddress" field. func IpAddressHasPrefix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldHasPrefix(FieldIpAddress, v)) } // IpAddressHasSuffix applies the HasSuffix predicate on the "ipAddress" field. func IpAddressHasSuffix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldHasSuffix(FieldIpAddress, v)) } // IpAddressEqualFold applies the EqualFold predicate on the "ipAddress" field. func IpAddressEqualFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldEqualFold(FieldIpAddress, v)) } // IpAddressContainsFold applies the ContainsFold predicate on the "ipAddress" field. func IpAddressContainsFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldContainsFold(FieldIpAddress, v)) } // ScenariosEQ applies the EQ predicate on the "scenarios" field. func ScenariosEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldScenarios, v)) } // ScenariosNEQ applies the NEQ predicate on the "scenarios" field. func ScenariosNEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldScenarios, v)) } // ScenariosIn applies the In predicate on the "scenarios" field. func ScenariosIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldScenarios), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldScenarios, vs...)) } // ScenariosNotIn applies the NotIn predicate on the "scenarios" field. func ScenariosNotIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldScenarios), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldScenarios, vs...)) } // ScenariosGT applies the GT predicate on the "scenarios" field. func ScenariosGT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldGT(FieldScenarios, v)) } // ScenariosGTE applies the GTE predicate on the "scenarios" field. func ScenariosGTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldScenarios, v)) } // ScenariosLT applies the LT predicate on the "scenarios" field. func ScenariosLT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldLT(FieldScenarios, v)) } // ScenariosLTE applies the LTE predicate on the "scenarios" field. func ScenariosLTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldLTE(FieldScenarios, v)) } // ScenariosContains applies the Contains predicate on the "scenarios" field. func ScenariosContains(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldContains(FieldScenarios, v)) } // ScenariosHasPrefix applies the HasPrefix predicate on the "scenarios" field. func ScenariosHasPrefix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldHasPrefix(FieldScenarios, v)) } // ScenariosHasSuffix applies the HasSuffix predicate on the "scenarios" field. func ScenariosHasSuffix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldHasSuffix(FieldScenarios, v)) } // ScenariosIsNil applies the IsNil predicate on the "scenarios" field. func ScenariosIsNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldScenarios))) - }) + return predicate.Machine(sql.FieldIsNull(FieldScenarios)) } // ScenariosNotNil applies the NotNil predicate on the "scenarios" field. func ScenariosNotNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldScenarios))) - }) + return predicate.Machine(sql.FieldNotNull(FieldScenarios)) } // ScenariosEqualFold applies the EqualFold predicate on the "scenarios" field. func ScenariosEqualFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldEqualFold(FieldScenarios, v)) } // ScenariosContainsFold applies the ContainsFold predicate on the "scenarios" field. func ScenariosContainsFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldContainsFold(FieldScenarios, v)) } // VersionEQ applies the EQ predicate on the "version" field. func VersionEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldVersion, v)) } // VersionNEQ applies the NEQ predicate on the "version" field. func VersionNEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldVersion, v)) } // VersionIn applies the In predicate on the "version" field. func VersionIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldVersion), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldVersion, vs...)) } // VersionNotIn applies the NotIn predicate on the "version" field. func VersionNotIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldVersion), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldVersion, vs...)) } // VersionGT applies the GT predicate on the "version" field. func VersionGT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldGT(FieldVersion, v)) } // VersionGTE applies the GTE predicate on the "version" field. func VersionGTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldVersion, v)) } // VersionLT applies the LT predicate on the "version" field. func VersionLT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldLT(FieldVersion, v)) } // VersionLTE applies the LTE predicate on the "version" field. func VersionLTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldLTE(FieldVersion, v)) } // VersionContains applies the Contains predicate on the "version" field. func VersionContains(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldContains(FieldVersion, v)) } // VersionHasPrefix applies the HasPrefix predicate on the "version" field. func VersionHasPrefix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldHasPrefix(FieldVersion, v)) } // VersionHasSuffix applies the HasSuffix predicate on the "version" field. func VersionHasSuffix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldHasSuffix(FieldVersion, v)) } // VersionIsNil applies the IsNil predicate on the "version" field. func VersionIsNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldVersion))) - }) + return predicate.Machine(sql.FieldIsNull(FieldVersion)) } // VersionNotNil applies the NotNil predicate on the "version" field. func VersionNotNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldVersion))) - }) + return predicate.Machine(sql.FieldNotNull(FieldVersion)) } // VersionEqualFold applies the EqualFold predicate on the "version" field. func VersionEqualFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldEqualFold(FieldVersion, v)) } // VersionContainsFold applies the ContainsFold predicate on the "version" field. func VersionContainsFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldContainsFold(FieldVersion, v)) } // IsValidatedEQ applies the EQ predicate on the "isValidated" field. func IsValidatedEQ(v bool) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldIsValidated), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldIsValidated, v)) } // IsValidatedNEQ applies the NEQ predicate on the "isValidated" field. func IsValidatedNEQ(v bool) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldIsValidated), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldIsValidated, v)) } -// StatusEQ applies the EQ predicate on the "status" field. -func StatusEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStatus), v)) - }) +// AuthTypeEQ applies the EQ predicate on the "auth_type" field. +func AuthTypeEQ(v string) predicate.Machine { + return predicate.Machine(sql.FieldEQ(FieldAuthType, v)) } -// StatusNEQ applies the NEQ predicate on the "status" field. -func StatusNEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldStatus), v)) - }) +// AuthTypeNEQ applies the NEQ predicate on the "auth_type" field. +func AuthTypeNEQ(v string) predicate.Machine { + return predicate.Machine(sql.FieldNEQ(FieldAuthType, v)) } -// StatusIn applies the In predicate on the "status" field. -func StatusIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldStatus), v...)) - }) +// AuthTypeIn applies the In predicate on the "auth_type" field. +func AuthTypeIn(vs ...string) predicate.Machine { + return predicate.Machine(sql.FieldIn(FieldAuthType, vs...)) } -// StatusNotIn applies the NotIn predicate on the "status" field. -func StatusNotIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldStatus), v...)) - }) +// AuthTypeNotIn applies the NotIn predicate on the "auth_type" field. +func AuthTypeNotIn(vs ...string) predicate.Machine { + return predicate.Machine(sql.FieldNotIn(FieldAuthType, vs...)) } -// StatusGT applies the GT predicate on the "status" field. -func StatusGT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldStatus), v)) - }) +// AuthTypeGT applies the GT predicate on the "auth_type" field. +func AuthTypeGT(v string) predicate.Machine { + return predicate.Machine(sql.FieldGT(FieldAuthType, v)) } -// StatusGTE applies the GTE predicate on the "status" field. -func StatusGTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldStatus), v)) - }) +// AuthTypeGTE applies the GTE predicate on the "auth_type" field. +func AuthTypeGTE(v string) predicate.Machine { + return predicate.Machine(sql.FieldGTE(FieldAuthType, v)) } -// StatusLT applies the LT predicate on the "status" field. -func StatusLT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldStatus), v)) - }) +// AuthTypeLT applies the LT predicate on the "auth_type" field. +func AuthTypeLT(v string) predicate.Machine { + return predicate.Machine(sql.FieldLT(FieldAuthType, v)) } -// StatusLTE applies the LTE predicate on the "status" field. -func StatusLTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldStatus), v)) - }) +// AuthTypeLTE applies the LTE predicate on the "auth_type" field. +func AuthTypeLTE(v string) predicate.Machine { + return predicate.Machine(sql.FieldLTE(FieldAuthType, v)) } -// StatusContains applies the Contains predicate on the "status" field. -func StatusContains(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldStatus), v)) - }) +// AuthTypeContains applies the Contains predicate on the "auth_type" field. +func AuthTypeContains(v string) predicate.Machine { + return predicate.Machine(sql.FieldContains(FieldAuthType, v)) } -// StatusHasPrefix applies the HasPrefix predicate on the "status" field. -func StatusHasPrefix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldStatus), v)) - }) +// AuthTypeHasPrefix applies the HasPrefix predicate on the "auth_type" field. +func AuthTypeHasPrefix(v string) predicate.Machine { + return predicate.Machine(sql.FieldHasPrefix(FieldAuthType, v)) } -// StatusHasSuffix applies the HasSuffix predicate on the "status" field. -func StatusHasSuffix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldStatus), v)) - }) +// AuthTypeHasSuffix applies the HasSuffix predicate on the "auth_type" field. +func AuthTypeHasSuffix(v string) predicate.Machine { + return predicate.Machine(sql.FieldHasSuffix(FieldAuthType, v)) } -// StatusIsNil applies the IsNil predicate on the "status" field. -func StatusIsNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldStatus))) - }) +// AuthTypeEqualFold applies the EqualFold predicate on the "auth_type" field. +func AuthTypeEqualFold(v string) predicate.Machine { + return predicate.Machine(sql.FieldEqualFold(FieldAuthType, v)) } -// StatusNotNil applies the NotNil predicate on the "status" field. -func StatusNotNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldStatus))) - }) +// AuthTypeContainsFold applies the ContainsFold predicate on the "auth_type" field. +func AuthTypeContainsFold(v string) predicate.Machine { + return predicate.Machine(sql.FieldContainsFold(FieldAuthType, v)) } -// StatusEqualFold applies the EqualFold predicate on the "status" field. -func StatusEqualFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldStatus), v)) - }) +// OsnameEQ applies the EQ predicate on the "osname" field. +func OsnameEQ(v string) predicate.Machine { + return predicate.Machine(sql.FieldEQ(FieldOsname, v)) } -// StatusContainsFold applies the ContainsFold predicate on the "status" field. -func StatusContainsFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldStatus), v)) - }) +// OsnameNEQ applies the NEQ predicate on the "osname" field. +func OsnameNEQ(v string) predicate.Machine { + return predicate.Machine(sql.FieldNEQ(FieldOsname, v)) } -// AuthTypeEQ applies the EQ predicate on the "auth_type" field. -func AuthTypeEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAuthType), v)) - }) +// OsnameIn applies the In predicate on the "osname" field. +func OsnameIn(vs ...string) predicate.Machine { + return predicate.Machine(sql.FieldIn(FieldOsname, vs...)) } -// AuthTypeNEQ applies the NEQ predicate on the "auth_type" field. -func AuthTypeNEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldAuthType), v)) - }) +// OsnameNotIn applies the NotIn predicate on the "osname" field. +func OsnameNotIn(vs ...string) predicate.Machine { + return predicate.Machine(sql.FieldNotIn(FieldOsname, vs...)) } -// AuthTypeIn applies the In predicate on the "auth_type" field. -func AuthTypeIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldAuthType), v...)) - }) +// OsnameGT applies the GT predicate on the "osname" field. +func OsnameGT(v string) predicate.Machine { + return predicate.Machine(sql.FieldGT(FieldOsname, v)) } -// AuthTypeNotIn applies the NotIn predicate on the "auth_type" field. -func AuthTypeNotIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldAuthType), v...)) - }) +// OsnameGTE applies the GTE predicate on the "osname" field. +func OsnameGTE(v string) predicate.Machine { + return predicate.Machine(sql.FieldGTE(FieldOsname, v)) } -// AuthTypeGT applies the GT predicate on the "auth_type" field. -func AuthTypeGT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldAuthType), v)) - }) +// OsnameLT applies the LT predicate on the "osname" field. +func OsnameLT(v string) predicate.Machine { + return predicate.Machine(sql.FieldLT(FieldOsname, v)) } -// AuthTypeGTE applies the GTE predicate on the "auth_type" field. -func AuthTypeGTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldAuthType), v)) - }) +// OsnameLTE applies the LTE predicate on the "osname" field. +func OsnameLTE(v string) predicate.Machine { + return predicate.Machine(sql.FieldLTE(FieldOsname, v)) } -// AuthTypeLT applies the LT predicate on the "auth_type" field. -func AuthTypeLT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldAuthType), v)) - }) +// OsnameContains applies the Contains predicate on the "osname" field. +func OsnameContains(v string) predicate.Machine { + return predicate.Machine(sql.FieldContains(FieldOsname, v)) } -// AuthTypeLTE applies the LTE predicate on the "auth_type" field. -func AuthTypeLTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldAuthType), v)) - }) +// OsnameHasPrefix applies the HasPrefix predicate on the "osname" field. +func OsnameHasPrefix(v string) predicate.Machine { + return predicate.Machine(sql.FieldHasPrefix(FieldOsname, v)) } -// AuthTypeContains applies the Contains predicate on the "auth_type" field. -func AuthTypeContains(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldAuthType), v)) - }) +// OsnameHasSuffix applies the HasSuffix predicate on the "osname" field. +func OsnameHasSuffix(v string) predicate.Machine { + return predicate.Machine(sql.FieldHasSuffix(FieldOsname, v)) } -// AuthTypeHasPrefix applies the HasPrefix predicate on the "auth_type" field. -func AuthTypeHasPrefix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldAuthType), v)) - }) +// OsnameIsNil applies the IsNil predicate on the "osname" field. +func OsnameIsNil() predicate.Machine { + return predicate.Machine(sql.FieldIsNull(FieldOsname)) } -// AuthTypeHasSuffix applies the HasSuffix predicate on the "auth_type" field. -func AuthTypeHasSuffix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldAuthType), v)) - }) +// OsnameNotNil applies the NotNil predicate on the "osname" field. +func OsnameNotNil() predicate.Machine { + return predicate.Machine(sql.FieldNotNull(FieldOsname)) } -// AuthTypeEqualFold applies the EqualFold predicate on the "auth_type" field. -func AuthTypeEqualFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldAuthType), v)) - }) +// OsnameEqualFold applies the EqualFold predicate on the "osname" field. +func OsnameEqualFold(v string) predicate.Machine { + return predicate.Machine(sql.FieldEqualFold(FieldOsname, v)) } -// AuthTypeContainsFold applies the ContainsFold predicate on the "auth_type" field. -func AuthTypeContainsFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldAuthType), v)) - }) +// OsnameContainsFold applies the ContainsFold predicate on the "osname" field. +func OsnameContainsFold(v string) predicate.Machine { + return predicate.Machine(sql.FieldContainsFold(FieldOsname, v)) +} + +// OsversionEQ applies the EQ predicate on the "osversion" field. +func OsversionEQ(v string) predicate.Machine { + return predicate.Machine(sql.FieldEQ(FieldOsversion, v)) +} + +// OsversionNEQ applies the NEQ predicate on the "osversion" field. +func OsversionNEQ(v string) predicate.Machine { + return predicate.Machine(sql.FieldNEQ(FieldOsversion, v)) +} + +// OsversionIn applies the In predicate on the "osversion" field. +func OsversionIn(vs ...string) predicate.Machine { + return predicate.Machine(sql.FieldIn(FieldOsversion, vs...)) +} + +// OsversionNotIn applies the NotIn predicate on the "osversion" field. +func OsversionNotIn(vs ...string) predicate.Machine { + return predicate.Machine(sql.FieldNotIn(FieldOsversion, vs...)) +} + +// OsversionGT applies the GT predicate on the "osversion" field. +func OsversionGT(v string) predicate.Machine { + return predicate.Machine(sql.FieldGT(FieldOsversion, v)) +} + +// OsversionGTE applies the GTE predicate on the "osversion" field. +func OsversionGTE(v string) predicate.Machine { + return predicate.Machine(sql.FieldGTE(FieldOsversion, v)) +} + +// OsversionLT applies the LT predicate on the "osversion" field. +func OsversionLT(v string) predicate.Machine { + return predicate.Machine(sql.FieldLT(FieldOsversion, v)) +} + +// OsversionLTE applies the LTE predicate on the "osversion" field. +func OsversionLTE(v string) predicate.Machine { + return predicate.Machine(sql.FieldLTE(FieldOsversion, v)) +} + +// OsversionContains applies the Contains predicate on the "osversion" field. +func OsversionContains(v string) predicate.Machine { + return predicate.Machine(sql.FieldContains(FieldOsversion, v)) +} + +// OsversionHasPrefix applies the HasPrefix predicate on the "osversion" field. +func OsversionHasPrefix(v string) predicate.Machine { + return predicate.Machine(sql.FieldHasPrefix(FieldOsversion, v)) +} + +// OsversionHasSuffix applies the HasSuffix predicate on the "osversion" field. +func OsversionHasSuffix(v string) predicate.Machine { + return predicate.Machine(sql.FieldHasSuffix(FieldOsversion, v)) +} + +// OsversionIsNil applies the IsNil predicate on the "osversion" field. +func OsversionIsNil() predicate.Machine { + return predicate.Machine(sql.FieldIsNull(FieldOsversion)) +} + +// OsversionNotNil applies the NotNil predicate on the "osversion" field. +func OsversionNotNil() predicate.Machine { + return predicate.Machine(sql.FieldNotNull(FieldOsversion)) +} + +// OsversionEqualFold applies the EqualFold predicate on the "osversion" field. +func OsversionEqualFold(v string) predicate.Machine { + return predicate.Machine(sql.FieldEqualFold(FieldOsversion, v)) +} + +// OsversionContainsFold applies the ContainsFold predicate on the "osversion" field. +func OsversionContainsFold(v string) predicate.Machine { + return predicate.Machine(sql.FieldContainsFold(FieldOsversion, v)) +} + +// FeatureflagsEQ applies the EQ predicate on the "featureflags" field. +func FeatureflagsEQ(v string) predicate.Machine { + return predicate.Machine(sql.FieldEQ(FieldFeatureflags, v)) +} + +// FeatureflagsNEQ applies the NEQ predicate on the "featureflags" field. +func FeatureflagsNEQ(v string) predicate.Machine { + return predicate.Machine(sql.FieldNEQ(FieldFeatureflags, v)) +} + +// FeatureflagsIn applies the In predicate on the "featureflags" field. +func FeatureflagsIn(vs ...string) predicate.Machine { + return predicate.Machine(sql.FieldIn(FieldFeatureflags, vs...)) +} + +// FeatureflagsNotIn applies the NotIn predicate on the "featureflags" field. +func FeatureflagsNotIn(vs ...string) predicate.Machine { + return predicate.Machine(sql.FieldNotIn(FieldFeatureflags, vs...)) +} + +// FeatureflagsGT applies the GT predicate on the "featureflags" field. +func FeatureflagsGT(v string) predicate.Machine { + return predicate.Machine(sql.FieldGT(FieldFeatureflags, v)) +} + +// FeatureflagsGTE applies the GTE predicate on the "featureflags" field. +func FeatureflagsGTE(v string) predicate.Machine { + return predicate.Machine(sql.FieldGTE(FieldFeatureflags, v)) +} + +// FeatureflagsLT applies the LT predicate on the "featureflags" field. +func FeatureflagsLT(v string) predicate.Machine { + return predicate.Machine(sql.FieldLT(FieldFeatureflags, v)) +} + +// FeatureflagsLTE applies the LTE predicate on the "featureflags" field. +func FeatureflagsLTE(v string) predicate.Machine { + return predicate.Machine(sql.FieldLTE(FieldFeatureflags, v)) +} + +// FeatureflagsContains applies the Contains predicate on the "featureflags" field. +func FeatureflagsContains(v string) predicate.Machine { + return predicate.Machine(sql.FieldContains(FieldFeatureflags, v)) +} + +// FeatureflagsHasPrefix applies the HasPrefix predicate on the "featureflags" field. +func FeatureflagsHasPrefix(v string) predicate.Machine { + return predicate.Machine(sql.FieldHasPrefix(FieldFeatureflags, v)) +} + +// FeatureflagsHasSuffix applies the HasSuffix predicate on the "featureflags" field. +func FeatureflagsHasSuffix(v string) predicate.Machine { + return predicate.Machine(sql.FieldHasSuffix(FieldFeatureflags, v)) +} + +// FeatureflagsIsNil applies the IsNil predicate on the "featureflags" field. +func FeatureflagsIsNil() predicate.Machine { + return predicate.Machine(sql.FieldIsNull(FieldFeatureflags)) +} + +// FeatureflagsNotNil applies the NotNil predicate on the "featureflags" field. +func FeatureflagsNotNil() predicate.Machine { + return predicate.Machine(sql.FieldNotNull(FieldFeatureflags)) +} + +// FeatureflagsEqualFold applies the EqualFold predicate on the "featureflags" field. +func FeatureflagsEqualFold(v string) predicate.Machine { + return predicate.Machine(sql.FieldEqualFold(FieldFeatureflags, v)) +} + +// FeatureflagsContainsFold applies the ContainsFold predicate on the "featureflags" field. +func FeatureflagsContainsFold(v string) predicate.Machine { + return predicate.Machine(sql.FieldContainsFold(FieldFeatureflags, v)) +} + +// HubstateIsNil applies the IsNil predicate on the "hubstate" field. +func HubstateIsNil() predicate.Machine { + return predicate.Machine(sql.FieldIsNull(FieldHubstate)) +} + +// HubstateNotNil applies the NotNil predicate on the "hubstate" field. +func HubstateNotNil() predicate.Machine { + return predicate.Machine(sql.FieldNotNull(FieldHubstate)) +} + +// DatasourcesIsNil applies the IsNil predicate on the "datasources" field. +func DatasourcesIsNil() predicate.Machine { + return predicate.Machine(sql.FieldIsNull(FieldDatasources)) +} + +// DatasourcesNotNil applies the NotNil predicate on the "datasources" field. +func DatasourcesNotNil() predicate.Machine { + return predicate.Machine(sql.FieldNotNull(FieldDatasources)) } // HasAlerts applies the HasEdge predicate on the "alerts" edge. @@ -1231,7 +975,6 @@ func HasAlerts() predicate.Machine { return predicate.Machine(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(AlertsTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, false, AlertsTable, AlertsColumn), ) sqlgraph.HasNeighbors(s, step) @@ -1241,11 +984,7 @@ func HasAlerts() predicate.Machine { // HasAlertsWith applies the HasEdge predicate on the "alerts" edge with a given conditions (other predicates). func HasAlertsWith(preds ...predicate.Alert) predicate.Machine { return predicate.Machine(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(AlertsInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.O2M, false, AlertsTable, AlertsColumn), - ) + step := newAlertsStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -1256,32 +995,15 @@ func HasAlertsWith(preds ...predicate.Alert) predicate.Machine { // And groups predicates with the AND operator between them. func And(predicates ...predicate.Machine) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Machine(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.Machine) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Machine(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.Machine) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.Machine(sql.NotPredicates(p)) } diff --git a/pkg/database/ent/machine_create.go b/pkg/database/ent/machine_create.go index efe02782f6b..fba8400798c 100644 --- a/pkg/database/ent/machine_create.go +++ b/pkg/database/ent/machine_create.go @@ -12,6 +12,7 @@ import ( "entgo.io/ent/schema/field" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/schema" ) // MachineCreate is the builder for creating a Machine entity. @@ -137,34 +138,74 @@ func (mc *MachineCreate) SetNillableIsValidated(b *bool) *MachineCreate { return mc } -// SetStatus sets the "status" field. -func (mc *MachineCreate) SetStatus(s string) *MachineCreate { - mc.mutation.SetStatus(s) +// SetAuthType sets the "auth_type" field. +func (mc *MachineCreate) SetAuthType(s string) *MachineCreate { + mc.mutation.SetAuthType(s) return mc } -// SetNillableStatus sets the "status" field if the given value is not nil. -func (mc *MachineCreate) SetNillableStatus(s *string) *MachineCreate { +// SetNillableAuthType sets the "auth_type" field if the given value is not nil. +func (mc *MachineCreate) SetNillableAuthType(s *string) *MachineCreate { if s != nil { - mc.SetStatus(*s) + mc.SetAuthType(*s) } return mc } -// SetAuthType sets the "auth_type" field. -func (mc *MachineCreate) SetAuthType(s string) *MachineCreate { - mc.mutation.SetAuthType(s) +// SetOsname sets the "osname" field. +func (mc *MachineCreate) SetOsname(s string) *MachineCreate { + mc.mutation.SetOsname(s) return mc } -// SetNillableAuthType sets the "auth_type" field if the given value is not nil. -func (mc *MachineCreate) SetNillableAuthType(s *string) *MachineCreate { +// SetNillableOsname sets the "osname" field if the given value is not nil. +func (mc *MachineCreate) SetNillableOsname(s *string) *MachineCreate { if s != nil { - mc.SetAuthType(*s) + mc.SetOsname(*s) } return mc } +// SetOsversion sets the "osversion" field. +func (mc *MachineCreate) SetOsversion(s string) *MachineCreate { + mc.mutation.SetOsversion(s) + return mc +} + +// SetNillableOsversion sets the "osversion" field if the given value is not nil. +func (mc *MachineCreate) SetNillableOsversion(s *string) *MachineCreate { + if s != nil { + mc.SetOsversion(*s) + } + return mc +} + +// SetFeatureflags sets the "featureflags" field. +func (mc *MachineCreate) SetFeatureflags(s string) *MachineCreate { + mc.mutation.SetFeatureflags(s) + return mc +} + +// SetNillableFeatureflags sets the "featureflags" field if the given value is not nil. +func (mc *MachineCreate) SetNillableFeatureflags(s *string) *MachineCreate { + if s != nil { + mc.SetFeatureflags(*s) + } + return mc +} + +// SetHubstate sets the "hubstate" field. +func (mc *MachineCreate) SetHubstate(ms map[string][]schema.ItemState) *MachineCreate { + mc.mutation.SetHubstate(ms) + return mc +} + +// SetDatasources sets the "datasources" field. +func (mc *MachineCreate) SetDatasources(m map[string]int64) *MachineCreate { + mc.mutation.SetDatasources(m) + return mc +} + // AddAlertIDs adds the "alerts" edge to the Alert entity by IDs. func (mc *MachineCreate) AddAlertIDs(ids ...int) *MachineCreate { mc.mutation.AddAlertIDs(ids...) @@ -187,50 +228,8 @@ func (mc *MachineCreate) Mutation() *MachineMutation { // Save creates the Machine in the database. func (mc *MachineCreate) Save(ctx context.Context) (*Machine, error) { - var ( - err error - node *Machine - ) mc.defaults() - if len(mc.hooks) == 0 { - if err = mc.check(); err != nil { - return nil, err - } - node, err = mc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MachineMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = mc.check(); err != nil { - return nil, err - } - mc.mutation = mutation - if node, err = mc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(mc.hooks) - 1; i >= 0; i-- { - if mc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, mc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Machine) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MachineMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, mc.sqlSave, mc.mutation, mc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -269,10 +268,6 @@ func (mc *MachineCreate) defaults() { v := machine.DefaultLastPush() mc.mutation.SetLastPush(v) } - if _, ok := mc.mutation.LastHeartbeat(); !ok { - v := machine.DefaultLastHeartbeat() - mc.mutation.SetLastHeartbeat(v) - } if _, ok := mc.mutation.IsValidated(); !ok { v := machine.DefaultIsValidated mc.mutation.SetIsValidated(v) @@ -285,6 +280,12 @@ func (mc *MachineCreate) defaults() { // check runs all checks and user-defined validators on the builder. func (mc *MachineCreate) check() error { + if _, ok := mc.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Machine.created_at"`)} + } + if _, ok := mc.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "Machine.updated_at"`)} + } if _, ok := mc.mutation.MachineId(); !ok { return &ValidationError{Name: "machineId", err: errors.New(`ent: missing required field "Machine.machineId"`)} } @@ -309,6 +310,9 @@ func (mc *MachineCreate) check() error { } func (mc *MachineCreate) sqlSave(ctx context.Context) (*Machine, error) { + if err := mc.check(); err != nil { + return nil, err + } _node, _spec := mc.createSpec() if err := sqlgraph.CreateNode(ctx, mc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -318,116 +322,80 @@ func (mc *MachineCreate) sqlSave(ctx context.Context) (*Machine, error) { } id := _spec.ID.Value.(int64) _node.ID = int(id) + mc.mutation.id = &_node.ID + mc.mutation.done = true return _node, nil } func (mc *MachineCreate) createSpec() (*Machine, *sqlgraph.CreateSpec) { var ( _node = &Machine{config: mc.config} - _spec = &sqlgraph.CreateSpec{ - Table: machine.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(machine.Table, sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt)) ) if value, ok := mc.mutation.CreatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldCreatedAt, - }) - _node.CreatedAt = &value + _spec.SetField(machine.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value } if value, ok := mc.mutation.UpdatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldUpdatedAt, - }) - _node.UpdatedAt = &value + _spec.SetField(machine.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value } if value, ok := mc.mutation.LastPush(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldLastPush, - }) + _spec.SetField(machine.FieldLastPush, field.TypeTime, value) _node.LastPush = &value } if value, ok := mc.mutation.LastHeartbeat(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldLastHeartbeat, - }) + _spec.SetField(machine.FieldLastHeartbeat, field.TypeTime, value) _node.LastHeartbeat = &value } if value, ok := mc.mutation.MachineId(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldMachineId, - }) + _spec.SetField(machine.FieldMachineId, field.TypeString, value) _node.MachineId = value } if value, ok := mc.mutation.Password(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldPassword, - }) + _spec.SetField(machine.FieldPassword, field.TypeString, value) _node.Password = value } if value, ok := mc.mutation.IpAddress(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldIpAddress, - }) + _spec.SetField(machine.FieldIpAddress, field.TypeString, value) _node.IpAddress = value } if value, ok := mc.mutation.Scenarios(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldScenarios, - }) + _spec.SetField(machine.FieldScenarios, field.TypeString, value) _node.Scenarios = value } if value, ok := mc.mutation.Version(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldVersion, - }) + _spec.SetField(machine.FieldVersion, field.TypeString, value) _node.Version = value } if value, ok := mc.mutation.IsValidated(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: machine.FieldIsValidated, - }) + _spec.SetField(machine.FieldIsValidated, field.TypeBool, value) _node.IsValidated = value } - if value, ok := mc.mutation.Status(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldStatus, - }) - _node.Status = value - } if value, ok := mc.mutation.AuthType(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldAuthType, - }) + _spec.SetField(machine.FieldAuthType, field.TypeString, value) _node.AuthType = value } + if value, ok := mc.mutation.Osname(); ok { + _spec.SetField(machine.FieldOsname, field.TypeString, value) + _node.Osname = value + } + if value, ok := mc.mutation.Osversion(); ok { + _spec.SetField(machine.FieldOsversion, field.TypeString, value) + _node.Osversion = value + } + if value, ok := mc.mutation.Featureflags(); ok { + _spec.SetField(machine.FieldFeatureflags, field.TypeString, value) + _node.Featureflags = value + } + if value, ok := mc.mutation.Hubstate(); ok { + _spec.SetField(machine.FieldHubstate, field.TypeJSON, value) + _node.Hubstate = value + } + if value, ok := mc.mutation.Datasources(); ok { + _spec.SetField(machine.FieldDatasources, field.TypeJSON, value) + _node.Datasources = value + } if nodes := mc.mutation.AlertsIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -436,10 +404,7 @@ func (mc *MachineCreate) createSpec() (*Machine, *sqlgraph.CreateSpec) { Columns: []string{machine.AlertsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -453,11 +418,15 @@ func (mc *MachineCreate) createSpec() (*Machine, *sqlgraph.CreateSpec) { // MachineCreateBulk is the builder for creating many Machine entities in bulk. type MachineCreateBulk struct { config + err error builders []*MachineCreate } // Save creates the Machine entities in the database. func (mcb *MachineCreateBulk) Save(ctx context.Context) ([]*Machine, error) { + if mcb.err != nil { + return nil, mcb.err + } specs := make([]*sqlgraph.CreateSpec, len(mcb.builders)) nodes := make([]*Machine, len(mcb.builders)) mutators := make([]Mutator, len(mcb.builders)) @@ -474,8 +443,8 @@ func (mcb *MachineCreateBulk) Save(ctx context.Context) ([]*Machine, error) { return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, mcb.builders[i+1].mutation) } else { diff --git a/pkg/database/ent/machine_delete.go b/pkg/database/ent/machine_delete.go index bead8acb46d..ac3aa751d5e 100644 --- a/pkg/database/ent/machine_delete.go +++ b/pkg/database/ent/machine_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (md *MachineDelete) Where(ps ...predicate.Machine) *MachineDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (md *MachineDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(md.hooks) == 0 { - affected, err = md.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MachineMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - md.mutation = mutation - affected, err = md.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(md.hooks) - 1; i >= 0; i-- { - if md.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = md.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, md.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, md.sqlExec, md.mutation, md.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (md *MachineDelete) ExecX(ctx context.Context) int { } func (md *MachineDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: machine.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(machine.Table, sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt)) if ps := md.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (md *MachineDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + md.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type MachineDeleteOne struct { md *MachineDelete } +// Where appends a list predicates to the MachineDelete builder. +func (mdo *MachineDeleteOne) Where(ps ...predicate.Machine) *MachineDeleteOne { + mdo.md.mutation.Where(ps...) + return mdo +} + // Exec executes the deletion query. func (mdo *MachineDeleteOne) Exec(ctx context.Context) error { n, err := mdo.md.Exec(ctx) @@ -111,5 +82,7 @@ func (mdo *MachineDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (mdo *MachineDeleteOne) ExecX(ctx context.Context) { - mdo.md.ExecX(ctx) + if err := mdo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/pkg/database/ent/machine_query.go b/pkg/database/ent/machine_query.go index 2839142196b..462c2cf35b1 100644 --- a/pkg/database/ent/machine_query.go +++ b/pkg/database/ent/machine_query.go @@ -19,11 +19,9 @@ import ( // MachineQuery is the builder for querying Machine entities. type MachineQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string + ctx *QueryContext + order []machine.OrderOption + inters []Interceptor predicates []predicate.Machine withAlerts *AlertQuery // intermediate query (i.e. traversal path). @@ -37,34 +35,34 @@ func (mq *MachineQuery) Where(ps ...predicate.Machine) *MachineQuery { return mq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (mq *MachineQuery) Limit(limit int) *MachineQuery { - mq.limit = &limit + mq.ctx.Limit = &limit return mq } -// Offset adds an offset step to the query. +// Offset to start from. func (mq *MachineQuery) Offset(offset int) *MachineQuery { - mq.offset = &offset + mq.ctx.Offset = &offset return mq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (mq *MachineQuery) Unique(unique bool) *MachineQuery { - mq.unique = &unique + mq.ctx.Unique = &unique return mq } -// Order adds an order step to the query. -func (mq *MachineQuery) Order(o ...OrderFunc) *MachineQuery { +// Order specifies how the records should be ordered. +func (mq *MachineQuery) Order(o ...machine.OrderOption) *MachineQuery { mq.order = append(mq.order, o...) return mq } // QueryAlerts chains the current query on the "alerts" edge. func (mq *MachineQuery) QueryAlerts() *AlertQuery { - query := &AlertQuery{config: mq.config} + query := (&AlertClient{config: mq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := mq.prepareQuery(ctx); err != nil { return nil, err @@ -87,7 +85,7 @@ func (mq *MachineQuery) QueryAlerts() *AlertQuery { // First returns the first Machine entity from the query. // Returns a *NotFoundError when no Machine was found. func (mq *MachineQuery) First(ctx context.Context) (*Machine, error) { - nodes, err := mq.Limit(1).All(ctx) + nodes, err := mq.Limit(1).All(setContextOp(ctx, mq.ctx, "First")) if err != nil { return nil, err } @@ -110,7 +108,7 @@ func (mq *MachineQuery) FirstX(ctx context.Context) *Machine { // Returns a *NotFoundError when no Machine ID was found. func (mq *MachineQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = mq.Limit(1).IDs(ctx); err != nil { + if ids, err = mq.Limit(1).IDs(setContextOp(ctx, mq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -133,7 +131,7 @@ func (mq *MachineQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Machine entity is found. // Returns a *NotFoundError when no Machine entities are found. func (mq *MachineQuery) Only(ctx context.Context) (*Machine, error) { - nodes, err := mq.Limit(2).All(ctx) + nodes, err := mq.Limit(2).All(setContextOp(ctx, mq.ctx, "Only")) if err != nil { return nil, err } @@ -161,7 +159,7 @@ func (mq *MachineQuery) OnlyX(ctx context.Context) *Machine { // Returns a *NotFoundError when no entities are found. func (mq *MachineQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = mq.Limit(2).IDs(ctx); err != nil { + if ids, err = mq.Limit(2).IDs(setContextOp(ctx, mq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -186,10 +184,12 @@ func (mq *MachineQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Machines. func (mq *MachineQuery) All(ctx context.Context) ([]*Machine, error) { + ctx = setContextOp(ctx, mq.ctx, "All") if err := mq.prepareQuery(ctx); err != nil { return nil, err } - return mq.sqlAll(ctx) + qr := querierAll[[]*Machine, *MachineQuery]() + return withInterceptors[[]*Machine](ctx, mq, qr, mq.inters) } // AllX is like All, but panics if an error occurs. @@ -202,9 +202,12 @@ func (mq *MachineQuery) AllX(ctx context.Context) []*Machine { } // IDs executes the query and returns a list of Machine IDs. -func (mq *MachineQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := mq.Select(machine.FieldID).Scan(ctx, &ids); err != nil { +func (mq *MachineQuery) IDs(ctx context.Context) (ids []int, err error) { + if mq.ctx.Unique == nil && mq.path != nil { + mq.Unique(true) + } + ctx = setContextOp(ctx, mq.ctx, "IDs") + if err = mq.Select(machine.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -221,10 +224,11 @@ func (mq *MachineQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (mq *MachineQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, mq.ctx, "Count") if err := mq.prepareQuery(ctx); err != nil { return 0, err } - return mq.sqlCount(ctx) + return withInterceptors[int](ctx, mq, querierCount[*MachineQuery](), mq.inters) } // CountX is like Count, but panics if an error occurs. @@ -238,10 +242,15 @@ func (mq *MachineQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (mq *MachineQuery) Exist(ctx context.Context) (bool, error) { - if err := mq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, mq.ctx, "Exist") + switch _, err := mq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return mq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -261,22 +270,21 @@ func (mq *MachineQuery) Clone() *MachineQuery { } return &MachineQuery{ config: mq.config, - limit: mq.limit, - offset: mq.offset, - order: append([]OrderFunc{}, mq.order...), + ctx: mq.ctx.Clone(), + order: append([]machine.OrderOption{}, mq.order...), + inters: append([]Interceptor{}, mq.inters...), predicates: append([]predicate.Machine{}, mq.predicates...), withAlerts: mq.withAlerts.Clone(), // clone intermediate query. - sql: mq.sql.Clone(), - path: mq.path, - unique: mq.unique, + sql: mq.sql.Clone(), + path: mq.path, } } // WithAlerts tells the query-builder to eager-load the nodes that are connected to // the "alerts" edge. The optional arguments are used to configure the query builder of the edge. func (mq *MachineQuery) WithAlerts(opts ...func(*AlertQuery)) *MachineQuery { - query := &AlertQuery{config: mq.config} + query := (&AlertClient{config: mq.config}).Query() for _, opt := range opts { opt(query) } @@ -299,16 +307,11 @@ func (mq *MachineQuery) WithAlerts(opts ...func(*AlertQuery)) *MachineQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (mq *MachineQuery) GroupBy(field string, fields ...string) *MachineGroupBy { - grbuild := &MachineGroupBy{config: mq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := mq.prepareQuery(ctx); err != nil { - return nil, err - } - return mq.sqlQuery(ctx), nil - } + mq.ctx.Fields = append([]string{field}, fields...) + grbuild := &MachineGroupBy{build: mq} + grbuild.flds = &mq.ctx.Fields grbuild.label = machine.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -325,15 +328,30 @@ func (mq *MachineQuery) GroupBy(field string, fields ...string) *MachineGroupBy // Select(machine.FieldCreatedAt). // Scan(ctx, &v) func (mq *MachineQuery) Select(fields ...string) *MachineSelect { - mq.fields = append(mq.fields, fields...) - selbuild := &MachineSelect{MachineQuery: mq} - selbuild.label = machine.Label - selbuild.flds, selbuild.scan = &mq.fields, selbuild.Scan - return selbuild + mq.ctx.Fields = append(mq.ctx.Fields, fields...) + sbuild := &MachineSelect{MachineQuery: mq} + sbuild.label = machine.Label + sbuild.flds, sbuild.scan = &mq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a MachineSelect configured with the given aggregations. +func (mq *MachineQuery) Aggregate(fns ...AggregateFunc) *MachineSelect { + return mq.Select().Aggregate(fns...) } func (mq *MachineQuery) prepareQuery(ctx context.Context) error { - for _, f := range mq.fields { + for _, inter := range mq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, mq); err != nil { + return err + } + } + } + for _, f := range mq.ctx.Fields { if !machine.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -396,7 +414,7 @@ func (mq *MachineQuery) loadAlerts(ctx context.Context, query *AlertQuery, nodes } query.withFKs = true query.Where(predicate.Alert(func(s *sql.Selector) { - s.Where(sql.InValues(machine.AlertsColumn, fks...)) + s.Where(sql.InValues(s.C(machine.AlertsColumn), fks...)) })) neighbors, err := query.All(ctx) if err != nil { @@ -409,7 +427,7 @@ func (mq *MachineQuery) loadAlerts(ctx context.Context, query *AlertQuery, nodes } node, ok := nodeids[*fk] if !ok { - return fmt.Errorf(`unexpected foreign-key "machine_alerts" returned %v for node %v`, *fk, n.ID) + return fmt.Errorf(`unexpected referenced foreign-key "machine_alerts" returned %v for node %v`, *fk, n.ID) } assign(node, n) } @@ -418,41 +436,22 @@ func (mq *MachineQuery) loadAlerts(ctx context.Context, query *AlertQuery, nodes func (mq *MachineQuery) sqlCount(ctx context.Context) (int, error) { _spec := mq.querySpec() - _spec.Node.Columns = mq.fields - if len(mq.fields) > 0 { - _spec.Unique = mq.unique != nil && *mq.unique + _spec.Node.Columns = mq.ctx.Fields + if len(mq.ctx.Fields) > 0 { + _spec.Unique = mq.ctx.Unique != nil && *mq.ctx.Unique } return sqlgraph.CountNodes(ctx, mq.driver, _spec) } -func (mq *MachineQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := mq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (mq *MachineQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: machine.Table, - Columns: machine.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, - }, - From: mq.sql, - Unique: true, - } - if unique := mq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(machine.Table, machine.Columns, sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt)) + _spec.From = mq.sql + if unique := mq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if mq.path != nil { + _spec.Unique = true } - if fields := mq.fields; len(fields) > 0 { + if fields := mq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, machine.FieldID) for i := range fields { @@ -468,10 +467,10 @@ func (mq *MachineQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := mq.limit; limit != nil { + if limit := mq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := mq.offset; offset != nil { + if offset := mq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := mq.order; len(ps) > 0 { @@ -487,7 +486,7 @@ func (mq *MachineQuery) querySpec() *sqlgraph.QuerySpec { func (mq *MachineQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(mq.driver.Dialect()) t1 := builder.Table(machine.Table) - columns := mq.fields + columns := mq.ctx.Fields if len(columns) == 0 { columns = machine.Columns } @@ -496,7 +495,7 @@ func (mq *MachineQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = mq.sql selector.Select(selector.Columns(columns...)...) } - if mq.unique != nil && *mq.unique { + if mq.ctx.Unique != nil && *mq.ctx.Unique { selector.Distinct() } for _, p := range mq.predicates { @@ -505,12 +504,12 @@ func (mq *MachineQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range mq.order { p(selector) } - if offset := mq.offset; offset != nil { + if offset := mq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := mq.limit; limit != nil { + if limit := mq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -518,13 +517,8 @@ func (mq *MachineQuery) sqlQuery(ctx context.Context) *sql.Selector { // MachineGroupBy is the group-by builder for Machine entities. type MachineGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *MachineQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -533,74 +527,77 @@ func (mgb *MachineGroupBy) Aggregate(fns ...AggregateFunc) *MachineGroupBy { return mgb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (mgb *MachineGroupBy) Scan(ctx context.Context, v any) error { - query, err := mgb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, mgb.build.ctx, "GroupBy") + if err := mgb.build.prepareQuery(ctx); err != nil { return err } - mgb.sql = query - return mgb.sqlScan(ctx, v) + return scanWithInterceptors[*MachineQuery, *MachineGroupBy](ctx, mgb.build, mgb, mgb.build.inters, v) } -func (mgb *MachineGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range mgb.fields { - if !machine.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (mgb *MachineGroupBy) sqlScan(ctx context.Context, root *MachineQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(mgb.fns)) + for _, fn := range mgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*mgb.flds)+len(mgb.fns)) + for _, f := range *mgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := mgb.sqlQuery() + selector.GroupBy(selector.Columns(*mgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := mgb.driver.Query(ctx, query, args, rows); err != nil { + if err := mgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (mgb *MachineGroupBy) sqlQuery() *sql.Selector { - selector := mgb.sql.Select() - aggregation := make([]string, 0, len(mgb.fns)) - for _, fn := range mgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(mgb.fields)+len(mgb.fns)) - for _, f := range mgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(mgb.fields...)...) -} - // MachineSelect is the builder for selecting fields of Machine entities. type MachineSelect struct { *MachineQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ms *MachineSelect) Aggregate(fns ...AggregateFunc) *MachineSelect { + ms.fns = append(ms.fns, fns...) + return ms } // Scan applies the selector query and scans the result into the given value. func (ms *MachineSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ms.ctx, "Select") if err := ms.prepareQuery(ctx); err != nil { return err } - ms.sql = ms.MachineQuery.sqlQuery(ctx) - return ms.sqlScan(ctx, v) + return scanWithInterceptors[*MachineQuery, *MachineSelect](ctx, ms.MachineQuery, ms, ms.inters, v) } -func (ms *MachineSelect) sqlScan(ctx context.Context, v any) error { +func (ms *MachineSelect) sqlScan(ctx context.Context, root *MachineQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ms.fns)) + for _, fn := range ms.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ms.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := ms.sql.Query() + query, args := selector.Query() if err := ms.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/pkg/database/ent/machine_update.go b/pkg/database/ent/machine_update.go index de9f8d12460..531baabf0d6 100644 --- a/pkg/database/ent/machine_update.go +++ b/pkg/database/ent/machine_update.go @@ -14,6 +14,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/schema" ) // MachineUpdate is the builder for updating Machine entities. @@ -29,36 +30,26 @@ func (mu *MachineUpdate) Where(ps ...predicate.Machine) *MachineUpdate { return mu } -// SetCreatedAt sets the "created_at" field. -func (mu *MachineUpdate) SetCreatedAt(t time.Time) *MachineUpdate { - mu.mutation.SetCreatedAt(t) - return mu -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (mu *MachineUpdate) ClearCreatedAt() *MachineUpdate { - mu.mutation.ClearCreatedAt() - return mu -} - // SetUpdatedAt sets the "updated_at" field. func (mu *MachineUpdate) SetUpdatedAt(t time.Time) *MachineUpdate { mu.mutation.SetUpdatedAt(t) return mu } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (mu *MachineUpdate) ClearUpdatedAt() *MachineUpdate { - mu.mutation.ClearUpdatedAt() - return mu -} - // SetLastPush sets the "last_push" field. func (mu *MachineUpdate) SetLastPush(t time.Time) *MachineUpdate { mu.mutation.SetLastPush(t) return mu } +// SetNillableLastPush sets the "last_push" field if the given value is not nil. +func (mu *MachineUpdate) SetNillableLastPush(t *time.Time) *MachineUpdate { + if t != nil { + mu.SetLastPush(*t) + } + return mu +} + // ClearLastPush clears the value of the "last_push" field. func (mu *MachineUpdate) ClearLastPush() *MachineUpdate { mu.mutation.ClearLastPush() @@ -71,15 +62,17 @@ func (mu *MachineUpdate) SetLastHeartbeat(t time.Time) *MachineUpdate { return mu } -// ClearLastHeartbeat clears the value of the "last_heartbeat" field. -func (mu *MachineUpdate) ClearLastHeartbeat() *MachineUpdate { - mu.mutation.ClearLastHeartbeat() +// SetNillableLastHeartbeat sets the "last_heartbeat" field if the given value is not nil. +func (mu *MachineUpdate) SetNillableLastHeartbeat(t *time.Time) *MachineUpdate { + if t != nil { + mu.SetLastHeartbeat(*t) + } return mu } -// SetMachineId sets the "machineId" field. -func (mu *MachineUpdate) SetMachineId(s string) *MachineUpdate { - mu.mutation.SetMachineId(s) +// ClearLastHeartbeat clears the value of the "last_heartbeat" field. +func (mu *MachineUpdate) ClearLastHeartbeat() *MachineUpdate { + mu.mutation.ClearLastHeartbeat() return mu } @@ -89,12 +82,28 @@ func (mu *MachineUpdate) SetPassword(s string) *MachineUpdate { return mu } +// SetNillablePassword sets the "password" field if the given value is not nil. +func (mu *MachineUpdate) SetNillablePassword(s *string) *MachineUpdate { + if s != nil { + mu.SetPassword(*s) + } + return mu +} + // SetIpAddress sets the "ipAddress" field. func (mu *MachineUpdate) SetIpAddress(s string) *MachineUpdate { mu.mutation.SetIpAddress(s) return mu } +// SetNillableIpAddress sets the "ipAddress" field if the given value is not nil. +func (mu *MachineUpdate) SetNillableIpAddress(s *string) *MachineUpdate { + if s != nil { + mu.SetIpAddress(*s) + } + return mu +} + // SetScenarios sets the "scenarios" field. func (mu *MachineUpdate) SetScenarios(s string) *MachineUpdate { mu.mutation.SetScenarios(s) @@ -149,40 +158,104 @@ func (mu *MachineUpdate) SetNillableIsValidated(b *bool) *MachineUpdate { return mu } -// SetStatus sets the "status" field. -func (mu *MachineUpdate) SetStatus(s string) *MachineUpdate { - mu.mutation.SetStatus(s) +// SetAuthType sets the "auth_type" field. +func (mu *MachineUpdate) SetAuthType(s string) *MachineUpdate { + mu.mutation.SetAuthType(s) return mu } -// SetNillableStatus sets the "status" field if the given value is not nil. -func (mu *MachineUpdate) SetNillableStatus(s *string) *MachineUpdate { +// SetNillableAuthType sets the "auth_type" field if the given value is not nil. +func (mu *MachineUpdate) SetNillableAuthType(s *string) *MachineUpdate { if s != nil { - mu.SetStatus(*s) + mu.SetAuthType(*s) } return mu } -// ClearStatus clears the value of the "status" field. -func (mu *MachineUpdate) ClearStatus() *MachineUpdate { - mu.mutation.ClearStatus() +// SetOsname sets the "osname" field. +func (mu *MachineUpdate) SetOsname(s string) *MachineUpdate { + mu.mutation.SetOsname(s) return mu } -// SetAuthType sets the "auth_type" field. -func (mu *MachineUpdate) SetAuthType(s string) *MachineUpdate { - mu.mutation.SetAuthType(s) +// SetNillableOsname sets the "osname" field if the given value is not nil. +func (mu *MachineUpdate) SetNillableOsname(s *string) *MachineUpdate { + if s != nil { + mu.SetOsname(*s) + } return mu } -// SetNillableAuthType sets the "auth_type" field if the given value is not nil. -func (mu *MachineUpdate) SetNillableAuthType(s *string) *MachineUpdate { +// ClearOsname clears the value of the "osname" field. +func (mu *MachineUpdate) ClearOsname() *MachineUpdate { + mu.mutation.ClearOsname() + return mu +} + +// SetOsversion sets the "osversion" field. +func (mu *MachineUpdate) SetOsversion(s string) *MachineUpdate { + mu.mutation.SetOsversion(s) + return mu +} + +// SetNillableOsversion sets the "osversion" field if the given value is not nil. +func (mu *MachineUpdate) SetNillableOsversion(s *string) *MachineUpdate { if s != nil { - mu.SetAuthType(*s) + mu.SetOsversion(*s) + } + return mu +} + +// ClearOsversion clears the value of the "osversion" field. +func (mu *MachineUpdate) ClearOsversion() *MachineUpdate { + mu.mutation.ClearOsversion() + return mu +} + +// SetFeatureflags sets the "featureflags" field. +func (mu *MachineUpdate) SetFeatureflags(s string) *MachineUpdate { + mu.mutation.SetFeatureflags(s) + return mu +} + +// SetNillableFeatureflags sets the "featureflags" field if the given value is not nil. +func (mu *MachineUpdate) SetNillableFeatureflags(s *string) *MachineUpdate { + if s != nil { + mu.SetFeatureflags(*s) } return mu } +// ClearFeatureflags clears the value of the "featureflags" field. +func (mu *MachineUpdate) ClearFeatureflags() *MachineUpdate { + mu.mutation.ClearFeatureflags() + return mu +} + +// SetHubstate sets the "hubstate" field. +func (mu *MachineUpdate) SetHubstate(ms map[string][]schema.ItemState) *MachineUpdate { + mu.mutation.SetHubstate(ms) + return mu +} + +// ClearHubstate clears the value of the "hubstate" field. +func (mu *MachineUpdate) ClearHubstate() *MachineUpdate { + mu.mutation.ClearHubstate() + return mu +} + +// SetDatasources sets the "datasources" field. +func (mu *MachineUpdate) SetDatasources(m map[string]int64) *MachineUpdate { + mu.mutation.SetDatasources(m) + return mu +} + +// ClearDatasources clears the value of the "datasources" field. +func (mu *MachineUpdate) ClearDatasources() *MachineUpdate { + mu.mutation.ClearDatasources() + return mu +} + // AddAlertIDs adds the "alerts" edge to the Alert entity by IDs. func (mu *MachineUpdate) AddAlertIDs(ids ...int) *MachineUpdate { mu.mutation.AddAlertIDs(ids...) @@ -226,41 +299,8 @@ func (mu *MachineUpdate) RemoveAlerts(a ...*Alert) *MachineUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (mu *MachineUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) mu.defaults() - if len(mu.hooks) == 0 { - if err = mu.check(); err != nil { - return 0, err - } - affected, err = mu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MachineMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = mu.check(); err != nil { - return 0, err - } - mu.mutation = mutation - affected, err = mu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(mu.hooks) - 1; i >= 0; i-- { - if mu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, mu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, mu.sqlSave, mu.mutation, mu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -287,22 +327,10 @@ func (mu *MachineUpdate) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (mu *MachineUpdate) defaults() { - if _, ok := mu.mutation.CreatedAt(); !ok && !mu.mutation.CreatedAtCleared() { - v := machine.UpdateDefaultCreatedAt() - mu.mutation.SetCreatedAt(v) - } - if _, ok := mu.mutation.UpdatedAt(); !ok && !mu.mutation.UpdatedAtCleared() { + if _, ok := mu.mutation.UpdatedAt(); !ok { v := machine.UpdateDefaultUpdatedAt() mu.mutation.SetUpdatedAt(v) } - if _, ok := mu.mutation.LastPush(); !ok && !mu.mutation.LastPushCleared() { - v := machine.UpdateDefaultLastPush() - mu.mutation.SetLastPush(v) - } - if _, ok := mu.mutation.LastHeartbeat(); !ok && !mu.mutation.LastHeartbeatCleared() { - v := machine.UpdateDefaultLastHeartbeat() - mu.mutation.SetLastHeartbeat(v) - } } // check runs all checks and user-defined validators on the builder. @@ -316,16 +344,10 @@ func (mu *MachineUpdate) check() error { } func (mu *MachineUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: machine.Table, - Columns: machine.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, - }, + if err := mu.check(); err != nil { + return n, err } + _spec := sqlgraph.NewUpdateSpec(machine.Table, machine.Columns, sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt)) if ps := mu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -333,131 +355,74 @@ func (mu *MachineUpdate) sqlSave(ctx context.Context) (n int, err error) { } } } - if value, ok := mu.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldCreatedAt, - }) - } - if mu.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: machine.FieldCreatedAt, - }) - } if value, ok := mu.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldUpdatedAt, - }) - } - if mu.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: machine.FieldUpdatedAt, - }) + _spec.SetField(machine.FieldUpdatedAt, field.TypeTime, value) } if value, ok := mu.mutation.LastPush(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldLastPush, - }) + _spec.SetField(machine.FieldLastPush, field.TypeTime, value) } if mu.mutation.LastPushCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: machine.FieldLastPush, - }) + _spec.ClearField(machine.FieldLastPush, field.TypeTime) } if value, ok := mu.mutation.LastHeartbeat(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldLastHeartbeat, - }) + _spec.SetField(machine.FieldLastHeartbeat, field.TypeTime, value) } if mu.mutation.LastHeartbeatCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: machine.FieldLastHeartbeat, - }) - } - if value, ok := mu.mutation.MachineId(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldMachineId, - }) + _spec.ClearField(machine.FieldLastHeartbeat, field.TypeTime) } if value, ok := mu.mutation.Password(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldPassword, - }) + _spec.SetField(machine.FieldPassword, field.TypeString, value) } if value, ok := mu.mutation.IpAddress(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldIpAddress, - }) + _spec.SetField(machine.FieldIpAddress, field.TypeString, value) } if value, ok := mu.mutation.Scenarios(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldScenarios, - }) + _spec.SetField(machine.FieldScenarios, field.TypeString, value) } if mu.mutation.ScenariosCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: machine.FieldScenarios, - }) + _spec.ClearField(machine.FieldScenarios, field.TypeString) } if value, ok := mu.mutation.Version(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldVersion, - }) + _spec.SetField(machine.FieldVersion, field.TypeString, value) } if mu.mutation.VersionCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: machine.FieldVersion, - }) + _spec.ClearField(machine.FieldVersion, field.TypeString) } if value, ok := mu.mutation.IsValidated(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: machine.FieldIsValidated, - }) - } - if value, ok := mu.mutation.Status(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldStatus, - }) - } - if mu.mutation.StatusCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: machine.FieldStatus, - }) + _spec.SetField(machine.FieldIsValidated, field.TypeBool, value) } if value, ok := mu.mutation.AuthType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldAuthType, - }) + _spec.SetField(machine.FieldAuthType, field.TypeString, value) + } + if value, ok := mu.mutation.Osname(); ok { + _spec.SetField(machine.FieldOsname, field.TypeString, value) + } + if mu.mutation.OsnameCleared() { + _spec.ClearField(machine.FieldOsname, field.TypeString) + } + if value, ok := mu.mutation.Osversion(); ok { + _spec.SetField(machine.FieldOsversion, field.TypeString, value) + } + if mu.mutation.OsversionCleared() { + _spec.ClearField(machine.FieldOsversion, field.TypeString) + } + if value, ok := mu.mutation.Featureflags(); ok { + _spec.SetField(machine.FieldFeatureflags, field.TypeString, value) + } + if mu.mutation.FeatureflagsCleared() { + _spec.ClearField(machine.FieldFeatureflags, field.TypeString) + } + if value, ok := mu.mutation.Hubstate(); ok { + _spec.SetField(machine.FieldHubstate, field.TypeJSON, value) + } + if mu.mutation.HubstateCleared() { + _spec.ClearField(machine.FieldHubstate, field.TypeJSON) + } + if value, ok := mu.mutation.Datasources(); ok { + _spec.SetField(machine.FieldDatasources, field.TypeJSON, value) + } + if mu.mutation.DatasourcesCleared() { + _spec.ClearField(machine.FieldDatasources, field.TypeJSON) } if mu.mutation.AlertsCleared() { edge := &sqlgraph.EdgeSpec{ @@ -467,10 +432,7 @@ func (mu *MachineUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{machine.AlertsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -483,10 +445,7 @@ func (mu *MachineUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{machine.AlertsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -502,10 +461,7 @@ func (mu *MachineUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{machine.AlertsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -521,6 +477,7 @@ func (mu *MachineUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + mu.mutation.done = true return n, nil } @@ -532,36 +489,26 @@ type MachineUpdateOne struct { mutation *MachineMutation } -// SetCreatedAt sets the "created_at" field. -func (muo *MachineUpdateOne) SetCreatedAt(t time.Time) *MachineUpdateOne { - muo.mutation.SetCreatedAt(t) - return muo -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (muo *MachineUpdateOne) ClearCreatedAt() *MachineUpdateOne { - muo.mutation.ClearCreatedAt() - return muo -} - // SetUpdatedAt sets the "updated_at" field. func (muo *MachineUpdateOne) SetUpdatedAt(t time.Time) *MachineUpdateOne { muo.mutation.SetUpdatedAt(t) return muo } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (muo *MachineUpdateOne) ClearUpdatedAt() *MachineUpdateOne { - muo.mutation.ClearUpdatedAt() - return muo -} - // SetLastPush sets the "last_push" field. func (muo *MachineUpdateOne) SetLastPush(t time.Time) *MachineUpdateOne { muo.mutation.SetLastPush(t) return muo } +// SetNillableLastPush sets the "last_push" field if the given value is not nil. +func (muo *MachineUpdateOne) SetNillableLastPush(t *time.Time) *MachineUpdateOne { + if t != nil { + muo.SetLastPush(*t) + } + return muo +} + // ClearLastPush clears the value of the "last_push" field. func (muo *MachineUpdateOne) ClearLastPush() *MachineUpdateOne { muo.mutation.ClearLastPush() @@ -574,15 +521,17 @@ func (muo *MachineUpdateOne) SetLastHeartbeat(t time.Time) *MachineUpdateOne { return muo } -// ClearLastHeartbeat clears the value of the "last_heartbeat" field. -func (muo *MachineUpdateOne) ClearLastHeartbeat() *MachineUpdateOne { - muo.mutation.ClearLastHeartbeat() +// SetNillableLastHeartbeat sets the "last_heartbeat" field if the given value is not nil. +func (muo *MachineUpdateOne) SetNillableLastHeartbeat(t *time.Time) *MachineUpdateOne { + if t != nil { + muo.SetLastHeartbeat(*t) + } return muo } -// SetMachineId sets the "machineId" field. -func (muo *MachineUpdateOne) SetMachineId(s string) *MachineUpdateOne { - muo.mutation.SetMachineId(s) +// ClearLastHeartbeat clears the value of the "last_heartbeat" field. +func (muo *MachineUpdateOne) ClearLastHeartbeat() *MachineUpdateOne { + muo.mutation.ClearLastHeartbeat() return muo } @@ -592,12 +541,28 @@ func (muo *MachineUpdateOne) SetPassword(s string) *MachineUpdateOne { return muo } +// SetNillablePassword sets the "password" field if the given value is not nil. +func (muo *MachineUpdateOne) SetNillablePassword(s *string) *MachineUpdateOne { + if s != nil { + muo.SetPassword(*s) + } + return muo +} + // SetIpAddress sets the "ipAddress" field. func (muo *MachineUpdateOne) SetIpAddress(s string) *MachineUpdateOne { muo.mutation.SetIpAddress(s) return muo } +// SetNillableIpAddress sets the "ipAddress" field if the given value is not nil. +func (muo *MachineUpdateOne) SetNillableIpAddress(s *string) *MachineUpdateOne { + if s != nil { + muo.SetIpAddress(*s) + } + return muo +} + // SetScenarios sets the "scenarios" field. func (muo *MachineUpdateOne) SetScenarios(s string) *MachineUpdateOne { muo.mutation.SetScenarios(s) @@ -652,40 +617,104 @@ func (muo *MachineUpdateOne) SetNillableIsValidated(b *bool) *MachineUpdateOne { return muo } -// SetStatus sets the "status" field. -func (muo *MachineUpdateOne) SetStatus(s string) *MachineUpdateOne { - muo.mutation.SetStatus(s) +// SetAuthType sets the "auth_type" field. +func (muo *MachineUpdateOne) SetAuthType(s string) *MachineUpdateOne { + muo.mutation.SetAuthType(s) return muo } -// SetNillableStatus sets the "status" field if the given value is not nil. -func (muo *MachineUpdateOne) SetNillableStatus(s *string) *MachineUpdateOne { +// SetNillableAuthType sets the "auth_type" field if the given value is not nil. +func (muo *MachineUpdateOne) SetNillableAuthType(s *string) *MachineUpdateOne { if s != nil { - muo.SetStatus(*s) + muo.SetAuthType(*s) } return muo } -// ClearStatus clears the value of the "status" field. -func (muo *MachineUpdateOne) ClearStatus() *MachineUpdateOne { - muo.mutation.ClearStatus() +// SetOsname sets the "osname" field. +func (muo *MachineUpdateOne) SetOsname(s string) *MachineUpdateOne { + muo.mutation.SetOsname(s) return muo } -// SetAuthType sets the "auth_type" field. -func (muo *MachineUpdateOne) SetAuthType(s string) *MachineUpdateOne { - muo.mutation.SetAuthType(s) +// SetNillableOsname sets the "osname" field if the given value is not nil. +func (muo *MachineUpdateOne) SetNillableOsname(s *string) *MachineUpdateOne { + if s != nil { + muo.SetOsname(*s) + } return muo } -// SetNillableAuthType sets the "auth_type" field if the given value is not nil. -func (muo *MachineUpdateOne) SetNillableAuthType(s *string) *MachineUpdateOne { +// ClearOsname clears the value of the "osname" field. +func (muo *MachineUpdateOne) ClearOsname() *MachineUpdateOne { + muo.mutation.ClearOsname() + return muo +} + +// SetOsversion sets the "osversion" field. +func (muo *MachineUpdateOne) SetOsversion(s string) *MachineUpdateOne { + muo.mutation.SetOsversion(s) + return muo +} + +// SetNillableOsversion sets the "osversion" field if the given value is not nil. +func (muo *MachineUpdateOne) SetNillableOsversion(s *string) *MachineUpdateOne { if s != nil { - muo.SetAuthType(*s) + muo.SetOsversion(*s) } return muo } +// ClearOsversion clears the value of the "osversion" field. +func (muo *MachineUpdateOne) ClearOsversion() *MachineUpdateOne { + muo.mutation.ClearOsversion() + return muo +} + +// SetFeatureflags sets the "featureflags" field. +func (muo *MachineUpdateOne) SetFeatureflags(s string) *MachineUpdateOne { + muo.mutation.SetFeatureflags(s) + return muo +} + +// SetNillableFeatureflags sets the "featureflags" field if the given value is not nil. +func (muo *MachineUpdateOne) SetNillableFeatureflags(s *string) *MachineUpdateOne { + if s != nil { + muo.SetFeatureflags(*s) + } + return muo +} + +// ClearFeatureflags clears the value of the "featureflags" field. +func (muo *MachineUpdateOne) ClearFeatureflags() *MachineUpdateOne { + muo.mutation.ClearFeatureflags() + return muo +} + +// SetHubstate sets the "hubstate" field. +func (muo *MachineUpdateOne) SetHubstate(ms map[string][]schema.ItemState) *MachineUpdateOne { + muo.mutation.SetHubstate(ms) + return muo +} + +// ClearHubstate clears the value of the "hubstate" field. +func (muo *MachineUpdateOne) ClearHubstate() *MachineUpdateOne { + muo.mutation.ClearHubstate() + return muo +} + +// SetDatasources sets the "datasources" field. +func (muo *MachineUpdateOne) SetDatasources(m map[string]int64) *MachineUpdateOne { + muo.mutation.SetDatasources(m) + return muo +} + +// ClearDatasources clears the value of the "datasources" field. +func (muo *MachineUpdateOne) ClearDatasources() *MachineUpdateOne { + muo.mutation.ClearDatasources() + return muo +} + // AddAlertIDs adds the "alerts" edge to the Alert entity by IDs. func (muo *MachineUpdateOne) AddAlertIDs(ids ...int) *MachineUpdateOne { muo.mutation.AddAlertIDs(ids...) @@ -727,6 +756,12 @@ func (muo *MachineUpdateOne) RemoveAlerts(a ...*Alert) *MachineUpdateOne { return muo.RemoveAlertIDs(ids...) } +// Where appends a list predicates to the MachineUpdate builder. +func (muo *MachineUpdateOne) Where(ps ...predicate.Machine) *MachineUpdateOne { + muo.mutation.Where(ps...) + return muo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (muo *MachineUpdateOne) Select(field string, fields ...string) *MachineUpdateOne { @@ -736,47 +771,8 @@ func (muo *MachineUpdateOne) Select(field string, fields ...string) *MachineUpda // Save executes the query and returns the updated Machine entity. func (muo *MachineUpdateOne) Save(ctx context.Context) (*Machine, error) { - var ( - err error - node *Machine - ) muo.defaults() - if len(muo.hooks) == 0 { - if err = muo.check(); err != nil { - return nil, err - } - node, err = muo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MachineMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = muo.check(); err != nil { - return nil, err - } - muo.mutation = mutation - node, err = muo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(muo.hooks) - 1; i >= 0; i-- { - if muo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = muo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, muo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Machine) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MachineMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, muo.sqlSave, muo.mutation, muo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -803,22 +799,10 @@ func (muo *MachineUpdateOne) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (muo *MachineUpdateOne) defaults() { - if _, ok := muo.mutation.CreatedAt(); !ok && !muo.mutation.CreatedAtCleared() { - v := machine.UpdateDefaultCreatedAt() - muo.mutation.SetCreatedAt(v) - } - if _, ok := muo.mutation.UpdatedAt(); !ok && !muo.mutation.UpdatedAtCleared() { + if _, ok := muo.mutation.UpdatedAt(); !ok { v := machine.UpdateDefaultUpdatedAt() muo.mutation.SetUpdatedAt(v) } - if _, ok := muo.mutation.LastPush(); !ok && !muo.mutation.LastPushCleared() { - v := machine.UpdateDefaultLastPush() - muo.mutation.SetLastPush(v) - } - if _, ok := muo.mutation.LastHeartbeat(); !ok && !muo.mutation.LastHeartbeatCleared() { - v := machine.UpdateDefaultLastHeartbeat() - muo.mutation.SetLastHeartbeat(v) - } } // check runs all checks and user-defined validators on the builder. @@ -832,16 +816,10 @@ func (muo *MachineUpdateOne) check() error { } func (muo *MachineUpdateOne) sqlSave(ctx context.Context) (_node *Machine, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: machine.Table, - Columns: machine.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, - }, + if err := muo.check(); err != nil { + return _node, err } + _spec := sqlgraph.NewUpdateSpec(machine.Table, machine.Columns, sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt)) id, ok := muo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Machine.id" for update`)} @@ -866,131 +844,74 @@ func (muo *MachineUpdateOne) sqlSave(ctx context.Context) (_node *Machine, err e } } } - if value, ok := muo.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldCreatedAt, - }) - } - if muo.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: machine.FieldCreatedAt, - }) - } if value, ok := muo.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldUpdatedAt, - }) - } - if muo.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: machine.FieldUpdatedAt, - }) + _spec.SetField(machine.FieldUpdatedAt, field.TypeTime, value) } if value, ok := muo.mutation.LastPush(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldLastPush, - }) + _spec.SetField(machine.FieldLastPush, field.TypeTime, value) } if muo.mutation.LastPushCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: machine.FieldLastPush, - }) + _spec.ClearField(machine.FieldLastPush, field.TypeTime) } if value, ok := muo.mutation.LastHeartbeat(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldLastHeartbeat, - }) + _spec.SetField(machine.FieldLastHeartbeat, field.TypeTime, value) } if muo.mutation.LastHeartbeatCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: machine.FieldLastHeartbeat, - }) - } - if value, ok := muo.mutation.MachineId(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldMachineId, - }) + _spec.ClearField(machine.FieldLastHeartbeat, field.TypeTime) } if value, ok := muo.mutation.Password(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldPassword, - }) + _spec.SetField(machine.FieldPassword, field.TypeString, value) } if value, ok := muo.mutation.IpAddress(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldIpAddress, - }) + _spec.SetField(machine.FieldIpAddress, field.TypeString, value) } if value, ok := muo.mutation.Scenarios(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldScenarios, - }) + _spec.SetField(machine.FieldScenarios, field.TypeString, value) } if muo.mutation.ScenariosCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: machine.FieldScenarios, - }) + _spec.ClearField(machine.FieldScenarios, field.TypeString) } if value, ok := muo.mutation.Version(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldVersion, - }) + _spec.SetField(machine.FieldVersion, field.TypeString, value) } if muo.mutation.VersionCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: machine.FieldVersion, - }) + _spec.ClearField(machine.FieldVersion, field.TypeString) } if value, ok := muo.mutation.IsValidated(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: machine.FieldIsValidated, - }) - } - if value, ok := muo.mutation.Status(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldStatus, - }) - } - if muo.mutation.StatusCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: machine.FieldStatus, - }) + _spec.SetField(machine.FieldIsValidated, field.TypeBool, value) } if value, ok := muo.mutation.AuthType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldAuthType, - }) + _spec.SetField(machine.FieldAuthType, field.TypeString, value) + } + if value, ok := muo.mutation.Osname(); ok { + _spec.SetField(machine.FieldOsname, field.TypeString, value) + } + if muo.mutation.OsnameCleared() { + _spec.ClearField(machine.FieldOsname, field.TypeString) + } + if value, ok := muo.mutation.Osversion(); ok { + _spec.SetField(machine.FieldOsversion, field.TypeString, value) + } + if muo.mutation.OsversionCleared() { + _spec.ClearField(machine.FieldOsversion, field.TypeString) + } + if value, ok := muo.mutation.Featureflags(); ok { + _spec.SetField(machine.FieldFeatureflags, field.TypeString, value) + } + if muo.mutation.FeatureflagsCleared() { + _spec.ClearField(machine.FieldFeatureflags, field.TypeString) + } + if value, ok := muo.mutation.Hubstate(); ok { + _spec.SetField(machine.FieldHubstate, field.TypeJSON, value) + } + if muo.mutation.HubstateCleared() { + _spec.ClearField(machine.FieldHubstate, field.TypeJSON) + } + if value, ok := muo.mutation.Datasources(); ok { + _spec.SetField(machine.FieldDatasources, field.TypeJSON, value) + } + if muo.mutation.DatasourcesCleared() { + _spec.ClearField(machine.FieldDatasources, field.TypeJSON) } if muo.mutation.AlertsCleared() { edge := &sqlgraph.EdgeSpec{ @@ -1000,10 +921,7 @@ func (muo *MachineUpdateOne) sqlSave(ctx context.Context) (_node *Machine, err e Columns: []string{machine.AlertsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -1016,10 +934,7 @@ func (muo *MachineUpdateOne) sqlSave(ctx context.Context) (_node *Machine, err e Columns: []string{machine.AlertsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1035,10 +950,7 @@ func (muo *MachineUpdateOne) sqlSave(ctx context.Context) (_node *Machine, err e Columns: []string{machine.AlertsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1057,5 +969,6 @@ func (muo *MachineUpdateOne) sqlSave(ctx context.Context) (_node *Machine, err e } return nil, err } + muo.mutation.done = true return _node, nil } diff --git a/pkg/database/ent/meta.go b/pkg/database/ent/meta.go index 660f1a4db73..7e29627957c 100644 --- a/pkg/database/ent/meta.go +++ b/pkg/database/ent/meta.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" "github.com/crowdsecurity/crowdsec/pkg/database/ent/meta" @@ -18,9 +19,9 @@ type Meta struct { // ID of the ent. ID int `json:"id,omitempty"` // CreatedAt holds the value of the "created_at" field. - CreatedAt *time.Time `json:"created_at,omitempty"` + CreatedAt time.Time `json:"created_at,omitempty"` // UpdatedAt holds the value of the "updated_at" field. - UpdatedAt *time.Time `json:"updated_at,omitempty"` + UpdatedAt time.Time `json:"updated_at,omitempty"` // Key holds the value of the "key" field. Key string `json:"key,omitempty"` // Value holds the value of the "value" field. @@ -29,7 +30,8 @@ type Meta struct { AlertMetas int `json:"alert_metas,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the MetaQuery when eager-loading is set. - Edges MetaEdges `json:"edges"` + Edges MetaEdges `json:"edges"` + selectValues sql.SelectValues } // MetaEdges holds the relations/edges for other nodes in the graph. @@ -44,12 +46,10 @@ type MetaEdges struct { // OwnerOrErr returns the Owner value or an error if the edge // was not loaded in eager-loading, or loaded but was not found. func (e MetaEdges) OwnerOrErr() (*Alert, error) { - if e.loadedTypes[0] { - if e.Owner == nil { - // Edge was loaded but was not found. - return nil, &NotFoundError{label: alert.Label} - } + if e.Owner != nil { return e.Owner, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: alert.Label} } return nil, &NotLoadedError{edge: "owner"} } @@ -66,7 +66,7 @@ func (*Meta) scanValues(columns []string) ([]any, error) { case meta.FieldCreatedAt, meta.FieldUpdatedAt: values[i] = new(sql.NullTime) default: - return nil, fmt.Errorf("unexpected column %q for type Meta", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -90,15 +90,13 @@ func (m *Meta) assignValues(columns []string, values []any) error { if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) } else if value.Valid { - m.CreatedAt = new(time.Time) - *m.CreatedAt = value.Time + m.CreatedAt = value.Time } case meta.FieldUpdatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field updated_at", values[i]) } else if value.Valid { - m.UpdatedAt = new(time.Time) - *m.UpdatedAt = value.Time + m.UpdatedAt = value.Time } case meta.FieldKey: if value, ok := values[i].(*sql.NullString); !ok { @@ -118,21 +116,29 @@ func (m *Meta) assignValues(columns []string, values []any) error { } else if value.Valid { m.AlertMetas = int(value.Int64) } + default: + m.selectValues.Set(columns[i], values[i]) } } return nil } +// GetValue returns the ent.Value that was dynamically selected and assigned to the Meta. +// This includes values selected through modifiers, order, etc. +func (m *Meta) GetValue(name string) (ent.Value, error) { + return m.selectValues.Get(name) +} + // QueryOwner queries the "owner" edge of the Meta entity. func (m *Meta) QueryOwner() *AlertQuery { - return (&MetaClient{config: m.config}).QueryOwner(m) + return NewMetaClient(m.config).QueryOwner(m) } // Update returns a builder for updating this Meta. // Note that you need to call Meta.Unwrap() before calling this method if this Meta // was returned from a transaction, and the transaction was committed or rolled back. func (m *Meta) Update() *MetaUpdateOne { - return (&MetaClient{config: m.config}).UpdateOne(m) + return NewMetaClient(m.config).UpdateOne(m) } // Unwrap unwraps the Meta entity that was returned from a transaction after it was closed, @@ -151,15 +157,11 @@ func (m *Meta) String() string { var builder strings.Builder builder.WriteString("Meta(") builder.WriteString(fmt.Sprintf("id=%v, ", m.ID)) - if v := m.CreatedAt; v != nil { - builder.WriteString("created_at=") - builder.WriteString(v.Format(time.ANSIC)) - } + builder.WriteString("created_at=") + builder.WriteString(m.CreatedAt.Format(time.ANSIC)) builder.WriteString(", ") - if v := m.UpdatedAt; v != nil { - builder.WriteString("updated_at=") - builder.WriteString(v.Format(time.ANSIC)) - } + builder.WriteString("updated_at=") + builder.WriteString(m.UpdatedAt.Format(time.ANSIC)) builder.WriteString(", ") builder.WriteString("key=") builder.WriteString(m.Key) @@ -175,9 +177,3 @@ func (m *Meta) String() string { // MetaSlice is a parsable slice of Meta. type MetaSlice []*Meta - -func (m MetaSlice) config(cfg config) { - for _i := range m { - m[_i].config = cfg - } -} diff --git a/pkg/database/ent/meta/meta.go b/pkg/database/ent/meta/meta.go index 6d10f258919..ff41361616a 100644 --- a/pkg/database/ent/meta/meta.go +++ b/pkg/database/ent/meta/meta.go @@ -4,6 +4,9 @@ package meta import ( "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" ) const ( @@ -57,8 +60,6 @@ func ValidColumn(column string) bool { var ( // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time - // UpdateDefaultCreatedAt holds the default value on update for the "created_at" field. - UpdateDefaultCreatedAt func() time.Time // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. DefaultUpdatedAt func() time.Time // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. @@ -66,3 +67,50 @@ var ( // ValueValidator is a validator for the "value" field. It is called by the builders before save. ValueValidator func(string) error ) + +// OrderOption defines the ordering options for the Meta queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByKey orders the results by the key field. +func ByKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldKey, opts...).ToFunc() +} + +// ByValue orders the results by the value field. +func ByValue(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldValue, opts...).ToFunc() +} + +// ByAlertMetas orders the results by the alert_metas field. +func ByAlertMetas(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAlertMetas, opts...).ToFunc() +} + +// ByOwnerField orders the results by owner field. +func ByOwnerField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newOwnerStep(), sql.OrderByField(field, opts...)) + } +} +func newOwnerStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(OwnerInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), + ) +} diff --git a/pkg/database/ent/meta/where.go b/pkg/database/ent/meta/where.go index 479792fd4a6..6d5d54c0482 100644 --- a/pkg/database/ent/meta/where.go +++ b/pkg/database/ent/meta/where.go @@ -12,512 +12,312 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Meta(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Meta(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.Meta(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.Meta(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.Meta(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.Meta(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.Meta(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.Meta(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.Meta(sql.FieldLTE(FieldID, id)) } // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldCreatedAt, v)) } // UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. func UpdatedAt(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldUpdatedAt, v)) } // Key applies equality check predicate on the "key" field. It's identical to KeyEQ. func Key(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldKey, v)) } // Value applies equality check predicate on the "value" field. It's identical to ValueEQ. func Value(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldValue, v)) } // AlertMetas applies equality check predicate on the "alert_metas" field. It's identical to AlertMetasEQ. func AlertMetas(v int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAlertMetas), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldAlertMetas, v)) } // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldCreatedAt, v)) } // CreatedAtNEQ applies the NEQ predicate on the "created_at" field. func CreatedAtNEQ(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Meta(sql.FieldNEQ(FieldCreatedAt, v)) } // CreatedAtIn applies the In predicate on the "created_at" field. func CreatedAtIn(vs ...time.Time) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldCreatedAt), v...)) - }) + return predicate.Meta(sql.FieldIn(FieldCreatedAt, vs...)) } // CreatedAtNotIn applies the NotIn predicate on the "created_at" field. func CreatedAtNotIn(vs ...time.Time) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldCreatedAt), v...)) - }) + return predicate.Meta(sql.FieldNotIn(FieldCreatedAt, vs...)) } // CreatedAtGT applies the GT predicate on the "created_at" field. func CreatedAtGT(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldCreatedAt), v)) - }) + return predicate.Meta(sql.FieldGT(FieldCreatedAt, v)) } // CreatedAtGTE applies the GTE predicate on the "created_at" field. func CreatedAtGTE(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldCreatedAt), v)) - }) + return predicate.Meta(sql.FieldGTE(FieldCreatedAt, v)) } // CreatedAtLT applies the LT predicate on the "created_at" field. func CreatedAtLT(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldCreatedAt), v)) - }) + return predicate.Meta(sql.FieldLT(FieldCreatedAt, v)) } // CreatedAtLTE applies the LTE predicate on the "created_at" field. func CreatedAtLTE(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldCreatedAt), v)) - }) -} - -// CreatedAtIsNil applies the IsNil predicate on the "created_at" field. -func CreatedAtIsNil() predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldCreatedAt))) - }) -} - -// CreatedAtNotNil applies the NotNil predicate on the "created_at" field. -func CreatedAtNotNil() predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldCreatedAt))) - }) + return predicate.Meta(sql.FieldLTE(FieldCreatedAt, v)) } // UpdatedAtEQ applies the EQ predicate on the "updated_at" field. func UpdatedAtEQ(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldUpdatedAt, v)) } // UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. func UpdatedAtNEQ(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Meta(sql.FieldNEQ(FieldUpdatedAt, v)) } // UpdatedAtIn applies the In predicate on the "updated_at" field. func UpdatedAtIn(vs ...time.Time) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Meta(sql.FieldIn(FieldUpdatedAt, vs...)) } // UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. func UpdatedAtNotIn(vs ...time.Time) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Meta(sql.FieldNotIn(FieldUpdatedAt, vs...)) } // UpdatedAtGT applies the GT predicate on the "updated_at" field. func UpdatedAtGT(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Meta(sql.FieldGT(FieldUpdatedAt, v)) } // UpdatedAtGTE applies the GTE predicate on the "updated_at" field. func UpdatedAtGTE(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUpdatedAt), v)) - }) + return predicate.Meta(sql.FieldGTE(FieldUpdatedAt, v)) } // UpdatedAtLT applies the LT predicate on the "updated_at" field. func UpdatedAtLT(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Meta(sql.FieldLT(FieldUpdatedAt, v)) } // UpdatedAtLTE applies the LTE predicate on the "updated_at" field. func UpdatedAtLTE(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUpdatedAt), v)) - }) -} - -// UpdatedAtIsNil applies the IsNil predicate on the "updated_at" field. -func UpdatedAtIsNil() predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUpdatedAt))) - }) -} - -// UpdatedAtNotNil applies the NotNil predicate on the "updated_at" field. -func UpdatedAtNotNil() predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUpdatedAt))) - }) + return predicate.Meta(sql.FieldLTE(FieldUpdatedAt, v)) } // KeyEQ applies the EQ predicate on the "key" field. func KeyEQ(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldKey, v)) } // KeyNEQ applies the NEQ predicate on the "key" field. func KeyNEQ(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldNEQ(FieldKey, v)) } // KeyIn applies the In predicate on the "key" field. func KeyIn(vs ...string) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldKey), v...)) - }) + return predicate.Meta(sql.FieldIn(FieldKey, vs...)) } // KeyNotIn applies the NotIn predicate on the "key" field. func KeyNotIn(vs ...string) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldKey), v...)) - }) + return predicate.Meta(sql.FieldNotIn(FieldKey, vs...)) } // KeyGT applies the GT predicate on the "key" field. func KeyGT(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldGT(FieldKey, v)) } // KeyGTE applies the GTE predicate on the "key" field. func KeyGTE(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldGTE(FieldKey, v)) } // KeyLT applies the LT predicate on the "key" field. func KeyLT(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldLT(FieldKey, v)) } // KeyLTE applies the LTE predicate on the "key" field. func KeyLTE(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldLTE(FieldKey, v)) } // KeyContains applies the Contains predicate on the "key" field. func KeyContains(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldContains(FieldKey, v)) } // KeyHasPrefix applies the HasPrefix predicate on the "key" field. func KeyHasPrefix(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldHasPrefix(FieldKey, v)) } // KeyHasSuffix applies the HasSuffix predicate on the "key" field. func KeyHasSuffix(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldHasSuffix(FieldKey, v)) } // KeyEqualFold applies the EqualFold predicate on the "key" field. func KeyEqualFold(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldEqualFold(FieldKey, v)) } // KeyContainsFold applies the ContainsFold predicate on the "key" field. func KeyContainsFold(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldContainsFold(FieldKey, v)) } // ValueEQ applies the EQ predicate on the "value" field. func ValueEQ(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldValue, v)) } // ValueNEQ applies the NEQ predicate on the "value" field. func ValueNEQ(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldNEQ(FieldValue, v)) } // ValueIn applies the In predicate on the "value" field. func ValueIn(vs ...string) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldValue), v...)) - }) + return predicate.Meta(sql.FieldIn(FieldValue, vs...)) } // ValueNotIn applies the NotIn predicate on the "value" field. func ValueNotIn(vs ...string) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldValue), v...)) - }) + return predicate.Meta(sql.FieldNotIn(FieldValue, vs...)) } // ValueGT applies the GT predicate on the "value" field. func ValueGT(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldGT(FieldValue, v)) } // ValueGTE applies the GTE predicate on the "value" field. func ValueGTE(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldGTE(FieldValue, v)) } // ValueLT applies the LT predicate on the "value" field. func ValueLT(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldLT(FieldValue, v)) } // ValueLTE applies the LTE predicate on the "value" field. func ValueLTE(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldLTE(FieldValue, v)) } // ValueContains applies the Contains predicate on the "value" field. func ValueContains(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldContains(FieldValue, v)) } // ValueHasPrefix applies the HasPrefix predicate on the "value" field. func ValueHasPrefix(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldHasPrefix(FieldValue, v)) } // ValueHasSuffix applies the HasSuffix predicate on the "value" field. func ValueHasSuffix(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldHasSuffix(FieldValue, v)) } // ValueEqualFold applies the EqualFold predicate on the "value" field. func ValueEqualFold(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldEqualFold(FieldValue, v)) } // ValueContainsFold applies the ContainsFold predicate on the "value" field. func ValueContainsFold(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldContainsFold(FieldValue, v)) } // AlertMetasEQ applies the EQ predicate on the "alert_metas" field. func AlertMetasEQ(v int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAlertMetas), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldAlertMetas, v)) } // AlertMetasNEQ applies the NEQ predicate on the "alert_metas" field. func AlertMetasNEQ(v int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldAlertMetas), v)) - }) + return predicate.Meta(sql.FieldNEQ(FieldAlertMetas, v)) } // AlertMetasIn applies the In predicate on the "alert_metas" field. func AlertMetasIn(vs ...int) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldAlertMetas), v...)) - }) + return predicate.Meta(sql.FieldIn(FieldAlertMetas, vs...)) } // AlertMetasNotIn applies the NotIn predicate on the "alert_metas" field. func AlertMetasNotIn(vs ...int) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldAlertMetas), v...)) - }) + return predicate.Meta(sql.FieldNotIn(FieldAlertMetas, vs...)) } // AlertMetasIsNil applies the IsNil predicate on the "alert_metas" field. func AlertMetasIsNil() predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldAlertMetas))) - }) + return predicate.Meta(sql.FieldIsNull(FieldAlertMetas)) } // AlertMetasNotNil applies the NotNil predicate on the "alert_metas" field. func AlertMetasNotNil() predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldAlertMetas))) - }) + return predicate.Meta(sql.FieldNotNull(FieldAlertMetas)) } // HasOwner applies the HasEdge predicate on the "owner" edge. @@ -525,7 +325,6 @@ func HasOwner() predicate.Meta { return predicate.Meta(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(OwnerTable, FieldID), sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), ) sqlgraph.HasNeighbors(s, step) @@ -535,11 +334,7 @@ func HasOwner() predicate.Meta { // HasOwnerWith applies the HasEdge predicate on the "owner" edge with a given conditions (other predicates). func HasOwnerWith(preds ...predicate.Alert) predicate.Meta { return predicate.Meta(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(OwnerInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), - ) + step := newOwnerStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -550,32 +345,15 @@ func HasOwnerWith(preds ...predicate.Alert) predicate.Meta { // And groups predicates with the AND operator between them. func And(predicates ...predicate.Meta) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Meta(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.Meta) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Meta(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.Meta) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.Meta(sql.NotPredicates(p)) } diff --git a/pkg/database/ent/meta_create.go b/pkg/database/ent/meta_create.go index df4f6315911..321c4bd7ab4 100644 --- a/pkg/database/ent/meta_create.go +++ b/pkg/database/ent/meta_create.go @@ -101,50 +101,8 @@ func (mc *MetaCreate) Mutation() *MetaMutation { // Save creates the Meta in the database. func (mc *MetaCreate) Save(ctx context.Context) (*Meta, error) { - var ( - err error - node *Meta - ) mc.defaults() - if len(mc.hooks) == 0 { - if err = mc.check(); err != nil { - return nil, err - } - node, err = mc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MetaMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = mc.check(); err != nil { - return nil, err - } - mc.mutation = mutation - if node, err = mc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(mc.hooks) - 1; i >= 0; i-- { - if mc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, mc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Meta) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MetaMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, mc.sqlSave, mc.mutation, mc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -183,6 +141,12 @@ func (mc *MetaCreate) defaults() { // check runs all checks and user-defined validators on the builder. func (mc *MetaCreate) check() error { + if _, ok := mc.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Meta.created_at"`)} + } + if _, ok := mc.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "Meta.updated_at"`)} + } if _, ok := mc.mutation.Key(); !ok { return &ValidationError{Name: "key", err: errors.New(`ent: missing required field "Meta.key"`)} } @@ -198,6 +162,9 @@ func (mc *MetaCreate) check() error { } func (mc *MetaCreate) sqlSave(ctx context.Context) (*Meta, error) { + if err := mc.check(); err != nil { + return nil, err + } _node, _spec := mc.createSpec() if err := sqlgraph.CreateNode(ctx, mc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -207,50 +174,30 @@ func (mc *MetaCreate) sqlSave(ctx context.Context) (*Meta, error) { } id := _spec.ID.Value.(int64) _node.ID = int(id) + mc.mutation.id = &_node.ID + mc.mutation.done = true return _node, nil } func (mc *MetaCreate) createSpec() (*Meta, *sqlgraph.CreateSpec) { var ( _node = &Meta{config: mc.config} - _spec = &sqlgraph.CreateSpec{ - Table: meta.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(meta.Table, sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt)) ) if value, ok := mc.mutation.CreatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: meta.FieldCreatedAt, - }) - _node.CreatedAt = &value + _spec.SetField(meta.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value } if value, ok := mc.mutation.UpdatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: meta.FieldUpdatedAt, - }) - _node.UpdatedAt = &value + _spec.SetField(meta.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value } if value, ok := mc.mutation.Key(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: meta.FieldKey, - }) + _spec.SetField(meta.FieldKey, field.TypeString, value) _node.Key = value } if value, ok := mc.mutation.Value(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: meta.FieldValue, - }) + _spec.SetField(meta.FieldValue, field.TypeString, value) _node.Value = value } if nodes := mc.mutation.OwnerIDs(); len(nodes) > 0 { @@ -261,10 +208,7 @@ func (mc *MetaCreate) createSpec() (*Meta, *sqlgraph.CreateSpec) { Columns: []string{meta.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -279,11 +223,15 @@ func (mc *MetaCreate) createSpec() (*Meta, *sqlgraph.CreateSpec) { // MetaCreateBulk is the builder for creating many Meta entities in bulk. type MetaCreateBulk struct { config + err error builders []*MetaCreate } // Save creates the Meta entities in the database. func (mcb *MetaCreateBulk) Save(ctx context.Context) ([]*Meta, error) { + if mcb.err != nil { + return nil, mcb.err + } specs := make([]*sqlgraph.CreateSpec, len(mcb.builders)) nodes := make([]*Meta, len(mcb.builders)) mutators := make([]Mutator, len(mcb.builders)) @@ -300,8 +248,8 @@ func (mcb *MetaCreateBulk) Save(ctx context.Context) ([]*Meta, error) { return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, mcb.builders[i+1].mutation) } else { diff --git a/pkg/database/ent/meta_delete.go b/pkg/database/ent/meta_delete.go index e1e49d2acdc..ee25dd07eb9 100644 --- a/pkg/database/ent/meta_delete.go +++ b/pkg/database/ent/meta_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (md *MetaDelete) Where(ps ...predicate.Meta) *MetaDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (md *MetaDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(md.hooks) == 0 { - affected, err = md.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MetaMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - md.mutation = mutation - affected, err = md.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(md.hooks) - 1; i >= 0; i-- { - if md.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = md.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, md.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, md.sqlExec, md.mutation, md.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (md *MetaDelete) ExecX(ctx context.Context) int { } func (md *MetaDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: meta.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(meta.Table, sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt)) if ps := md.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (md *MetaDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + md.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type MetaDeleteOne struct { md *MetaDelete } +// Where appends a list predicates to the MetaDelete builder. +func (mdo *MetaDeleteOne) Where(ps ...predicate.Meta) *MetaDeleteOne { + mdo.md.mutation.Where(ps...) + return mdo +} + // Exec executes the deletion query. func (mdo *MetaDeleteOne) Exec(ctx context.Context) error { n, err := mdo.md.Exec(ctx) @@ -111,5 +82,7 @@ func (mdo *MetaDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (mdo *MetaDeleteOne) ExecX(ctx context.Context) { - mdo.md.ExecX(ctx) + if err := mdo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/pkg/database/ent/meta_query.go b/pkg/database/ent/meta_query.go index d6fd4f3d522..87d91d09e0e 100644 --- a/pkg/database/ent/meta_query.go +++ b/pkg/database/ent/meta_query.go @@ -18,11 +18,9 @@ import ( // MetaQuery is the builder for querying Meta entities. type MetaQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string + ctx *QueryContext + order []meta.OrderOption + inters []Interceptor predicates []predicate.Meta withOwner *AlertQuery // intermediate query (i.e. traversal path). @@ -36,34 +34,34 @@ func (mq *MetaQuery) Where(ps ...predicate.Meta) *MetaQuery { return mq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (mq *MetaQuery) Limit(limit int) *MetaQuery { - mq.limit = &limit + mq.ctx.Limit = &limit return mq } -// Offset adds an offset step to the query. +// Offset to start from. func (mq *MetaQuery) Offset(offset int) *MetaQuery { - mq.offset = &offset + mq.ctx.Offset = &offset return mq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (mq *MetaQuery) Unique(unique bool) *MetaQuery { - mq.unique = &unique + mq.ctx.Unique = &unique return mq } -// Order adds an order step to the query. -func (mq *MetaQuery) Order(o ...OrderFunc) *MetaQuery { +// Order specifies how the records should be ordered. +func (mq *MetaQuery) Order(o ...meta.OrderOption) *MetaQuery { mq.order = append(mq.order, o...) return mq } // QueryOwner chains the current query on the "owner" edge. func (mq *MetaQuery) QueryOwner() *AlertQuery { - query := &AlertQuery{config: mq.config} + query := (&AlertClient{config: mq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := mq.prepareQuery(ctx); err != nil { return nil, err @@ -86,7 +84,7 @@ func (mq *MetaQuery) QueryOwner() *AlertQuery { // First returns the first Meta entity from the query. // Returns a *NotFoundError when no Meta was found. func (mq *MetaQuery) First(ctx context.Context) (*Meta, error) { - nodes, err := mq.Limit(1).All(ctx) + nodes, err := mq.Limit(1).All(setContextOp(ctx, mq.ctx, "First")) if err != nil { return nil, err } @@ -109,7 +107,7 @@ func (mq *MetaQuery) FirstX(ctx context.Context) *Meta { // Returns a *NotFoundError when no Meta ID was found. func (mq *MetaQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = mq.Limit(1).IDs(ctx); err != nil { + if ids, err = mq.Limit(1).IDs(setContextOp(ctx, mq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -132,7 +130,7 @@ func (mq *MetaQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Meta entity is found. // Returns a *NotFoundError when no Meta entities are found. func (mq *MetaQuery) Only(ctx context.Context) (*Meta, error) { - nodes, err := mq.Limit(2).All(ctx) + nodes, err := mq.Limit(2).All(setContextOp(ctx, mq.ctx, "Only")) if err != nil { return nil, err } @@ -160,7 +158,7 @@ func (mq *MetaQuery) OnlyX(ctx context.Context) *Meta { // Returns a *NotFoundError when no entities are found. func (mq *MetaQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = mq.Limit(2).IDs(ctx); err != nil { + if ids, err = mq.Limit(2).IDs(setContextOp(ctx, mq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -185,10 +183,12 @@ func (mq *MetaQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of MetaSlice. func (mq *MetaQuery) All(ctx context.Context) ([]*Meta, error) { + ctx = setContextOp(ctx, mq.ctx, "All") if err := mq.prepareQuery(ctx); err != nil { return nil, err } - return mq.sqlAll(ctx) + qr := querierAll[[]*Meta, *MetaQuery]() + return withInterceptors[[]*Meta](ctx, mq, qr, mq.inters) } // AllX is like All, but panics if an error occurs. @@ -201,9 +201,12 @@ func (mq *MetaQuery) AllX(ctx context.Context) []*Meta { } // IDs executes the query and returns a list of Meta IDs. -func (mq *MetaQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := mq.Select(meta.FieldID).Scan(ctx, &ids); err != nil { +func (mq *MetaQuery) IDs(ctx context.Context) (ids []int, err error) { + if mq.ctx.Unique == nil && mq.path != nil { + mq.Unique(true) + } + ctx = setContextOp(ctx, mq.ctx, "IDs") + if err = mq.Select(meta.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -220,10 +223,11 @@ func (mq *MetaQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (mq *MetaQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, mq.ctx, "Count") if err := mq.prepareQuery(ctx); err != nil { return 0, err } - return mq.sqlCount(ctx) + return withInterceptors[int](ctx, mq, querierCount[*MetaQuery](), mq.inters) } // CountX is like Count, but panics if an error occurs. @@ -237,10 +241,15 @@ func (mq *MetaQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (mq *MetaQuery) Exist(ctx context.Context) (bool, error) { - if err := mq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, mq.ctx, "Exist") + switch _, err := mq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return mq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -260,22 +269,21 @@ func (mq *MetaQuery) Clone() *MetaQuery { } return &MetaQuery{ config: mq.config, - limit: mq.limit, - offset: mq.offset, - order: append([]OrderFunc{}, mq.order...), + ctx: mq.ctx.Clone(), + order: append([]meta.OrderOption{}, mq.order...), + inters: append([]Interceptor{}, mq.inters...), predicates: append([]predicate.Meta{}, mq.predicates...), withOwner: mq.withOwner.Clone(), // clone intermediate query. - sql: mq.sql.Clone(), - path: mq.path, - unique: mq.unique, + sql: mq.sql.Clone(), + path: mq.path, } } // WithOwner tells the query-builder to eager-load the nodes that are connected to // the "owner" edge. The optional arguments are used to configure the query builder of the edge. func (mq *MetaQuery) WithOwner(opts ...func(*AlertQuery)) *MetaQuery { - query := &AlertQuery{config: mq.config} + query := (&AlertClient{config: mq.config}).Query() for _, opt := range opts { opt(query) } @@ -298,16 +306,11 @@ func (mq *MetaQuery) WithOwner(opts ...func(*AlertQuery)) *MetaQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (mq *MetaQuery) GroupBy(field string, fields ...string) *MetaGroupBy { - grbuild := &MetaGroupBy{config: mq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := mq.prepareQuery(ctx); err != nil { - return nil, err - } - return mq.sqlQuery(ctx), nil - } + mq.ctx.Fields = append([]string{field}, fields...) + grbuild := &MetaGroupBy{build: mq} + grbuild.flds = &mq.ctx.Fields grbuild.label = meta.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -324,15 +327,30 @@ func (mq *MetaQuery) GroupBy(field string, fields ...string) *MetaGroupBy { // Select(meta.FieldCreatedAt). // Scan(ctx, &v) func (mq *MetaQuery) Select(fields ...string) *MetaSelect { - mq.fields = append(mq.fields, fields...) - selbuild := &MetaSelect{MetaQuery: mq} - selbuild.label = meta.Label - selbuild.flds, selbuild.scan = &mq.fields, selbuild.Scan - return selbuild + mq.ctx.Fields = append(mq.ctx.Fields, fields...) + sbuild := &MetaSelect{MetaQuery: mq} + sbuild.label = meta.Label + sbuild.flds, sbuild.scan = &mq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a MetaSelect configured with the given aggregations. +func (mq *MetaQuery) Aggregate(fns ...AggregateFunc) *MetaSelect { + return mq.Select().Aggregate(fns...) } func (mq *MetaQuery) prepareQuery(ctx context.Context) error { - for _, f := range mq.fields { + for _, inter := range mq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, mq); err != nil { + return err + } + } + } + for _, f := range mq.ctx.Fields { if !meta.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -392,6 +410,9 @@ func (mq *MetaQuery) loadOwner(ctx context.Context, query *AlertQuery, nodes []* } nodeids[fk] = append(nodeids[fk], nodes[i]) } + if len(ids) == 0 { + return nil + } query.Where(alert.IDIn(ids...)) neighbors, err := query.All(ctx) if err != nil { @@ -411,41 +432,22 @@ func (mq *MetaQuery) loadOwner(ctx context.Context, query *AlertQuery, nodes []* func (mq *MetaQuery) sqlCount(ctx context.Context) (int, error) { _spec := mq.querySpec() - _spec.Node.Columns = mq.fields - if len(mq.fields) > 0 { - _spec.Unique = mq.unique != nil && *mq.unique + _spec.Node.Columns = mq.ctx.Fields + if len(mq.ctx.Fields) > 0 { + _spec.Unique = mq.ctx.Unique != nil && *mq.ctx.Unique } return sqlgraph.CountNodes(ctx, mq.driver, _spec) } -func (mq *MetaQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := mq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (mq *MetaQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: meta.Table, - Columns: meta.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, - }, - From: mq.sql, - Unique: true, - } - if unique := mq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(meta.Table, meta.Columns, sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt)) + _spec.From = mq.sql + if unique := mq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if mq.path != nil { + _spec.Unique = true } - if fields := mq.fields; len(fields) > 0 { + if fields := mq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, meta.FieldID) for i := range fields { @@ -453,6 +455,9 @@ func (mq *MetaQuery) querySpec() *sqlgraph.QuerySpec { _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) } } + if mq.withOwner != nil { + _spec.Node.AddColumnOnce(meta.FieldAlertMetas) + } } if ps := mq.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { @@ -461,10 +466,10 @@ func (mq *MetaQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := mq.limit; limit != nil { + if limit := mq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := mq.offset; offset != nil { + if offset := mq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := mq.order; len(ps) > 0 { @@ -480,7 +485,7 @@ func (mq *MetaQuery) querySpec() *sqlgraph.QuerySpec { func (mq *MetaQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(mq.driver.Dialect()) t1 := builder.Table(meta.Table) - columns := mq.fields + columns := mq.ctx.Fields if len(columns) == 0 { columns = meta.Columns } @@ -489,7 +494,7 @@ func (mq *MetaQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = mq.sql selector.Select(selector.Columns(columns...)...) } - if mq.unique != nil && *mq.unique { + if mq.ctx.Unique != nil && *mq.ctx.Unique { selector.Distinct() } for _, p := range mq.predicates { @@ -498,12 +503,12 @@ func (mq *MetaQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range mq.order { p(selector) } - if offset := mq.offset; offset != nil { + if offset := mq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := mq.limit; limit != nil { + if limit := mq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -511,13 +516,8 @@ func (mq *MetaQuery) sqlQuery(ctx context.Context) *sql.Selector { // MetaGroupBy is the group-by builder for Meta entities. type MetaGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *MetaQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -526,74 +526,77 @@ func (mgb *MetaGroupBy) Aggregate(fns ...AggregateFunc) *MetaGroupBy { return mgb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (mgb *MetaGroupBy) Scan(ctx context.Context, v any) error { - query, err := mgb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, mgb.build.ctx, "GroupBy") + if err := mgb.build.prepareQuery(ctx); err != nil { return err } - mgb.sql = query - return mgb.sqlScan(ctx, v) + return scanWithInterceptors[*MetaQuery, *MetaGroupBy](ctx, mgb.build, mgb, mgb.build.inters, v) } -func (mgb *MetaGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range mgb.fields { - if !meta.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (mgb *MetaGroupBy) sqlScan(ctx context.Context, root *MetaQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(mgb.fns)) + for _, fn := range mgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*mgb.flds)+len(mgb.fns)) + for _, f := range *mgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := mgb.sqlQuery() + selector.GroupBy(selector.Columns(*mgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := mgb.driver.Query(ctx, query, args, rows); err != nil { + if err := mgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (mgb *MetaGroupBy) sqlQuery() *sql.Selector { - selector := mgb.sql.Select() - aggregation := make([]string, 0, len(mgb.fns)) - for _, fn := range mgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(mgb.fields)+len(mgb.fns)) - for _, f := range mgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(mgb.fields...)...) -} - // MetaSelect is the builder for selecting fields of Meta entities. type MetaSelect struct { *MetaQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ms *MetaSelect) Aggregate(fns ...AggregateFunc) *MetaSelect { + ms.fns = append(ms.fns, fns...) + return ms } // Scan applies the selector query and scans the result into the given value. func (ms *MetaSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ms.ctx, "Select") if err := ms.prepareQuery(ctx); err != nil { return err } - ms.sql = ms.MetaQuery.sqlQuery(ctx) - return ms.sqlScan(ctx, v) + return scanWithInterceptors[*MetaQuery, *MetaSelect](ctx, ms.MetaQuery, ms, ms.inters, v) } -func (ms *MetaSelect) sqlScan(ctx context.Context, v any) error { +func (ms *MetaSelect) sqlScan(ctx context.Context, root *MetaQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ms.fns)) + for _, fn := range ms.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ms.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := ms.sql.Query() + query, args := selector.Query() if err := ms.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/pkg/database/ent/meta_update.go b/pkg/database/ent/meta_update.go index 67a198dddfa..bdf622eb6c3 100644 --- a/pkg/database/ent/meta_update.go +++ b/pkg/database/ent/meta_update.go @@ -29,42 +29,12 @@ func (mu *MetaUpdate) Where(ps ...predicate.Meta) *MetaUpdate { return mu } -// SetCreatedAt sets the "created_at" field. -func (mu *MetaUpdate) SetCreatedAt(t time.Time) *MetaUpdate { - mu.mutation.SetCreatedAt(t) - return mu -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (mu *MetaUpdate) ClearCreatedAt() *MetaUpdate { - mu.mutation.ClearCreatedAt() - return mu -} - // SetUpdatedAt sets the "updated_at" field. func (mu *MetaUpdate) SetUpdatedAt(t time.Time) *MetaUpdate { mu.mutation.SetUpdatedAt(t) return mu } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (mu *MetaUpdate) ClearUpdatedAt() *MetaUpdate { - mu.mutation.ClearUpdatedAt() - return mu -} - -// SetKey sets the "key" field. -func (mu *MetaUpdate) SetKey(s string) *MetaUpdate { - mu.mutation.SetKey(s) - return mu -} - -// SetValue sets the "value" field. -func (mu *MetaUpdate) SetValue(s string) *MetaUpdate { - mu.mutation.SetValue(s) - return mu -} - // SetAlertMetas sets the "alert_metas" field. func (mu *MetaUpdate) SetAlertMetas(i int) *MetaUpdate { mu.mutation.SetAlertMetas(i) @@ -117,41 +87,8 @@ func (mu *MetaUpdate) ClearOwner() *MetaUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (mu *MetaUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) mu.defaults() - if len(mu.hooks) == 0 { - if err = mu.check(); err != nil { - return 0, err - } - affected, err = mu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MetaMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = mu.check(); err != nil { - return 0, err - } - mu.mutation = mutation - affected, err = mu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(mu.hooks) - 1; i >= 0; i-- { - if mu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, mu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, mu.sqlSave, mu.mutation, mu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -178,37 +115,14 @@ func (mu *MetaUpdate) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (mu *MetaUpdate) defaults() { - if _, ok := mu.mutation.CreatedAt(); !ok && !mu.mutation.CreatedAtCleared() { - v := meta.UpdateDefaultCreatedAt() - mu.mutation.SetCreatedAt(v) - } - if _, ok := mu.mutation.UpdatedAt(); !ok && !mu.mutation.UpdatedAtCleared() { + if _, ok := mu.mutation.UpdatedAt(); !ok { v := meta.UpdateDefaultUpdatedAt() mu.mutation.SetUpdatedAt(v) } } -// check runs all checks and user-defined validators on the builder. -func (mu *MetaUpdate) check() error { - if v, ok := mu.mutation.Value(); ok { - if err := meta.ValueValidator(v); err != nil { - return &ValidationError{Name: "value", err: fmt.Errorf(`ent: validator failed for field "Meta.value": %w`, err)} - } - } - return nil -} - func (mu *MetaUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: meta.Table, - Columns: meta.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(meta.Table, meta.Columns, sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt)) if ps := mu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -216,45 +130,8 @@ func (mu *MetaUpdate) sqlSave(ctx context.Context) (n int, err error) { } } } - if value, ok := mu.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: meta.FieldCreatedAt, - }) - } - if mu.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: meta.FieldCreatedAt, - }) - } if value, ok := mu.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: meta.FieldUpdatedAt, - }) - } - if mu.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: meta.FieldUpdatedAt, - }) - } - if value, ok := mu.mutation.Key(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: meta.FieldKey, - }) - } - if value, ok := mu.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: meta.FieldValue, - }) + _spec.SetField(meta.FieldUpdatedAt, field.TypeTime, value) } if mu.mutation.OwnerCleared() { edge := &sqlgraph.EdgeSpec{ @@ -264,10 +141,7 @@ func (mu *MetaUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{meta.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -280,10 +154,7 @@ func (mu *MetaUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{meta.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -299,6 +170,7 @@ func (mu *MetaUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + mu.mutation.done = true return n, nil } @@ -310,42 +182,12 @@ type MetaUpdateOne struct { mutation *MetaMutation } -// SetCreatedAt sets the "created_at" field. -func (muo *MetaUpdateOne) SetCreatedAt(t time.Time) *MetaUpdateOne { - muo.mutation.SetCreatedAt(t) - return muo -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (muo *MetaUpdateOne) ClearCreatedAt() *MetaUpdateOne { - muo.mutation.ClearCreatedAt() - return muo -} - // SetUpdatedAt sets the "updated_at" field. func (muo *MetaUpdateOne) SetUpdatedAt(t time.Time) *MetaUpdateOne { muo.mutation.SetUpdatedAt(t) return muo } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (muo *MetaUpdateOne) ClearUpdatedAt() *MetaUpdateOne { - muo.mutation.ClearUpdatedAt() - return muo -} - -// SetKey sets the "key" field. -func (muo *MetaUpdateOne) SetKey(s string) *MetaUpdateOne { - muo.mutation.SetKey(s) - return muo -} - -// SetValue sets the "value" field. -func (muo *MetaUpdateOne) SetValue(s string) *MetaUpdateOne { - muo.mutation.SetValue(s) - return muo -} - // SetAlertMetas sets the "alert_metas" field. func (muo *MetaUpdateOne) SetAlertMetas(i int) *MetaUpdateOne { muo.mutation.SetAlertMetas(i) @@ -396,6 +238,12 @@ func (muo *MetaUpdateOne) ClearOwner() *MetaUpdateOne { return muo } +// Where appends a list predicates to the MetaUpdate builder. +func (muo *MetaUpdateOne) Where(ps ...predicate.Meta) *MetaUpdateOne { + muo.mutation.Where(ps...) + return muo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (muo *MetaUpdateOne) Select(field string, fields ...string) *MetaUpdateOne { @@ -405,47 +253,8 @@ func (muo *MetaUpdateOne) Select(field string, fields ...string) *MetaUpdateOne // Save executes the query and returns the updated Meta entity. func (muo *MetaUpdateOne) Save(ctx context.Context) (*Meta, error) { - var ( - err error - node *Meta - ) muo.defaults() - if len(muo.hooks) == 0 { - if err = muo.check(); err != nil { - return nil, err - } - node, err = muo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MetaMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = muo.check(); err != nil { - return nil, err - } - muo.mutation = mutation - node, err = muo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(muo.hooks) - 1; i >= 0; i-- { - if muo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = muo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, muo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Meta) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MetaMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, muo.sqlSave, muo.mutation, muo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -472,37 +281,14 @@ func (muo *MetaUpdateOne) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (muo *MetaUpdateOne) defaults() { - if _, ok := muo.mutation.CreatedAt(); !ok && !muo.mutation.CreatedAtCleared() { - v := meta.UpdateDefaultCreatedAt() - muo.mutation.SetCreatedAt(v) - } - if _, ok := muo.mutation.UpdatedAt(); !ok && !muo.mutation.UpdatedAtCleared() { + if _, ok := muo.mutation.UpdatedAt(); !ok { v := meta.UpdateDefaultUpdatedAt() muo.mutation.SetUpdatedAt(v) } } -// check runs all checks and user-defined validators on the builder. -func (muo *MetaUpdateOne) check() error { - if v, ok := muo.mutation.Value(); ok { - if err := meta.ValueValidator(v); err != nil { - return &ValidationError{Name: "value", err: fmt.Errorf(`ent: validator failed for field "Meta.value": %w`, err)} - } - } - return nil -} - func (muo *MetaUpdateOne) sqlSave(ctx context.Context) (_node *Meta, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: meta.Table, - Columns: meta.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(meta.Table, meta.Columns, sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt)) id, ok := muo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Meta.id" for update`)} @@ -527,45 +313,8 @@ func (muo *MetaUpdateOne) sqlSave(ctx context.Context) (_node *Meta, err error) } } } - if value, ok := muo.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: meta.FieldCreatedAt, - }) - } - if muo.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: meta.FieldCreatedAt, - }) - } if value, ok := muo.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: meta.FieldUpdatedAt, - }) - } - if muo.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: meta.FieldUpdatedAt, - }) - } - if value, ok := muo.mutation.Key(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: meta.FieldKey, - }) - } - if value, ok := muo.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: meta.FieldValue, - }) + _spec.SetField(meta.FieldUpdatedAt, field.TypeTime, value) } if muo.mutation.OwnerCleared() { edge := &sqlgraph.EdgeSpec{ @@ -575,10 +324,7 @@ func (muo *MetaUpdateOne) sqlSave(ctx context.Context) (_node *Meta, err error) Columns: []string{meta.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -591,10 +337,7 @@ func (muo *MetaUpdateOne) sqlSave(ctx context.Context) (_node *Meta, err error) Columns: []string{meta.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -613,5 +356,6 @@ func (muo *MetaUpdateOne) sqlSave(ctx context.Context) (_node *Meta, err error) } return nil, err } + muo.mutation.done = true return _node, nil } diff --git a/pkg/database/ent/metric.go b/pkg/database/ent/metric.go new file mode 100644 index 00000000000..47f3b4df4e5 --- /dev/null +++ b/pkg/database/ent/metric.go @@ -0,0 +1,154 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" +) + +// Metric is the model entity for the Metric schema. +type Metric struct { + config `json:"-"` + // ID of the ent. + ID int `json:"id,omitempty"` + // Type of the metrics source: LP=logprocessor, RC=remediation + GeneratedType metric.GeneratedType `json:"generated_type,omitempty"` + // Source of the metrics: machine id, bouncer name... + // It must come from the auth middleware. + GeneratedBy string `json:"generated_by,omitempty"` + // When the metrics are received by LAPI + ReceivedAt time.Time `json:"received_at,omitempty"` + // When the metrics are sent to the console + PushedAt *time.Time `json:"pushed_at,omitempty"` + // The actual metrics (item0) + Payload string `json:"payload,omitempty"` + selectValues sql.SelectValues +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*Metric) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case metric.FieldID: + values[i] = new(sql.NullInt64) + case metric.FieldGeneratedType, metric.FieldGeneratedBy, metric.FieldPayload: + values[i] = new(sql.NullString) + case metric.FieldReceivedAt, metric.FieldPushedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the Metric fields. +func (m *Metric) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case metric.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + m.ID = int(value.Int64) + case metric.FieldGeneratedType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field generated_type", values[i]) + } else if value.Valid { + m.GeneratedType = metric.GeneratedType(value.String) + } + case metric.FieldGeneratedBy: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field generated_by", values[i]) + } else if value.Valid { + m.GeneratedBy = value.String + } + case metric.FieldReceivedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field received_at", values[i]) + } else if value.Valid { + m.ReceivedAt = value.Time + } + case metric.FieldPushedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field pushed_at", values[i]) + } else if value.Valid { + m.PushedAt = new(time.Time) + *m.PushedAt = value.Time + } + case metric.FieldPayload: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field payload", values[i]) + } else if value.Valid { + m.Payload = value.String + } + default: + m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the Metric. +// This includes values selected through modifiers, order, etc. +func (m *Metric) Value(name string) (ent.Value, error) { + return m.selectValues.Get(name) +} + +// Update returns a builder for updating this Metric. +// Note that you need to call Metric.Unwrap() before calling this method if this Metric +// was returned from a transaction, and the transaction was committed or rolled back. +func (m *Metric) Update() *MetricUpdateOne { + return NewMetricClient(m.config).UpdateOne(m) +} + +// Unwrap unwraps the Metric entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (m *Metric) Unwrap() *Metric { + _tx, ok := m.config.driver.(*txDriver) + if !ok { + panic("ent: Metric is not a transactional entity") + } + m.config.driver = _tx.drv + return m +} + +// String implements the fmt.Stringer. +func (m *Metric) String() string { + var builder strings.Builder + builder.WriteString("Metric(") + builder.WriteString(fmt.Sprintf("id=%v, ", m.ID)) + builder.WriteString("generated_type=") + builder.WriteString(fmt.Sprintf("%v", m.GeneratedType)) + builder.WriteString(", ") + builder.WriteString("generated_by=") + builder.WriteString(m.GeneratedBy) + builder.WriteString(", ") + builder.WriteString("received_at=") + builder.WriteString(m.ReceivedAt.Format(time.ANSIC)) + builder.WriteString(", ") + if v := m.PushedAt; v != nil { + builder.WriteString("pushed_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("payload=") + builder.WriteString(m.Payload) + builder.WriteByte(')') + return builder.String() +} + +// Metrics is a parsable slice of Metric. +type Metrics []*Metric diff --git a/pkg/database/ent/metric/metric.go b/pkg/database/ent/metric/metric.go new file mode 100644 index 00000000000..78e88982220 --- /dev/null +++ b/pkg/database/ent/metric/metric.go @@ -0,0 +1,104 @@ +// Code generated by ent, DO NOT EDIT. + +package metric + +import ( + "fmt" + + "entgo.io/ent/dialect/sql" +) + +const ( + // Label holds the string label denoting the metric type in the database. + Label = "metric" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldGeneratedType holds the string denoting the generated_type field in the database. + FieldGeneratedType = "generated_type" + // FieldGeneratedBy holds the string denoting the generated_by field in the database. + FieldGeneratedBy = "generated_by" + // FieldReceivedAt holds the string denoting the received_at field in the database. + FieldReceivedAt = "received_at" + // FieldPushedAt holds the string denoting the pushed_at field in the database. + FieldPushedAt = "pushed_at" + // FieldPayload holds the string denoting the payload field in the database. + FieldPayload = "payload" + // Table holds the table name of the metric in the database. + Table = "metrics" +) + +// Columns holds all SQL columns for metric fields. +var Columns = []string{ + FieldID, + FieldGeneratedType, + FieldGeneratedBy, + FieldReceivedAt, + FieldPushedAt, + FieldPayload, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +// GeneratedType defines the type for the "generated_type" enum field. +type GeneratedType string + +// GeneratedType values. +const ( + GeneratedTypeLP GeneratedType = "LP" + GeneratedTypeRC GeneratedType = "RC" +) + +func (gt GeneratedType) String() string { + return string(gt) +} + +// GeneratedTypeValidator is a validator for the "generated_type" field enum values. It is called by the builders before save. +func GeneratedTypeValidator(gt GeneratedType) error { + switch gt { + case GeneratedTypeLP, GeneratedTypeRC: + return nil + default: + return fmt.Errorf("metric: invalid enum value for generated_type field: %q", gt) + } +} + +// OrderOption defines the ordering options for the Metric queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByGeneratedType orders the results by the generated_type field. +func ByGeneratedType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldGeneratedType, opts...).ToFunc() +} + +// ByGeneratedBy orders the results by the generated_by field. +func ByGeneratedBy(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldGeneratedBy, opts...).ToFunc() +} + +// ByReceivedAt orders the results by the received_at field. +func ByReceivedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldReceivedAt, opts...).ToFunc() +} + +// ByPushedAt orders the results by the pushed_at field. +func ByPushedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPushedAt, opts...).ToFunc() +} + +// ByPayload orders the results by the payload field. +func ByPayload(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPayload, opts...).ToFunc() +} diff --git a/pkg/database/ent/metric/where.go b/pkg/database/ent/metric/where.go new file mode 100644 index 00000000000..72bd9d93cd7 --- /dev/null +++ b/pkg/database/ent/metric/where.go @@ -0,0 +1,330 @@ +// Code generated by ent, DO NOT EDIT. + +package metric + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int) predicate.Metric { + return predicate.Metric(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int) predicate.Metric { + return predicate.Metric(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int) predicate.Metric { + return predicate.Metric(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int) predicate.Metric { + return predicate.Metric(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int) predicate.Metric { + return predicate.Metric(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int) predicate.Metric { + return predicate.Metric(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int) predicate.Metric { + return predicate.Metric(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int) predicate.Metric { + return predicate.Metric(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int) predicate.Metric { + return predicate.Metric(sql.FieldLTE(FieldID, id)) +} + +// GeneratedBy applies equality check predicate on the "generated_by" field. It's identical to GeneratedByEQ. +func GeneratedBy(v string) predicate.Metric { + return predicate.Metric(sql.FieldEQ(FieldGeneratedBy, v)) +} + +// ReceivedAt applies equality check predicate on the "received_at" field. It's identical to ReceivedAtEQ. +func ReceivedAt(v time.Time) predicate.Metric { + return predicate.Metric(sql.FieldEQ(FieldReceivedAt, v)) +} + +// PushedAt applies equality check predicate on the "pushed_at" field. It's identical to PushedAtEQ. +func PushedAt(v time.Time) predicate.Metric { + return predicate.Metric(sql.FieldEQ(FieldPushedAt, v)) +} + +// Payload applies equality check predicate on the "payload" field. It's identical to PayloadEQ. +func Payload(v string) predicate.Metric { + return predicate.Metric(sql.FieldEQ(FieldPayload, v)) +} + +// GeneratedTypeEQ applies the EQ predicate on the "generated_type" field. +func GeneratedTypeEQ(v GeneratedType) predicate.Metric { + return predicate.Metric(sql.FieldEQ(FieldGeneratedType, v)) +} + +// GeneratedTypeNEQ applies the NEQ predicate on the "generated_type" field. +func GeneratedTypeNEQ(v GeneratedType) predicate.Metric { + return predicate.Metric(sql.FieldNEQ(FieldGeneratedType, v)) +} + +// GeneratedTypeIn applies the In predicate on the "generated_type" field. +func GeneratedTypeIn(vs ...GeneratedType) predicate.Metric { + return predicate.Metric(sql.FieldIn(FieldGeneratedType, vs...)) +} + +// GeneratedTypeNotIn applies the NotIn predicate on the "generated_type" field. +func GeneratedTypeNotIn(vs ...GeneratedType) predicate.Metric { + return predicate.Metric(sql.FieldNotIn(FieldGeneratedType, vs...)) +} + +// GeneratedByEQ applies the EQ predicate on the "generated_by" field. +func GeneratedByEQ(v string) predicate.Metric { + return predicate.Metric(sql.FieldEQ(FieldGeneratedBy, v)) +} + +// GeneratedByNEQ applies the NEQ predicate on the "generated_by" field. +func GeneratedByNEQ(v string) predicate.Metric { + return predicate.Metric(sql.FieldNEQ(FieldGeneratedBy, v)) +} + +// GeneratedByIn applies the In predicate on the "generated_by" field. +func GeneratedByIn(vs ...string) predicate.Metric { + return predicate.Metric(sql.FieldIn(FieldGeneratedBy, vs...)) +} + +// GeneratedByNotIn applies the NotIn predicate on the "generated_by" field. +func GeneratedByNotIn(vs ...string) predicate.Metric { + return predicate.Metric(sql.FieldNotIn(FieldGeneratedBy, vs...)) +} + +// GeneratedByGT applies the GT predicate on the "generated_by" field. +func GeneratedByGT(v string) predicate.Metric { + return predicate.Metric(sql.FieldGT(FieldGeneratedBy, v)) +} + +// GeneratedByGTE applies the GTE predicate on the "generated_by" field. +func GeneratedByGTE(v string) predicate.Metric { + return predicate.Metric(sql.FieldGTE(FieldGeneratedBy, v)) +} + +// GeneratedByLT applies the LT predicate on the "generated_by" field. +func GeneratedByLT(v string) predicate.Metric { + return predicate.Metric(sql.FieldLT(FieldGeneratedBy, v)) +} + +// GeneratedByLTE applies the LTE predicate on the "generated_by" field. +func GeneratedByLTE(v string) predicate.Metric { + return predicate.Metric(sql.FieldLTE(FieldGeneratedBy, v)) +} + +// GeneratedByContains applies the Contains predicate on the "generated_by" field. +func GeneratedByContains(v string) predicate.Metric { + return predicate.Metric(sql.FieldContains(FieldGeneratedBy, v)) +} + +// GeneratedByHasPrefix applies the HasPrefix predicate on the "generated_by" field. +func GeneratedByHasPrefix(v string) predicate.Metric { + return predicate.Metric(sql.FieldHasPrefix(FieldGeneratedBy, v)) +} + +// GeneratedByHasSuffix applies the HasSuffix predicate on the "generated_by" field. +func GeneratedByHasSuffix(v string) predicate.Metric { + return predicate.Metric(sql.FieldHasSuffix(FieldGeneratedBy, v)) +} + +// GeneratedByEqualFold applies the EqualFold predicate on the "generated_by" field. +func GeneratedByEqualFold(v string) predicate.Metric { + return predicate.Metric(sql.FieldEqualFold(FieldGeneratedBy, v)) +} + +// GeneratedByContainsFold applies the ContainsFold predicate on the "generated_by" field. +func GeneratedByContainsFold(v string) predicate.Metric { + return predicate.Metric(sql.FieldContainsFold(FieldGeneratedBy, v)) +} + +// ReceivedAtEQ applies the EQ predicate on the "received_at" field. +func ReceivedAtEQ(v time.Time) predicate.Metric { + return predicate.Metric(sql.FieldEQ(FieldReceivedAt, v)) +} + +// ReceivedAtNEQ applies the NEQ predicate on the "received_at" field. +func ReceivedAtNEQ(v time.Time) predicate.Metric { + return predicate.Metric(sql.FieldNEQ(FieldReceivedAt, v)) +} + +// ReceivedAtIn applies the In predicate on the "received_at" field. +func ReceivedAtIn(vs ...time.Time) predicate.Metric { + return predicate.Metric(sql.FieldIn(FieldReceivedAt, vs...)) +} + +// ReceivedAtNotIn applies the NotIn predicate on the "received_at" field. +func ReceivedAtNotIn(vs ...time.Time) predicate.Metric { + return predicate.Metric(sql.FieldNotIn(FieldReceivedAt, vs...)) +} + +// ReceivedAtGT applies the GT predicate on the "received_at" field. +func ReceivedAtGT(v time.Time) predicate.Metric { + return predicate.Metric(sql.FieldGT(FieldReceivedAt, v)) +} + +// ReceivedAtGTE applies the GTE predicate on the "received_at" field. +func ReceivedAtGTE(v time.Time) predicate.Metric { + return predicate.Metric(sql.FieldGTE(FieldReceivedAt, v)) +} + +// ReceivedAtLT applies the LT predicate on the "received_at" field. +func ReceivedAtLT(v time.Time) predicate.Metric { + return predicate.Metric(sql.FieldLT(FieldReceivedAt, v)) +} + +// ReceivedAtLTE applies the LTE predicate on the "received_at" field. +func ReceivedAtLTE(v time.Time) predicate.Metric { + return predicate.Metric(sql.FieldLTE(FieldReceivedAt, v)) +} + +// PushedAtEQ applies the EQ predicate on the "pushed_at" field. +func PushedAtEQ(v time.Time) predicate.Metric { + return predicate.Metric(sql.FieldEQ(FieldPushedAt, v)) +} + +// PushedAtNEQ applies the NEQ predicate on the "pushed_at" field. +func PushedAtNEQ(v time.Time) predicate.Metric { + return predicate.Metric(sql.FieldNEQ(FieldPushedAt, v)) +} + +// PushedAtIn applies the In predicate on the "pushed_at" field. +func PushedAtIn(vs ...time.Time) predicate.Metric { + return predicate.Metric(sql.FieldIn(FieldPushedAt, vs...)) +} + +// PushedAtNotIn applies the NotIn predicate on the "pushed_at" field. +func PushedAtNotIn(vs ...time.Time) predicate.Metric { + return predicate.Metric(sql.FieldNotIn(FieldPushedAt, vs...)) +} + +// PushedAtGT applies the GT predicate on the "pushed_at" field. +func PushedAtGT(v time.Time) predicate.Metric { + return predicate.Metric(sql.FieldGT(FieldPushedAt, v)) +} + +// PushedAtGTE applies the GTE predicate on the "pushed_at" field. +func PushedAtGTE(v time.Time) predicate.Metric { + return predicate.Metric(sql.FieldGTE(FieldPushedAt, v)) +} + +// PushedAtLT applies the LT predicate on the "pushed_at" field. +func PushedAtLT(v time.Time) predicate.Metric { + return predicate.Metric(sql.FieldLT(FieldPushedAt, v)) +} + +// PushedAtLTE applies the LTE predicate on the "pushed_at" field. +func PushedAtLTE(v time.Time) predicate.Metric { + return predicate.Metric(sql.FieldLTE(FieldPushedAt, v)) +} + +// PushedAtIsNil applies the IsNil predicate on the "pushed_at" field. +func PushedAtIsNil() predicate.Metric { + return predicate.Metric(sql.FieldIsNull(FieldPushedAt)) +} + +// PushedAtNotNil applies the NotNil predicate on the "pushed_at" field. +func PushedAtNotNil() predicate.Metric { + return predicate.Metric(sql.FieldNotNull(FieldPushedAt)) +} + +// PayloadEQ applies the EQ predicate on the "payload" field. +func PayloadEQ(v string) predicate.Metric { + return predicate.Metric(sql.FieldEQ(FieldPayload, v)) +} + +// PayloadNEQ applies the NEQ predicate on the "payload" field. +func PayloadNEQ(v string) predicate.Metric { + return predicate.Metric(sql.FieldNEQ(FieldPayload, v)) +} + +// PayloadIn applies the In predicate on the "payload" field. +func PayloadIn(vs ...string) predicate.Metric { + return predicate.Metric(sql.FieldIn(FieldPayload, vs...)) +} + +// PayloadNotIn applies the NotIn predicate on the "payload" field. +func PayloadNotIn(vs ...string) predicate.Metric { + return predicate.Metric(sql.FieldNotIn(FieldPayload, vs...)) +} + +// PayloadGT applies the GT predicate on the "payload" field. +func PayloadGT(v string) predicate.Metric { + return predicate.Metric(sql.FieldGT(FieldPayload, v)) +} + +// PayloadGTE applies the GTE predicate on the "payload" field. +func PayloadGTE(v string) predicate.Metric { + return predicate.Metric(sql.FieldGTE(FieldPayload, v)) +} + +// PayloadLT applies the LT predicate on the "payload" field. +func PayloadLT(v string) predicate.Metric { + return predicate.Metric(sql.FieldLT(FieldPayload, v)) +} + +// PayloadLTE applies the LTE predicate on the "payload" field. +func PayloadLTE(v string) predicate.Metric { + return predicate.Metric(sql.FieldLTE(FieldPayload, v)) +} + +// PayloadContains applies the Contains predicate on the "payload" field. +func PayloadContains(v string) predicate.Metric { + return predicate.Metric(sql.FieldContains(FieldPayload, v)) +} + +// PayloadHasPrefix applies the HasPrefix predicate on the "payload" field. +func PayloadHasPrefix(v string) predicate.Metric { + return predicate.Metric(sql.FieldHasPrefix(FieldPayload, v)) +} + +// PayloadHasSuffix applies the HasSuffix predicate on the "payload" field. +func PayloadHasSuffix(v string) predicate.Metric { + return predicate.Metric(sql.FieldHasSuffix(FieldPayload, v)) +} + +// PayloadEqualFold applies the EqualFold predicate on the "payload" field. +func PayloadEqualFold(v string) predicate.Metric { + return predicate.Metric(sql.FieldEqualFold(FieldPayload, v)) +} + +// PayloadContainsFold applies the ContainsFold predicate on the "payload" field. +func PayloadContainsFold(v string) predicate.Metric { + return predicate.Metric(sql.FieldContainsFold(FieldPayload, v)) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.Metric) predicate.Metric { + return predicate.Metric(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.Metric) predicate.Metric { + return predicate.Metric(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.Metric) predicate.Metric { + return predicate.Metric(sql.NotPredicates(p)) +} diff --git a/pkg/database/ent/metric_create.go b/pkg/database/ent/metric_create.go new file mode 100644 index 00000000000..973cddd41d0 --- /dev/null +++ b/pkg/database/ent/metric_create.go @@ -0,0 +1,246 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" +) + +// MetricCreate is the builder for creating a Metric entity. +type MetricCreate struct { + config + mutation *MetricMutation + hooks []Hook +} + +// SetGeneratedType sets the "generated_type" field. +func (mc *MetricCreate) SetGeneratedType(mt metric.GeneratedType) *MetricCreate { + mc.mutation.SetGeneratedType(mt) + return mc +} + +// SetGeneratedBy sets the "generated_by" field. +func (mc *MetricCreate) SetGeneratedBy(s string) *MetricCreate { + mc.mutation.SetGeneratedBy(s) + return mc +} + +// SetReceivedAt sets the "received_at" field. +func (mc *MetricCreate) SetReceivedAt(t time.Time) *MetricCreate { + mc.mutation.SetReceivedAt(t) + return mc +} + +// SetPushedAt sets the "pushed_at" field. +func (mc *MetricCreate) SetPushedAt(t time.Time) *MetricCreate { + mc.mutation.SetPushedAt(t) + return mc +} + +// SetNillablePushedAt sets the "pushed_at" field if the given value is not nil. +func (mc *MetricCreate) SetNillablePushedAt(t *time.Time) *MetricCreate { + if t != nil { + mc.SetPushedAt(*t) + } + return mc +} + +// SetPayload sets the "payload" field. +func (mc *MetricCreate) SetPayload(s string) *MetricCreate { + mc.mutation.SetPayload(s) + return mc +} + +// Mutation returns the MetricMutation object of the builder. +func (mc *MetricCreate) Mutation() *MetricMutation { + return mc.mutation +} + +// Save creates the Metric in the database. +func (mc *MetricCreate) Save(ctx context.Context) (*Metric, error) { + return withHooks(ctx, mc.sqlSave, mc.mutation, mc.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (mc *MetricCreate) SaveX(ctx context.Context) *Metric { + v, err := mc.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (mc *MetricCreate) Exec(ctx context.Context) error { + _, err := mc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (mc *MetricCreate) ExecX(ctx context.Context) { + if err := mc.Exec(ctx); err != nil { + panic(err) + } +} + +// check runs all checks and user-defined validators on the builder. +func (mc *MetricCreate) check() error { + if _, ok := mc.mutation.GeneratedType(); !ok { + return &ValidationError{Name: "generated_type", err: errors.New(`ent: missing required field "Metric.generated_type"`)} + } + if v, ok := mc.mutation.GeneratedType(); ok { + if err := metric.GeneratedTypeValidator(v); err != nil { + return &ValidationError{Name: "generated_type", err: fmt.Errorf(`ent: validator failed for field "Metric.generated_type": %w`, err)} + } + } + if _, ok := mc.mutation.GeneratedBy(); !ok { + return &ValidationError{Name: "generated_by", err: errors.New(`ent: missing required field "Metric.generated_by"`)} + } + if _, ok := mc.mutation.ReceivedAt(); !ok { + return &ValidationError{Name: "received_at", err: errors.New(`ent: missing required field "Metric.received_at"`)} + } + if _, ok := mc.mutation.Payload(); !ok { + return &ValidationError{Name: "payload", err: errors.New(`ent: missing required field "Metric.payload"`)} + } + return nil +} + +func (mc *MetricCreate) sqlSave(ctx context.Context) (*Metric, error) { + if err := mc.check(); err != nil { + return nil, err + } + _node, _spec := mc.createSpec() + if err := sqlgraph.CreateNode(ctx, mc.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int(id) + mc.mutation.id = &_node.ID + mc.mutation.done = true + return _node, nil +} + +func (mc *MetricCreate) createSpec() (*Metric, *sqlgraph.CreateSpec) { + var ( + _node = &Metric{config: mc.config} + _spec = sqlgraph.NewCreateSpec(metric.Table, sqlgraph.NewFieldSpec(metric.FieldID, field.TypeInt)) + ) + if value, ok := mc.mutation.GeneratedType(); ok { + _spec.SetField(metric.FieldGeneratedType, field.TypeEnum, value) + _node.GeneratedType = value + } + if value, ok := mc.mutation.GeneratedBy(); ok { + _spec.SetField(metric.FieldGeneratedBy, field.TypeString, value) + _node.GeneratedBy = value + } + if value, ok := mc.mutation.ReceivedAt(); ok { + _spec.SetField(metric.FieldReceivedAt, field.TypeTime, value) + _node.ReceivedAt = value + } + if value, ok := mc.mutation.PushedAt(); ok { + _spec.SetField(metric.FieldPushedAt, field.TypeTime, value) + _node.PushedAt = &value + } + if value, ok := mc.mutation.Payload(); ok { + _spec.SetField(metric.FieldPayload, field.TypeString, value) + _node.Payload = value + } + return _node, _spec +} + +// MetricCreateBulk is the builder for creating many Metric entities in bulk. +type MetricCreateBulk struct { + config + err error + builders []*MetricCreate +} + +// Save creates the Metric entities in the database. +func (mcb *MetricCreateBulk) Save(ctx context.Context) ([]*Metric, error) { + if mcb.err != nil { + return nil, mcb.err + } + specs := make([]*sqlgraph.CreateSpec, len(mcb.builders)) + nodes := make([]*Metric, len(mcb.builders)) + mutators := make([]Mutator, len(mcb.builders)) + for i := range mcb.builders { + func(i int, root context.Context) { + builder := mcb.builders[i] + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*MetricMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, mcb.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, mcb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, mcb.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (mcb *MetricCreateBulk) SaveX(ctx context.Context) []*Metric { + v, err := mcb.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (mcb *MetricCreateBulk) Exec(ctx context.Context) error { + _, err := mcb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (mcb *MetricCreateBulk) ExecX(ctx context.Context) { + if err := mcb.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/pkg/database/ent/metric_delete.go b/pkg/database/ent/metric_delete.go new file mode 100644 index 00000000000..d6606680a6a --- /dev/null +++ b/pkg/database/ent/metric_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" +) + +// MetricDelete is the builder for deleting a Metric entity. +type MetricDelete struct { + config + hooks []Hook + mutation *MetricMutation +} + +// Where appends a list predicates to the MetricDelete builder. +func (md *MetricDelete) Where(ps ...predicate.Metric) *MetricDelete { + md.mutation.Where(ps...) + return md +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (md *MetricDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, md.sqlExec, md.mutation, md.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (md *MetricDelete) ExecX(ctx context.Context) int { + n, err := md.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (md *MetricDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(metric.Table, sqlgraph.NewFieldSpec(metric.FieldID, field.TypeInt)) + if ps := md.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, md.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + md.mutation.done = true + return affected, err +} + +// MetricDeleteOne is the builder for deleting a single Metric entity. +type MetricDeleteOne struct { + md *MetricDelete +} + +// Where appends a list predicates to the MetricDelete builder. +func (mdo *MetricDeleteOne) Where(ps ...predicate.Metric) *MetricDeleteOne { + mdo.md.mutation.Where(ps...) + return mdo +} + +// Exec executes the deletion query. +func (mdo *MetricDeleteOne) Exec(ctx context.Context) error { + n, err := mdo.md.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{metric.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (mdo *MetricDeleteOne) ExecX(ctx context.Context) { + if err := mdo.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/pkg/database/ent/metric_query.go b/pkg/database/ent/metric_query.go new file mode 100644 index 00000000000..6e1c6f08b4a --- /dev/null +++ b/pkg/database/ent/metric_query.go @@ -0,0 +1,526 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" +) + +// MetricQuery is the builder for querying Metric entities. +type MetricQuery struct { + config + ctx *QueryContext + order []metric.OrderOption + inters []Interceptor + predicates []predicate.Metric + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the MetricQuery builder. +func (mq *MetricQuery) Where(ps ...predicate.Metric) *MetricQuery { + mq.predicates = append(mq.predicates, ps...) + return mq +} + +// Limit the number of records to be returned by this query. +func (mq *MetricQuery) Limit(limit int) *MetricQuery { + mq.ctx.Limit = &limit + return mq +} + +// Offset to start from. +func (mq *MetricQuery) Offset(offset int) *MetricQuery { + mq.ctx.Offset = &offset + return mq +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (mq *MetricQuery) Unique(unique bool) *MetricQuery { + mq.ctx.Unique = &unique + return mq +} + +// Order specifies how the records should be ordered. +func (mq *MetricQuery) Order(o ...metric.OrderOption) *MetricQuery { + mq.order = append(mq.order, o...) + return mq +} + +// First returns the first Metric entity from the query. +// Returns a *NotFoundError when no Metric was found. +func (mq *MetricQuery) First(ctx context.Context) (*Metric, error) { + nodes, err := mq.Limit(1).All(setContextOp(ctx, mq.ctx, "First")) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{metric.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (mq *MetricQuery) FirstX(ctx context.Context) *Metric { + node, err := mq.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first Metric ID from the query. +// Returns a *NotFoundError when no Metric ID was found. +func (mq *MetricQuery) FirstID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = mq.Limit(1).IDs(setContextOp(ctx, mq.ctx, "FirstID")); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{metric.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (mq *MetricQuery) FirstIDX(ctx context.Context) int { + id, err := mq.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single Metric entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one Metric entity is found. +// Returns a *NotFoundError when no Metric entities are found. +func (mq *MetricQuery) Only(ctx context.Context) (*Metric, error) { + nodes, err := mq.Limit(2).All(setContextOp(ctx, mq.ctx, "Only")) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{metric.Label} + default: + return nil, &NotSingularError{metric.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (mq *MetricQuery) OnlyX(ctx context.Context) *Metric { + node, err := mq.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only Metric ID in the query. +// Returns a *NotSingularError when more than one Metric ID is found. +// Returns a *NotFoundError when no entities are found. +func (mq *MetricQuery) OnlyID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = mq.Limit(2).IDs(setContextOp(ctx, mq.ctx, "OnlyID")); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{metric.Label} + default: + err = &NotSingularError{metric.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (mq *MetricQuery) OnlyIDX(ctx context.Context) int { + id, err := mq.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of Metrics. +func (mq *MetricQuery) All(ctx context.Context) ([]*Metric, error) { + ctx = setContextOp(ctx, mq.ctx, "All") + if err := mq.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*Metric, *MetricQuery]() + return withInterceptors[[]*Metric](ctx, mq, qr, mq.inters) +} + +// AllX is like All, but panics if an error occurs. +func (mq *MetricQuery) AllX(ctx context.Context) []*Metric { + nodes, err := mq.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of Metric IDs. +func (mq *MetricQuery) IDs(ctx context.Context) (ids []int, err error) { + if mq.ctx.Unique == nil && mq.path != nil { + mq.Unique(true) + } + ctx = setContextOp(ctx, mq.ctx, "IDs") + if err = mq.Select(metric.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (mq *MetricQuery) IDsX(ctx context.Context) []int { + ids, err := mq.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (mq *MetricQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, mq.ctx, "Count") + if err := mq.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, mq, querierCount[*MetricQuery](), mq.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (mq *MetricQuery) CountX(ctx context.Context) int { + count, err := mq.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (mq *MetricQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, mq.ctx, "Exist") + switch _, err := mq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (mq *MetricQuery) ExistX(ctx context.Context) bool { + exist, err := mq.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the MetricQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (mq *MetricQuery) Clone() *MetricQuery { + if mq == nil { + return nil + } + return &MetricQuery{ + config: mq.config, + ctx: mq.ctx.Clone(), + order: append([]metric.OrderOption{}, mq.order...), + inters: append([]Interceptor{}, mq.inters...), + predicates: append([]predicate.Metric{}, mq.predicates...), + // clone intermediate query. + sql: mq.sql.Clone(), + path: mq.path, + } +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// GeneratedType metric.GeneratedType `json:"generated_type,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.Metric.Query(). +// GroupBy(metric.FieldGeneratedType). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (mq *MetricQuery) GroupBy(field string, fields ...string) *MetricGroupBy { + mq.ctx.Fields = append([]string{field}, fields...) + grbuild := &MetricGroupBy{build: mq} + grbuild.flds = &mq.ctx.Fields + grbuild.label = metric.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// GeneratedType metric.GeneratedType `json:"generated_type,omitempty"` +// } +// +// client.Metric.Query(). +// Select(metric.FieldGeneratedType). +// Scan(ctx, &v) +func (mq *MetricQuery) Select(fields ...string) *MetricSelect { + mq.ctx.Fields = append(mq.ctx.Fields, fields...) + sbuild := &MetricSelect{MetricQuery: mq} + sbuild.label = metric.Label + sbuild.flds, sbuild.scan = &mq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a MetricSelect configured with the given aggregations. +func (mq *MetricQuery) Aggregate(fns ...AggregateFunc) *MetricSelect { + return mq.Select().Aggregate(fns...) +} + +func (mq *MetricQuery) prepareQuery(ctx context.Context) error { + for _, inter := range mq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, mq); err != nil { + return err + } + } + } + for _, f := range mq.ctx.Fields { + if !metric.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if mq.path != nil { + prev, err := mq.path(ctx) + if err != nil { + return err + } + mq.sql = prev + } + return nil +} + +func (mq *MetricQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Metric, error) { + var ( + nodes = []*Metric{} + _spec = mq.querySpec() + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*Metric).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &Metric{config: mq.config} + nodes = append(nodes, node) + return node.assignValues(columns, values) + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, mq.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + return nodes, nil +} + +func (mq *MetricQuery) sqlCount(ctx context.Context) (int, error) { + _spec := mq.querySpec() + _spec.Node.Columns = mq.ctx.Fields + if len(mq.ctx.Fields) > 0 { + _spec.Unique = mq.ctx.Unique != nil && *mq.ctx.Unique + } + return sqlgraph.CountNodes(ctx, mq.driver, _spec) +} + +func (mq *MetricQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(metric.Table, metric.Columns, sqlgraph.NewFieldSpec(metric.FieldID, field.TypeInt)) + _spec.From = mq.sql + if unique := mq.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if mq.path != nil { + _spec.Unique = true + } + if fields := mq.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, metric.FieldID) + for i := range fields { + if fields[i] != metric.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := mq.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := mq.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := mq.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := mq.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (mq *MetricQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(mq.driver.Dialect()) + t1 := builder.Table(metric.Table) + columns := mq.ctx.Fields + if len(columns) == 0 { + columns = metric.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if mq.sql != nil { + selector = mq.sql + selector.Select(selector.Columns(columns...)...) + } + if mq.ctx.Unique != nil && *mq.ctx.Unique { + selector.Distinct() + } + for _, p := range mq.predicates { + p(selector) + } + for _, p := range mq.order { + p(selector) + } + if offset := mq.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := mq.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// MetricGroupBy is the group-by builder for Metric entities. +type MetricGroupBy struct { + selector + build *MetricQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (mgb *MetricGroupBy) Aggregate(fns ...AggregateFunc) *MetricGroupBy { + mgb.fns = append(mgb.fns, fns...) + return mgb +} + +// Scan applies the selector query and scans the result into the given value. +func (mgb *MetricGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, mgb.build.ctx, "GroupBy") + if err := mgb.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*MetricQuery, *MetricGroupBy](ctx, mgb.build, mgb, mgb.build.inters, v) +} + +func (mgb *MetricGroupBy) sqlScan(ctx context.Context, root *MetricQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(mgb.fns)) + for _, fn := range mgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*mgb.flds)+len(mgb.fns)) + for _, f := range *mgb.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*mgb.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := mgb.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// MetricSelect is the builder for selecting fields of Metric entities. +type MetricSelect struct { + *MetricQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ms *MetricSelect) Aggregate(fns ...AggregateFunc) *MetricSelect { + ms.fns = append(ms.fns, fns...) + return ms +} + +// Scan applies the selector query and scans the result into the given value. +func (ms *MetricSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ms.ctx, "Select") + if err := ms.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*MetricQuery, *MetricSelect](ctx, ms.MetricQuery, ms, ms.inters, v) +} + +func (ms *MetricSelect) sqlScan(ctx context.Context, root *MetricQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ms.fns)) + for _, fn := range ms.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ms.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := ms.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/pkg/database/ent/metric_update.go b/pkg/database/ent/metric_update.go new file mode 100644 index 00000000000..4da33dd6ce9 --- /dev/null +++ b/pkg/database/ent/metric_update.go @@ -0,0 +1,228 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" +) + +// MetricUpdate is the builder for updating Metric entities. +type MetricUpdate struct { + config + hooks []Hook + mutation *MetricMutation +} + +// Where appends a list predicates to the MetricUpdate builder. +func (mu *MetricUpdate) Where(ps ...predicate.Metric) *MetricUpdate { + mu.mutation.Where(ps...) + return mu +} + +// SetPushedAt sets the "pushed_at" field. +func (mu *MetricUpdate) SetPushedAt(t time.Time) *MetricUpdate { + mu.mutation.SetPushedAt(t) + return mu +} + +// SetNillablePushedAt sets the "pushed_at" field if the given value is not nil. +func (mu *MetricUpdate) SetNillablePushedAt(t *time.Time) *MetricUpdate { + if t != nil { + mu.SetPushedAt(*t) + } + return mu +} + +// ClearPushedAt clears the value of the "pushed_at" field. +func (mu *MetricUpdate) ClearPushedAt() *MetricUpdate { + mu.mutation.ClearPushedAt() + return mu +} + +// Mutation returns the MetricMutation object of the builder. +func (mu *MetricUpdate) Mutation() *MetricMutation { + return mu.mutation +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (mu *MetricUpdate) Save(ctx context.Context) (int, error) { + return withHooks(ctx, mu.sqlSave, mu.mutation, mu.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (mu *MetricUpdate) SaveX(ctx context.Context) int { + affected, err := mu.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (mu *MetricUpdate) Exec(ctx context.Context) error { + _, err := mu.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (mu *MetricUpdate) ExecX(ctx context.Context) { + if err := mu.Exec(ctx); err != nil { + panic(err) + } +} + +func (mu *MetricUpdate) sqlSave(ctx context.Context) (n int, err error) { + _spec := sqlgraph.NewUpdateSpec(metric.Table, metric.Columns, sqlgraph.NewFieldSpec(metric.FieldID, field.TypeInt)) + if ps := mu.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := mu.mutation.PushedAt(); ok { + _spec.SetField(metric.FieldPushedAt, field.TypeTime, value) + } + if mu.mutation.PushedAtCleared() { + _spec.ClearField(metric.FieldPushedAt, field.TypeTime) + } + if n, err = sqlgraph.UpdateNodes(ctx, mu.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{metric.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + mu.mutation.done = true + return n, nil +} + +// MetricUpdateOne is the builder for updating a single Metric entity. +type MetricUpdateOne struct { + config + fields []string + hooks []Hook + mutation *MetricMutation +} + +// SetPushedAt sets the "pushed_at" field. +func (muo *MetricUpdateOne) SetPushedAt(t time.Time) *MetricUpdateOne { + muo.mutation.SetPushedAt(t) + return muo +} + +// SetNillablePushedAt sets the "pushed_at" field if the given value is not nil. +func (muo *MetricUpdateOne) SetNillablePushedAt(t *time.Time) *MetricUpdateOne { + if t != nil { + muo.SetPushedAt(*t) + } + return muo +} + +// ClearPushedAt clears the value of the "pushed_at" field. +func (muo *MetricUpdateOne) ClearPushedAt() *MetricUpdateOne { + muo.mutation.ClearPushedAt() + return muo +} + +// Mutation returns the MetricMutation object of the builder. +func (muo *MetricUpdateOne) Mutation() *MetricMutation { + return muo.mutation +} + +// Where appends a list predicates to the MetricUpdate builder. +func (muo *MetricUpdateOne) Where(ps ...predicate.Metric) *MetricUpdateOne { + muo.mutation.Where(ps...) + return muo +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (muo *MetricUpdateOne) Select(field string, fields ...string) *MetricUpdateOne { + muo.fields = append([]string{field}, fields...) + return muo +} + +// Save executes the query and returns the updated Metric entity. +func (muo *MetricUpdateOne) Save(ctx context.Context) (*Metric, error) { + return withHooks(ctx, muo.sqlSave, muo.mutation, muo.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (muo *MetricUpdateOne) SaveX(ctx context.Context) *Metric { + node, err := muo.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (muo *MetricUpdateOne) Exec(ctx context.Context) error { + _, err := muo.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (muo *MetricUpdateOne) ExecX(ctx context.Context) { + if err := muo.Exec(ctx); err != nil { + panic(err) + } +} + +func (muo *MetricUpdateOne) sqlSave(ctx context.Context) (_node *Metric, err error) { + _spec := sqlgraph.NewUpdateSpec(metric.Table, metric.Columns, sqlgraph.NewFieldSpec(metric.FieldID, field.TypeInt)) + id, ok := muo.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Metric.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := muo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, metric.FieldID) + for _, f := range fields { + if !metric.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != metric.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := muo.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := muo.mutation.PushedAt(); ok { + _spec.SetField(metric.FieldPushedAt, field.TypeTime, value) + } + if muo.mutation.PushedAtCleared() { + _spec.ClearField(metric.FieldPushedAt, field.TypeTime) + } + _node = &Metric{config: muo.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, muo.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{metric.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + muo.mutation.done = true + return _node, nil +} diff --git a/pkg/database/ent/migrate/schema.go b/pkg/database/ent/migrate/schema.go index 375fd4e784a..986f5bc8c67 100644 --- a/pkg/database/ent/migrate/schema.go +++ b/pkg/database/ent/migrate/schema.go @@ -11,8 +11,8 @@ var ( // AlertsColumns holds the columns for the "alerts" table. AlertsColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "created_at", Type: field.TypeTime, Nullable: true}, - {Name: "updated_at", Type: field.TypeTime, Nullable: true}, + {Name: "created_at", Type: field.TypeTime}, + {Name: "updated_at", Type: field.TypeTime}, {Name: "scenario", Type: field.TypeString}, {Name: "bucket_id", Type: field.TypeString, Nullable: true, Default: ""}, {Name: "message", Type: field.TypeString, Nullable: true, Default: ""}, @@ -34,6 +34,7 @@ var ( {Name: "scenario_hash", Type: field.TypeString, Nullable: true}, {Name: "simulated", Type: field.TypeBool, Default: false}, {Name: "uuid", Type: field.TypeString, Nullable: true}, + {Name: "remediation", Type: field.TypeBool, Nullable: true}, {Name: "machine_alerts", Type: field.TypeInt, Nullable: true}, } // AlertsTable holds the schema information for the "alerts" table. @@ -44,7 +45,7 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "alerts_machines_alerts", - Columns: []*schema.Column{AlertsColumns[24]}, + Columns: []*schema.Column{AlertsColumns[25]}, RefColumns: []*schema.Column{MachinesColumns[0]}, OnDelete: schema.SetNull, }, @@ -60,17 +61,19 @@ var ( // BouncersColumns holds the columns for the "bouncers" table. BouncersColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "created_at", Type: field.TypeTime, Nullable: true}, - {Name: "updated_at", Type: field.TypeTime, Nullable: true}, + {Name: "created_at", Type: field.TypeTime}, + {Name: "updated_at", Type: field.TypeTime}, {Name: "name", Type: field.TypeString, Unique: true}, {Name: "api_key", Type: field.TypeString}, {Name: "revoked", Type: field.TypeBool}, {Name: "ip_address", Type: field.TypeString, Nullable: true, Default: ""}, {Name: "type", Type: field.TypeString, Nullable: true}, {Name: "version", Type: field.TypeString, Nullable: true}, - {Name: "until", Type: field.TypeTime, Nullable: true}, - {Name: "last_pull", Type: field.TypeTime}, + {Name: "last_pull", Type: field.TypeTime, Nullable: true}, {Name: "auth_type", Type: field.TypeString, Default: "api-key"}, + {Name: "osname", Type: field.TypeString, Nullable: true}, + {Name: "osversion", Type: field.TypeString, Nullable: true}, + {Name: "featureflags", Type: field.TypeString, Nullable: true}, } // BouncersTable holds the schema information for the "bouncers" table. BouncersTable = &schema.Table{ @@ -81,8 +84,8 @@ var ( // ConfigItemsColumns holds the columns for the "config_items" table. ConfigItemsColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "created_at", Type: field.TypeTime, Nullable: true}, - {Name: "updated_at", Type: field.TypeTime, Nullable: true}, + {Name: "created_at", Type: field.TypeTime}, + {Name: "updated_at", Type: field.TypeTime}, {Name: "name", Type: field.TypeString, Unique: true}, {Name: "value", Type: field.TypeString}, } @@ -95,8 +98,8 @@ var ( // DecisionsColumns holds the columns for the "decisions" table. DecisionsColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "created_at", Type: field.TypeTime, Nullable: true}, - {Name: "updated_at", Type: field.TypeTime, Nullable: true}, + {Name: "created_at", Type: field.TypeTime}, + {Name: "updated_at", Type: field.TypeTime}, {Name: "until", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"mysql": "datetime"}}, {Name: "scenario", Type: field.TypeString}, {Name: "type", Type: field.TypeString}, @@ -151,8 +154,8 @@ var ( // EventsColumns holds the columns for the "events" table. EventsColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "created_at", Type: field.TypeTime, Nullable: true}, - {Name: "updated_at", Type: field.TypeTime, Nullable: true}, + {Name: "created_at", Type: field.TypeTime}, + {Name: "updated_at", Type: field.TypeTime}, {Name: "time", Type: field.TypeTime}, {Name: "serialized", Type: field.TypeString, Size: 8191}, {Name: "alert_events", Type: field.TypeInt, Nullable: true}, @@ -178,11 +181,23 @@ var ( }, }, } + // LocksColumns holds the columns for the "locks" table. + LocksColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "name", Type: field.TypeString, Unique: true}, + {Name: "created_at", Type: field.TypeTime}, + } + // LocksTable holds the schema information for the "locks" table. + LocksTable = &schema.Table{ + Name: "locks", + Columns: LocksColumns, + PrimaryKey: []*schema.Column{LocksColumns[0]}, + } // MachinesColumns holds the columns for the "machines" table. MachinesColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "created_at", Type: field.TypeTime, Nullable: true}, - {Name: "updated_at", Type: field.TypeTime, Nullable: true}, + {Name: "created_at", Type: field.TypeTime}, + {Name: "updated_at", Type: field.TypeTime}, {Name: "last_push", Type: field.TypeTime, Nullable: true}, {Name: "last_heartbeat", Type: field.TypeTime, Nullable: true}, {Name: "machine_id", Type: field.TypeString, Unique: true}, @@ -191,8 +206,12 @@ var ( {Name: "scenarios", Type: field.TypeString, Nullable: true, Size: 100000}, {Name: "version", Type: field.TypeString, Nullable: true}, {Name: "is_validated", Type: field.TypeBool, Default: false}, - {Name: "status", Type: field.TypeString, Nullable: true}, {Name: "auth_type", Type: field.TypeString, Default: "password"}, + {Name: "osname", Type: field.TypeString, Nullable: true}, + {Name: "osversion", Type: field.TypeString, Nullable: true}, + {Name: "featureflags", Type: field.TypeString, Nullable: true}, + {Name: "hubstate", Type: field.TypeJSON, Nullable: true}, + {Name: "datasources", Type: field.TypeJSON, Nullable: true}, } // MachinesTable holds the schema information for the "machines" table. MachinesTable = &schema.Table{ @@ -203,8 +222,8 @@ var ( // MetaColumns holds the columns for the "meta" table. MetaColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "created_at", Type: field.TypeTime, Nullable: true}, - {Name: "updated_at", Type: field.TypeTime, Nullable: true}, + {Name: "created_at", Type: field.TypeTime}, + {Name: "updated_at", Type: field.TypeTime}, {Name: "key", Type: field.TypeString}, {Name: "value", Type: field.TypeString, Size: 4095}, {Name: "alert_metas", Type: field.TypeInt, Nullable: true}, @@ -230,6 +249,21 @@ var ( }, }, } + // MetricsColumns holds the columns for the "metrics" table. + MetricsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "generated_type", Type: field.TypeEnum, Enums: []string{"LP", "RC"}}, + {Name: "generated_by", Type: field.TypeString}, + {Name: "received_at", Type: field.TypeTime}, + {Name: "pushed_at", Type: field.TypeTime, Nullable: true}, + {Name: "payload", Type: field.TypeString, Size: 2147483647}, + } + // MetricsTable holds the schema information for the "metrics" table. + MetricsTable = &schema.Table{ + Name: "metrics", + Columns: MetricsColumns, + PrimaryKey: []*schema.Column{MetricsColumns[0]}, + } // Tables holds all the tables in the schema. Tables = []*schema.Table{ AlertsTable, @@ -237,8 +271,10 @@ var ( ConfigItemsTable, DecisionsTable, EventsTable, + LocksTable, MachinesTable, MetaTable, + MetricsTable, } ) diff --git a/pkg/database/ent/mutation.go b/pkg/database/ent/mutation.go index 907c1ef015e..5c6596f3db4 100644 --- a/pkg/database/ent/mutation.go +++ b/pkg/database/ent/mutation.go @@ -9,16 +9,19 @@ import ( "sync" "time" + "entgo.io/ent" + "entgo.io/ent/dialect/sql" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" "github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer" "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" "github.com/crowdsecurity/crowdsec/pkg/database/ent/event" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" "github.com/crowdsecurity/crowdsec/pkg/database/ent/meta" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" - - "entgo.io/ent" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/schema" ) const ( @@ -35,8 +38,10 @@ const ( TypeConfigItem = "ConfigItem" TypeDecision = "Decision" TypeEvent = "Event" + TypeLock = "Lock" TypeMachine = "Machine" TypeMeta = "Meta" + TypeMetric = "Metric" ) // AlertMutation represents an operation that mutates the Alert nodes in the graph. @@ -72,6 +77,7 @@ type AlertMutation struct { scenarioHash *string simulated *bool uuid *string + remediation *bool clearedFields map[string]struct{} owner *int clearedowner bool @@ -204,7 +210,7 @@ func (m *AlertMutation) CreatedAt() (r time.Time, exists bool) { // OldCreatedAt returns the old "created_at" field's value of the Alert entity. // If the Alert object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *AlertMutation) OldCreatedAt(ctx context.Context) (v *time.Time, err error) { +func (m *AlertMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } @@ -218,22 +224,9 @@ func (m *AlertMutation) OldCreatedAt(ctx context.Context) (v *time.Time, err err return oldValue.CreatedAt, nil } -// ClearCreatedAt clears the value of the "created_at" field. -func (m *AlertMutation) ClearCreatedAt() { - m.created_at = nil - m.clearedFields[alert.FieldCreatedAt] = struct{}{} -} - -// CreatedAtCleared returns if the "created_at" field was cleared in this mutation. -func (m *AlertMutation) CreatedAtCleared() bool { - _, ok := m.clearedFields[alert.FieldCreatedAt] - return ok -} - // ResetCreatedAt resets all changes to the "created_at" field. func (m *AlertMutation) ResetCreatedAt() { m.created_at = nil - delete(m.clearedFields, alert.FieldCreatedAt) } // SetUpdatedAt sets the "updated_at" field. @@ -253,7 +246,7 @@ func (m *AlertMutation) UpdatedAt() (r time.Time, exists bool) { // OldUpdatedAt returns the old "updated_at" field's value of the Alert entity. // If the Alert object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *AlertMutation) OldUpdatedAt(ctx context.Context) (v *time.Time, err error) { +func (m *AlertMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } @@ -267,22 +260,9 @@ func (m *AlertMutation) OldUpdatedAt(ctx context.Context) (v *time.Time, err err return oldValue.UpdatedAt, nil } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (m *AlertMutation) ClearUpdatedAt() { - m.updated_at = nil - m.clearedFields[alert.FieldUpdatedAt] = struct{}{} -} - -// UpdatedAtCleared returns if the "updated_at" field was cleared in this mutation. -func (m *AlertMutation) UpdatedAtCleared() bool { - _, ok := m.clearedFields[alert.FieldUpdatedAt] - return ok -} - // ResetUpdatedAt resets all changes to the "updated_at" field. func (m *AlertMutation) ResetUpdatedAt() { m.updated_at = nil - delete(m.clearedFields, alert.FieldUpdatedAt) } // SetScenario sets the "scenario" field. @@ -1372,6 +1352,55 @@ func (m *AlertMutation) ResetUUID() { delete(m.clearedFields, alert.FieldUUID) } +// SetRemediation sets the "remediation" field. +func (m *AlertMutation) SetRemediation(b bool) { + m.remediation = &b +} + +// Remediation returns the value of the "remediation" field in the mutation. +func (m *AlertMutation) Remediation() (r bool, exists bool) { + v := m.remediation + if v == nil { + return + } + return *v, true +} + +// OldRemediation returns the old "remediation" field's value of the Alert entity. +// If the Alert object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AlertMutation) OldRemediation(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRemediation is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRemediation requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRemediation: %w", err) + } + return oldValue.Remediation, nil +} + +// ClearRemediation clears the value of the "remediation" field. +func (m *AlertMutation) ClearRemediation() { + m.remediation = nil + m.clearedFields[alert.FieldRemediation] = struct{}{} +} + +// RemediationCleared returns if the "remediation" field was cleared in this mutation. +func (m *AlertMutation) RemediationCleared() bool { + _, ok := m.clearedFields[alert.FieldRemediation] + return ok +} + +// ResetRemediation resets all changes to the "remediation" field. +func (m *AlertMutation) ResetRemediation() { + m.remediation = nil + delete(m.clearedFields, alert.FieldRemediation) +} + // SetOwnerID sets the "owner" edge to the Machine entity by id. func (m *AlertMutation) SetOwnerID(id int) { m.owner = &id @@ -1578,11 +1607,26 @@ func (m *AlertMutation) Where(ps ...predicate.Alert) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the AlertMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *AlertMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Alert, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *AlertMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *AlertMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (Alert). func (m *AlertMutation) Type() string { return m.typ @@ -1592,7 +1636,7 @@ func (m *AlertMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *AlertMutation) Fields() []string { - fields := make([]string, 0, 23) + fields := make([]string, 0, 24) if m.created_at != nil { fields = append(fields, alert.FieldCreatedAt) } @@ -1662,6 +1706,9 @@ func (m *AlertMutation) Fields() []string { if m.uuid != nil { fields = append(fields, alert.FieldUUID) } + if m.remediation != nil { + fields = append(fields, alert.FieldRemediation) + } return fields } @@ -1716,6 +1763,8 @@ func (m *AlertMutation) Field(name string) (ent.Value, bool) { return m.Simulated() case alert.FieldUUID: return m.UUID() + case alert.FieldRemediation: + return m.Remediation() } return nil, false } @@ -1771,6 +1820,8 @@ func (m *AlertMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldSimulated(ctx) case alert.FieldUUID: return m.OldUUID(ctx) + case alert.FieldRemediation: + return m.OldRemediation(ctx) } return nil, fmt.Errorf("unknown Alert field %s", name) } @@ -1941,6 +1992,13 @@ func (m *AlertMutation) SetField(name string, value ent.Value) error { } m.SetUUID(v) return nil + case alert.FieldRemediation: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRemediation(v) + return nil } return fmt.Errorf("unknown Alert field %s", name) } @@ -2022,12 +2080,6 @@ func (m *AlertMutation) AddField(name string, value ent.Value) error { // mutation. func (m *AlertMutation) ClearedFields() []string { var fields []string - if m.FieldCleared(alert.FieldCreatedAt) { - fields = append(fields, alert.FieldCreatedAt) - } - if m.FieldCleared(alert.FieldUpdatedAt) { - fields = append(fields, alert.FieldUpdatedAt) - } if m.FieldCleared(alert.FieldBucketId) { fields = append(fields, alert.FieldBucketId) } @@ -2085,6 +2137,9 @@ func (m *AlertMutation) ClearedFields() []string { if m.FieldCleared(alert.FieldUUID) { fields = append(fields, alert.FieldUUID) } + if m.FieldCleared(alert.FieldRemediation) { + fields = append(fields, alert.FieldRemediation) + } return fields } @@ -2099,12 +2154,6 @@ func (m *AlertMutation) FieldCleared(name string) bool { // error if the field is not defined in the schema. func (m *AlertMutation) ClearField(name string) error { switch name { - case alert.FieldCreatedAt: - m.ClearCreatedAt() - return nil - case alert.FieldUpdatedAt: - m.ClearUpdatedAt() - return nil case alert.FieldBucketId: m.ClearBucketId() return nil @@ -2162,6 +2211,9 @@ func (m *AlertMutation) ClearField(name string) error { case alert.FieldUUID: m.ClearUUID() return nil + case alert.FieldRemediation: + m.ClearRemediation() + return nil } return fmt.Errorf("unknown Alert nullable field %s", name) } @@ -2239,6 +2291,9 @@ func (m *AlertMutation) ResetField(name string) error { case alert.FieldUUID: m.ResetUUID() return nil + case alert.FieldRemediation: + m.ResetRemediation() + return nil } return fmt.Errorf("unknown Alert field %s", name) } @@ -2411,9 +2466,11 @@ type BouncerMutation struct { ip_address *string _type *string version *string - until *time.Time last_pull *time.Time auth_type *string + osname *string + osversion *string + featureflags *string clearedFields map[string]struct{} done bool oldValue func(context.Context) (*Bouncer, error) @@ -2535,7 +2592,7 @@ func (m *BouncerMutation) CreatedAt() (r time.Time, exists bool) { // OldCreatedAt returns the old "created_at" field's value of the Bouncer entity. // If the Bouncer object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *BouncerMutation) OldCreatedAt(ctx context.Context) (v *time.Time, err error) { +func (m *BouncerMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } @@ -2549,22 +2606,9 @@ func (m *BouncerMutation) OldCreatedAt(ctx context.Context) (v *time.Time, err e return oldValue.CreatedAt, nil } -// ClearCreatedAt clears the value of the "created_at" field. -func (m *BouncerMutation) ClearCreatedAt() { - m.created_at = nil - m.clearedFields[bouncer.FieldCreatedAt] = struct{}{} -} - -// CreatedAtCleared returns if the "created_at" field was cleared in this mutation. -func (m *BouncerMutation) CreatedAtCleared() bool { - _, ok := m.clearedFields[bouncer.FieldCreatedAt] - return ok -} - // ResetCreatedAt resets all changes to the "created_at" field. func (m *BouncerMutation) ResetCreatedAt() { m.created_at = nil - delete(m.clearedFields, bouncer.FieldCreatedAt) } // SetUpdatedAt sets the "updated_at" field. @@ -2584,7 +2628,7 @@ func (m *BouncerMutation) UpdatedAt() (r time.Time, exists bool) { // OldUpdatedAt returns the old "updated_at" field's value of the Bouncer entity. // If the Bouncer object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *BouncerMutation) OldUpdatedAt(ctx context.Context) (v *time.Time, err error) { +func (m *BouncerMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } @@ -2598,22 +2642,9 @@ func (m *BouncerMutation) OldUpdatedAt(ctx context.Context) (v *time.Time, err e return oldValue.UpdatedAt, nil } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (m *BouncerMutation) ClearUpdatedAt() { - m.updated_at = nil - m.clearedFields[bouncer.FieldUpdatedAt] = struct{}{} -} - -// UpdatedAtCleared returns if the "updated_at" field was cleared in this mutation. -func (m *BouncerMutation) UpdatedAtCleared() bool { - _, ok := m.clearedFields[bouncer.FieldUpdatedAt] - return ok -} - // ResetUpdatedAt resets all changes to the "updated_at" field. func (m *BouncerMutation) ResetUpdatedAt() { m.updated_at = nil - delete(m.clearedFields, bouncer.FieldUpdatedAt) } // SetName sets the "name" field. @@ -2871,55 +2902,6 @@ func (m *BouncerMutation) ResetVersion() { delete(m.clearedFields, bouncer.FieldVersion) } -// SetUntil sets the "until" field. -func (m *BouncerMutation) SetUntil(t time.Time) { - m.until = &t -} - -// Until returns the value of the "until" field in the mutation. -func (m *BouncerMutation) Until() (r time.Time, exists bool) { - v := m.until - if v == nil { - return - } - return *v, true -} - -// OldUntil returns the old "until" field's value of the Bouncer entity. -// If the Bouncer object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *BouncerMutation) OldUntil(ctx context.Context) (v time.Time, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldUntil is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldUntil requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldUntil: %w", err) - } - return oldValue.Until, nil -} - -// ClearUntil clears the value of the "until" field. -func (m *BouncerMutation) ClearUntil() { - m.until = nil - m.clearedFields[bouncer.FieldUntil] = struct{}{} -} - -// UntilCleared returns if the "until" field was cleared in this mutation. -func (m *BouncerMutation) UntilCleared() bool { - _, ok := m.clearedFields[bouncer.FieldUntil] - return ok -} - -// ResetUntil resets all changes to the "until" field. -func (m *BouncerMutation) ResetUntil() { - m.until = nil - delete(m.clearedFields, bouncer.FieldUntil) -} - // SetLastPull sets the "last_pull" field. func (m *BouncerMutation) SetLastPull(t time.Time) { m.last_pull = &t @@ -2937,7 +2919,7 @@ func (m *BouncerMutation) LastPull() (r time.Time, exists bool) { // OldLastPull returns the old "last_pull" field's value of the Bouncer entity. // If the Bouncer object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *BouncerMutation) OldLastPull(ctx context.Context) (v time.Time, err error) { +func (m *BouncerMutation) OldLastPull(ctx context.Context) (v *time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldLastPull is only allowed on UpdateOne operations") } @@ -2951,9 +2933,22 @@ func (m *BouncerMutation) OldLastPull(ctx context.Context) (v time.Time, err err return oldValue.LastPull, nil } +// ClearLastPull clears the value of the "last_pull" field. +func (m *BouncerMutation) ClearLastPull() { + m.last_pull = nil + m.clearedFields[bouncer.FieldLastPull] = struct{}{} +} + +// LastPullCleared returns if the "last_pull" field was cleared in this mutation. +func (m *BouncerMutation) LastPullCleared() bool { + _, ok := m.clearedFields[bouncer.FieldLastPull] + return ok +} + // ResetLastPull resets all changes to the "last_pull" field. func (m *BouncerMutation) ResetLastPull() { m.last_pull = nil + delete(m.clearedFields, bouncer.FieldLastPull) } // SetAuthType sets the "auth_type" field. @@ -2992,16 +2987,178 @@ func (m *BouncerMutation) ResetAuthType() { m.auth_type = nil } +// SetOsname sets the "osname" field. +func (m *BouncerMutation) SetOsname(s string) { + m.osname = &s +} + +// Osname returns the value of the "osname" field in the mutation. +func (m *BouncerMutation) Osname() (r string, exists bool) { + v := m.osname + if v == nil { + return + } + return *v, true +} + +// OldOsname returns the old "osname" field's value of the Bouncer entity. +// If the Bouncer object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *BouncerMutation) OldOsname(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldOsname is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldOsname requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldOsname: %w", err) + } + return oldValue.Osname, nil +} + +// ClearOsname clears the value of the "osname" field. +func (m *BouncerMutation) ClearOsname() { + m.osname = nil + m.clearedFields[bouncer.FieldOsname] = struct{}{} +} + +// OsnameCleared returns if the "osname" field was cleared in this mutation. +func (m *BouncerMutation) OsnameCleared() bool { + _, ok := m.clearedFields[bouncer.FieldOsname] + return ok +} + +// ResetOsname resets all changes to the "osname" field. +func (m *BouncerMutation) ResetOsname() { + m.osname = nil + delete(m.clearedFields, bouncer.FieldOsname) +} + +// SetOsversion sets the "osversion" field. +func (m *BouncerMutation) SetOsversion(s string) { + m.osversion = &s +} + +// Osversion returns the value of the "osversion" field in the mutation. +func (m *BouncerMutation) Osversion() (r string, exists bool) { + v := m.osversion + if v == nil { + return + } + return *v, true +} + +// OldOsversion returns the old "osversion" field's value of the Bouncer entity. +// If the Bouncer object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *BouncerMutation) OldOsversion(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldOsversion is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldOsversion requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldOsversion: %w", err) + } + return oldValue.Osversion, nil +} + +// ClearOsversion clears the value of the "osversion" field. +func (m *BouncerMutation) ClearOsversion() { + m.osversion = nil + m.clearedFields[bouncer.FieldOsversion] = struct{}{} +} + +// OsversionCleared returns if the "osversion" field was cleared in this mutation. +func (m *BouncerMutation) OsversionCleared() bool { + _, ok := m.clearedFields[bouncer.FieldOsversion] + return ok +} + +// ResetOsversion resets all changes to the "osversion" field. +func (m *BouncerMutation) ResetOsversion() { + m.osversion = nil + delete(m.clearedFields, bouncer.FieldOsversion) +} + +// SetFeatureflags sets the "featureflags" field. +func (m *BouncerMutation) SetFeatureflags(s string) { + m.featureflags = &s +} + +// Featureflags returns the value of the "featureflags" field in the mutation. +func (m *BouncerMutation) Featureflags() (r string, exists bool) { + v := m.featureflags + if v == nil { + return + } + return *v, true +} + +// OldFeatureflags returns the old "featureflags" field's value of the Bouncer entity. +// If the Bouncer object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *BouncerMutation) OldFeatureflags(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFeatureflags is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFeatureflags requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFeatureflags: %w", err) + } + return oldValue.Featureflags, nil +} + +// ClearFeatureflags clears the value of the "featureflags" field. +func (m *BouncerMutation) ClearFeatureflags() { + m.featureflags = nil + m.clearedFields[bouncer.FieldFeatureflags] = struct{}{} +} + +// FeatureflagsCleared returns if the "featureflags" field was cleared in this mutation. +func (m *BouncerMutation) FeatureflagsCleared() bool { + _, ok := m.clearedFields[bouncer.FieldFeatureflags] + return ok +} + +// ResetFeatureflags resets all changes to the "featureflags" field. +func (m *BouncerMutation) ResetFeatureflags() { + m.featureflags = nil + delete(m.clearedFields, bouncer.FieldFeatureflags) +} + // Where appends a list predicates to the BouncerMutation builder. func (m *BouncerMutation) Where(ps ...predicate.Bouncer) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the BouncerMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *BouncerMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Bouncer, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *BouncerMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *BouncerMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (Bouncer). func (m *BouncerMutation) Type() string { return m.typ @@ -3011,7 +3168,7 @@ func (m *BouncerMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *BouncerMutation) Fields() []string { - fields := make([]string, 0, 11) + fields := make([]string, 0, 13) if m.created_at != nil { fields = append(fields, bouncer.FieldCreatedAt) } @@ -3036,15 +3193,21 @@ func (m *BouncerMutation) Fields() []string { if m.version != nil { fields = append(fields, bouncer.FieldVersion) } - if m.until != nil { - fields = append(fields, bouncer.FieldUntil) - } if m.last_pull != nil { fields = append(fields, bouncer.FieldLastPull) } if m.auth_type != nil { fields = append(fields, bouncer.FieldAuthType) } + if m.osname != nil { + fields = append(fields, bouncer.FieldOsname) + } + if m.osversion != nil { + fields = append(fields, bouncer.FieldOsversion) + } + if m.featureflags != nil { + fields = append(fields, bouncer.FieldFeatureflags) + } return fields } @@ -3069,12 +3232,16 @@ func (m *BouncerMutation) Field(name string) (ent.Value, bool) { return m.GetType() case bouncer.FieldVersion: return m.Version() - case bouncer.FieldUntil: - return m.Until() case bouncer.FieldLastPull: return m.LastPull() case bouncer.FieldAuthType: return m.AuthType() + case bouncer.FieldOsname: + return m.Osname() + case bouncer.FieldOsversion: + return m.Osversion() + case bouncer.FieldFeatureflags: + return m.Featureflags() } return nil, false } @@ -3100,12 +3267,16 @@ func (m *BouncerMutation) OldField(ctx context.Context, name string) (ent.Value, return m.OldType(ctx) case bouncer.FieldVersion: return m.OldVersion(ctx) - case bouncer.FieldUntil: - return m.OldUntil(ctx) case bouncer.FieldLastPull: return m.OldLastPull(ctx) case bouncer.FieldAuthType: return m.OldAuthType(ctx) + case bouncer.FieldOsname: + return m.OldOsname(ctx) + case bouncer.FieldOsversion: + return m.OldOsversion(ctx) + case bouncer.FieldFeatureflags: + return m.OldFeatureflags(ctx) } return nil, fmt.Errorf("unknown Bouncer field %s", name) } @@ -3171,13 +3342,6 @@ func (m *BouncerMutation) SetField(name string, value ent.Value) error { } m.SetVersion(v) return nil - case bouncer.FieldUntil: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetUntil(v) - return nil case bouncer.FieldLastPull: v, ok := value.(time.Time) if !ok { @@ -3192,6 +3356,27 @@ func (m *BouncerMutation) SetField(name string, value ent.Value) error { } m.SetAuthType(v) return nil + case bouncer.FieldOsname: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOsname(v) + return nil + case bouncer.FieldOsversion: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOsversion(v) + return nil + case bouncer.FieldFeatureflags: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFeatureflags(v) + return nil } return fmt.Errorf("unknown Bouncer field %s", name) } @@ -3222,12 +3407,6 @@ func (m *BouncerMutation) AddField(name string, value ent.Value) error { // mutation. func (m *BouncerMutation) ClearedFields() []string { var fields []string - if m.FieldCleared(bouncer.FieldCreatedAt) { - fields = append(fields, bouncer.FieldCreatedAt) - } - if m.FieldCleared(bouncer.FieldUpdatedAt) { - fields = append(fields, bouncer.FieldUpdatedAt) - } if m.FieldCleared(bouncer.FieldIPAddress) { fields = append(fields, bouncer.FieldIPAddress) } @@ -3237,8 +3416,17 @@ func (m *BouncerMutation) ClearedFields() []string { if m.FieldCleared(bouncer.FieldVersion) { fields = append(fields, bouncer.FieldVersion) } - if m.FieldCleared(bouncer.FieldUntil) { - fields = append(fields, bouncer.FieldUntil) + if m.FieldCleared(bouncer.FieldLastPull) { + fields = append(fields, bouncer.FieldLastPull) + } + if m.FieldCleared(bouncer.FieldOsname) { + fields = append(fields, bouncer.FieldOsname) + } + if m.FieldCleared(bouncer.FieldOsversion) { + fields = append(fields, bouncer.FieldOsversion) + } + if m.FieldCleared(bouncer.FieldFeatureflags) { + fields = append(fields, bouncer.FieldFeatureflags) } return fields } @@ -3254,12 +3442,6 @@ func (m *BouncerMutation) FieldCleared(name string) bool { // error if the field is not defined in the schema. func (m *BouncerMutation) ClearField(name string) error { switch name { - case bouncer.FieldCreatedAt: - m.ClearCreatedAt() - return nil - case bouncer.FieldUpdatedAt: - m.ClearUpdatedAt() - return nil case bouncer.FieldIPAddress: m.ClearIPAddress() return nil @@ -3269,8 +3451,17 @@ func (m *BouncerMutation) ClearField(name string) error { case bouncer.FieldVersion: m.ClearVersion() return nil - case bouncer.FieldUntil: - m.ClearUntil() + case bouncer.FieldLastPull: + m.ClearLastPull() + return nil + case bouncer.FieldOsname: + m.ClearOsname() + return nil + case bouncer.FieldOsversion: + m.ClearOsversion() + return nil + case bouncer.FieldFeatureflags: + m.ClearFeatureflags() return nil } return fmt.Errorf("unknown Bouncer nullable field %s", name) @@ -3304,15 +3495,21 @@ func (m *BouncerMutation) ResetField(name string) error { case bouncer.FieldVersion: m.ResetVersion() return nil - case bouncer.FieldUntil: - m.ResetUntil() - return nil case bouncer.FieldLastPull: m.ResetLastPull() return nil case bouncer.FieldAuthType: m.ResetAuthType() return nil + case bouncer.FieldOsname: + m.ResetOsname() + return nil + case bouncer.FieldOsversion: + m.ResetOsversion() + return nil + case bouncer.FieldFeatureflags: + m.ResetFeatureflags() + return nil } return fmt.Errorf("unknown Bouncer field %s", name) } @@ -3496,7 +3693,7 @@ func (m *ConfigItemMutation) CreatedAt() (r time.Time, exists bool) { // OldCreatedAt returns the old "created_at" field's value of the ConfigItem entity. // If the ConfigItem object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ConfigItemMutation) OldCreatedAt(ctx context.Context) (v *time.Time, err error) { +func (m *ConfigItemMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } @@ -3510,22 +3707,9 @@ func (m *ConfigItemMutation) OldCreatedAt(ctx context.Context) (v *time.Time, er return oldValue.CreatedAt, nil } -// ClearCreatedAt clears the value of the "created_at" field. -func (m *ConfigItemMutation) ClearCreatedAt() { +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *ConfigItemMutation) ResetCreatedAt() { m.created_at = nil - m.clearedFields[configitem.FieldCreatedAt] = struct{}{} -} - -// CreatedAtCleared returns if the "created_at" field was cleared in this mutation. -func (m *ConfigItemMutation) CreatedAtCleared() bool { - _, ok := m.clearedFields[configitem.FieldCreatedAt] - return ok -} - -// ResetCreatedAt resets all changes to the "created_at" field. -func (m *ConfigItemMutation) ResetCreatedAt() { - m.created_at = nil - delete(m.clearedFields, configitem.FieldCreatedAt) } // SetUpdatedAt sets the "updated_at" field. @@ -3545,7 +3729,7 @@ func (m *ConfigItemMutation) UpdatedAt() (r time.Time, exists bool) { // OldUpdatedAt returns the old "updated_at" field's value of the ConfigItem entity. // If the ConfigItem object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ConfigItemMutation) OldUpdatedAt(ctx context.Context) (v *time.Time, err error) { +func (m *ConfigItemMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } @@ -3559,22 +3743,9 @@ func (m *ConfigItemMutation) OldUpdatedAt(ctx context.Context) (v *time.Time, er return oldValue.UpdatedAt, nil } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (m *ConfigItemMutation) ClearUpdatedAt() { - m.updated_at = nil - m.clearedFields[configitem.FieldUpdatedAt] = struct{}{} -} - -// UpdatedAtCleared returns if the "updated_at" field was cleared in this mutation. -func (m *ConfigItemMutation) UpdatedAtCleared() bool { - _, ok := m.clearedFields[configitem.FieldUpdatedAt] - return ok -} - // ResetUpdatedAt resets all changes to the "updated_at" field. func (m *ConfigItemMutation) ResetUpdatedAt() { m.updated_at = nil - delete(m.clearedFields, configitem.FieldUpdatedAt) } // SetName sets the "name" field. @@ -3654,11 +3825,26 @@ func (m *ConfigItemMutation) Where(ps ...predicate.ConfigItem) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the ConfigItemMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *ConfigItemMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.ConfigItem, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *ConfigItemMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *ConfigItemMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (ConfigItem). func (m *ConfigItemMutation) Type() string { return m.typ @@ -3780,14 +3966,7 @@ func (m *ConfigItemMutation) AddField(name string, value ent.Value) error { // ClearedFields returns all nullable fields that were cleared during this // mutation. func (m *ConfigItemMutation) ClearedFields() []string { - var fields []string - if m.FieldCleared(configitem.FieldCreatedAt) { - fields = append(fields, configitem.FieldCreatedAt) - } - if m.FieldCleared(configitem.FieldUpdatedAt) { - fields = append(fields, configitem.FieldUpdatedAt) - } - return fields + return nil } // FieldCleared returns a boolean indicating if a field with the given name was @@ -3800,14 +3979,6 @@ func (m *ConfigItemMutation) FieldCleared(name string) bool { // ClearField clears the value of the field with the given name. It returns an // error if the field is not defined in the schema. func (m *ConfigItemMutation) ClearField(name string) error { - switch name { - case configitem.FieldCreatedAt: - m.ClearCreatedAt() - return nil - case configitem.FieldUpdatedAt: - m.ClearUpdatedAt() - return nil - } return fmt.Errorf("unknown ConfigItem nullable field %s", name) } @@ -4028,7 +4199,7 @@ func (m *DecisionMutation) CreatedAt() (r time.Time, exists bool) { // OldCreatedAt returns the old "created_at" field's value of the Decision entity. // If the Decision object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *DecisionMutation) OldCreatedAt(ctx context.Context) (v *time.Time, err error) { +func (m *DecisionMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } @@ -4042,22 +4213,9 @@ func (m *DecisionMutation) OldCreatedAt(ctx context.Context) (v *time.Time, err return oldValue.CreatedAt, nil } -// ClearCreatedAt clears the value of the "created_at" field. -func (m *DecisionMutation) ClearCreatedAt() { - m.created_at = nil - m.clearedFields[decision.FieldCreatedAt] = struct{}{} -} - -// CreatedAtCleared returns if the "created_at" field was cleared in this mutation. -func (m *DecisionMutation) CreatedAtCleared() bool { - _, ok := m.clearedFields[decision.FieldCreatedAt] - return ok -} - // ResetCreatedAt resets all changes to the "created_at" field. func (m *DecisionMutation) ResetCreatedAt() { m.created_at = nil - delete(m.clearedFields, decision.FieldCreatedAt) } // SetUpdatedAt sets the "updated_at" field. @@ -4077,7 +4235,7 @@ func (m *DecisionMutation) UpdatedAt() (r time.Time, exists bool) { // OldUpdatedAt returns the old "updated_at" field's value of the Decision entity. // If the Decision object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *DecisionMutation) OldUpdatedAt(ctx context.Context) (v *time.Time, err error) { +func (m *DecisionMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } @@ -4091,22 +4249,9 @@ func (m *DecisionMutation) OldUpdatedAt(ctx context.Context) (v *time.Time, err return oldValue.UpdatedAt, nil } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (m *DecisionMutation) ClearUpdatedAt() { - m.updated_at = nil - m.clearedFields[decision.FieldUpdatedAt] = struct{}{} -} - -// UpdatedAtCleared returns if the "updated_at" field was cleared in this mutation. -func (m *DecisionMutation) UpdatedAtCleared() bool { - _, ok := m.clearedFields[decision.FieldUpdatedAt] - return ok -} - // ResetUpdatedAt resets all changes to the "updated_at" field. func (m *DecisionMutation) ResetUpdatedAt() { m.updated_at = nil - delete(m.clearedFields, decision.FieldUpdatedAt) } // SetUntil sets the "until" field. @@ -4830,6 +4975,7 @@ func (m *DecisionMutation) SetOwnerID(id int) { // ClearOwner clears the "owner" edge to the Alert entity. func (m *DecisionMutation) ClearOwner() { m.clearedowner = true + m.clearedFields[decision.FieldAlertDecisions] = struct{}{} } // OwnerCleared reports if the "owner" edge to the Alert entity was cleared. @@ -4866,11 +5012,26 @@ func (m *DecisionMutation) Where(ps ...predicate.Decision) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the DecisionMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *DecisionMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Decision, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *DecisionMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *DecisionMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (Decision). func (m *DecisionMutation) Type() string { return m.typ @@ -5224,12 +5385,6 @@ func (m *DecisionMutation) AddField(name string, value ent.Value) error { // mutation. func (m *DecisionMutation) ClearedFields() []string { var fields []string - if m.FieldCleared(decision.FieldCreatedAt) { - fields = append(fields, decision.FieldCreatedAt) - } - if m.FieldCleared(decision.FieldUpdatedAt) { - fields = append(fields, decision.FieldUpdatedAt) - } if m.FieldCleared(decision.FieldUntil) { fields = append(fields, decision.FieldUntil) } @@ -5268,12 +5423,6 @@ func (m *DecisionMutation) FieldCleared(name string) bool { // error if the field is not defined in the schema. func (m *DecisionMutation) ClearField(name string) error { switch name { - case decision.FieldCreatedAt: - m.ClearCreatedAt() - return nil - case decision.FieldUpdatedAt: - m.ClearUpdatedAt() - return nil case decision.FieldUntil: m.ClearUntil() return nil @@ -5565,7 +5714,7 @@ func (m *EventMutation) CreatedAt() (r time.Time, exists bool) { // OldCreatedAt returns the old "created_at" field's value of the Event entity. // If the Event object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *EventMutation) OldCreatedAt(ctx context.Context) (v *time.Time, err error) { +func (m *EventMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } @@ -5579,22 +5728,9 @@ func (m *EventMutation) OldCreatedAt(ctx context.Context) (v *time.Time, err err return oldValue.CreatedAt, nil } -// ClearCreatedAt clears the value of the "created_at" field. -func (m *EventMutation) ClearCreatedAt() { - m.created_at = nil - m.clearedFields[event.FieldCreatedAt] = struct{}{} -} - -// CreatedAtCleared returns if the "created_at" field was cleared in this mutation. -func (m *EventMutation) CreatedAtCleared() bool { - _, ok := m.clearedFields[event.FieldCreatedAt] - return ok -} - // ResetCreatedAt resets all changes to the "created_at" field. func (m *EventMutation) ResetCreatedAt() { m.created_at = nil - delete(m.clearedFields, event.FieldCreatedAt) } // SetUpdatedAt sets the "updated_at" field. @@ -5614,7 +5750,7 @@ func (m *EventMutation) UpdatedAt() (r time.Time, exists bool) { // OldUpdatedAt returns the old "updated_at" field's value of the Event entity. // If the Event object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *EventMutation) OldUpdatedAt(ctx context.Context) (v *time.Time, err error) { +func (m *EventMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } @@ -5628,22 +5764,9 @@ func (m *EventMutation) OldUpdatedAt(ctx context.Context) (v *time.Time, err err return oldValue.UpdatedAt, nil } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (m *EventMutation) ClearUpdatedAt() { - m.updated_at = nil - m.clearedFields[event.FieldUpdatedAt] = struct{}{} -} - -// UpdatedAtCleared returns if the "updated_at" field was cleared in this mutation. -func (m *EventMutation) UpdatedAtCleared() bool { - _, ok := m.clearedFields[event.FieldUpdatedAt] - return ok -} - // ResetUpdatedAt resets all changes to the "updated_at" field. func (m *EventMutation) ResetUpdatedAt() { m.updated_at = nil - delete(m.clearedFields, event.FieldUpdatedAt) } // SetTime sets the "time" field. @@ -5775,6 +5898,7 @@ func (m *EventMutation) SetOwnerID(id int) { // ClearOwner clears the "owner" edge to the Alert entity. func (m *EventMutation) ClearOwner() { m.clearedowner = true + m.clearedFields[event.FieldAlertEvents] = struct{}{} } // OwnerCleared reports if the "owner" edge to the Alert entity was cleared. @@ -5811,11 +5935,26 @@ func (m *EventMutation) Where(ps ...predicate.Event) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the EventMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *EventMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Event, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *EventMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *EventMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (Event). func (m *EventMutation) Type() string { return m.typ @@ -5955,12 +6094,6 @@ func (m *EventMutation) AddField(name string, value ent.Value) error { // mutation. func (m *EventMutation) ClearedFields() []string { var fields []string - if m.FieldCleared(event.FieldCreatedAt) { - fields = append(fields, event.FieldCreatedAt) - } - if m.FieldCleared(event.FieldUpdatedAt) { - fields = append(fields, event.FieldUpdatedAt) - } if m.FieldCleared(event.FieldAlertEvents) { fields = append(fields, event.FieldAlertEvents) } @@ -5978,12 +6111,6 @@ func (m *EventMutation) FieldCleared(name string) bool { // error if the field is not defined in the schema. func (m *EventMutation) ClearField(name string) error { switch name { - case event.FieldCreatedAt: - m.ClearCreatedAt() - return nil - case event.FieldUpdatedAt: - m.ClearUpdatedAt() - return nil case event.FieldAlertEvents: m.ClearAlertEvents() return nil @@ -6088,44 +6215,31 @@ func (m *EventMutation) ResetEdge(name string) error { return fmt.Errorf("unknown Event edge %s", name) } -// MachineMutation represents an operation that mutates the Machine nodes in the graph. -type MachineMutation struct { +// LockMutation represents an operation that mutates the Lock nodes in the graph. +type LockMutation struct { config - op Op - typ string - id *int - created_at *time.Time - updated_at *time.Time - last_push *time.Time - last_heartbeat *time.Time - machineId *string - password *string - ipAddress *string - scenarios *string - version *string - isValidated *bool - status *string - auth_type *string - clearedFields map[string]struct{} - alerts map[int]struct{} - removedalerts map[int]struct{} - clearedalerts bool - done bool - oldValue func(context.Context) (*Machine, error) - predicates []predicate.Machine + op Op + typ string + id *int + name *string + created_at *time.Time + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*Lock, error) + predicates []predicate.Lock } -var _ ent.Mutation = (*MachineMutation)(nil) +var _ ent.Mutation = (*LockMutation)(nil) -// machineOption allows management of the mutation configuration using functional options. -type machineOption func(*MachineMutation) +// lockOption allows management of the mutation configuration using functional options. +type lockOption func(*LockMutation) -// newMachineMutation creates new mutation for the Machine entity. -func newMachineMutation(c config, op Op, opts ...machineOption) *MachineMutation { - m := &MachineMutation{ +// newLockMutation creates new mutation for the Lock entity. +func newLockMutation(c config, op Op, opts ...lockOption) *LockMutation { + m := &LockMutation{ config: c, op: op, - typ: TypeMachine, + typ: TypeLock, clearedFields: make(map[string]struct{}), } for _, opt := range opts { @@ -6134,20 +6248,20 @@ func newMachineMutation(c config, op Op, opts ...machineOption) *MachineMutation return m } -// withMachineID sets the ID field of the mutation. -func withMachineID(id int) machineOption { - return func(m *MachineMutation) { +// withLockID sets the ID field of the mutation. +func withLockID(id int) lockOption { + return func(m *LockMutation) { var ( err error once sync.Once - value *Machine + value *Lock ) - m.oldValue = func(ctx context.Context) (*Machine, error) { + m.oldValue = func(ctx context.Context) (*Lock, error) { once.Do(func() { if m.done { err = errors.New("querying old values post mutation is not allowed") } else { - value, err = m.Client().Machine.Get(ctx, id) + value, err = m.Client().Lock.Get(ctx, id) } }) return value, err @@ -6156,10 +6270,10 @@ func withMachineID(id int) machineOption { } } -// withMachine sets the old Machine of the mutation. -func withMachine(node *Machine) machineOption { - return func(m *MachineMutation) { - m.oldValue = func(context.Context) (*Machine, error) { +// withLock sets the old Lock of the mutation. +func withLock(node *Lock) lockOption { + return func(m *LockMutation) { + m.oldValue = func(context.Context) (*Lock, error) { return node, nil } m.id = &node.ID @@ -6168,7 +6282,7 @@ func withMachine(node *Machine) machineOption { // Client returns a new `ent.Client` from the mutation. If the mutation was // executed in a transaction (ent.Tx), a transactional client is returned. -func (m MachineMutation) Client() *Client { +func (m LockMutation) Client() *Client { client := &Client{config: m.config} client.init() return client @@ -6176,7 +6290,7 @@ func (m MachineMutation) Client() *Client { // Tx returns an `ent.Tx` for mutations that were executed in transactions; // it returns an error otherwise. -func (m MachineMutation) Tx() (*Tx, error) { +func (m LockMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { return nil, errors.New("ent: mutation is not running in a transaction") } @@ -6187,7 +6301,7 @@ func (m MachineMutation) Tx() (*Tx, error) { // ID returns the ID value in the mutation. Note that the ID is only available // if it was provided to the builder or after it was returned from the database. -func (m *MachineMutation) ID() (id int, exists bool) { +func (m *LockMutation) ID() (id int, exists bool) { if m.id == nil { return } @@ -6198,7 +6312,7 @@ func (m *MachineMutation) ID() (id int, exists bool) { // That means, if the mutation is applied within a transaction with an isolation level such // as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated // or updated by the mutation. -func (m *MachineMutation) IDs(ctx context.Context) ([]int, error) { +func (m *LockMutation) IDs(ctx context.Context) ([]int, error) { switch { case m.op.Is(OpUpdateOne | OpDeleteOne): id, exists := m.ID() @@ -6207,228 +6321,599 @@ func (m *MachineMutation) IDs(ctx context.Context) ([]int, error) { } fallthrough case m.op.Is(OpUpdate | OpDelete): - return m.Client().Machine.Query().Where(m.predicates...).IDs(ctx) + return m.Client().Lock.Query().Where(m.predicates...).IDs(ctx) default: return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) } } -// SetCreatedAt sets the "created_at" field. -func (m *MachineMutation) SetCreatedAt(t time.Time) { - m.created_at = &t +// SetName sets the "name" field. +func (m *LockMutation) SetName(s string) { + m.name = &s } -// CreatedAt returns the value of the "created_at" field in the mutation. -func (m *MachineMutation) CreatedAt() (r time.Time, exists bool) { - v := m.created_at +// Name returns the value of the "name" field in the mutation. +func (m *LockMutation) Name() (r string, exists bool) { + v := m.name if v == nil { return } return *v, true } -// OldCreatedAt returns the old "created_at" field's value of the Machine entity. -// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// OldName returns the old "name" field's value of the Lock entity. +// If the Lock object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MachineMutation) OldCreatedAt(ctx context.Context) (v *time.Time, err error) { +func (m *LockMutation) OldName(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + return v, errors.New("OldName is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldCreatedAt requires an ID field in the mutation") + return v, errors.New("OldName requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + return v, fmt.Errorf("querying old value for OldName: %w", err) } - return oldValue.CreatedAt, nil -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (m *MachineMutation) ClearCreatedAt() { - m.created_at = nil - m.clearedFields[machine.FieldCreatedAt] = struct{}{} -} - -// CreatedAtCleared returns if the "created_at" field was cleared in this mutation. -func (m *MachineMutation) CreatedAtCleared() bool { - _, ok := m.clearedFields[machine.FieldCreatedAt] - return ok + return oldValue.Name, nil } -// ResetCreatedAt resets all changes to the "created_at" field. -func (m *MachineMutation) ResetCreatedAt() { - m.created_at = nil - delete(m.clearedFields, machine.FieldCreatedAt) +// ResetName resets all changes to the "name" field. +func (m *LockMutation) ResetName() { + m.name = nil } -// SetUpdatedAt sets the "updated_at" field. -func (m *MachineMutation) SetUpdatedAt(t time.Time) { - m.updated_at = &t +// SetCreatedAt sets the "created_at" field. +func (m *LockMutation) SetCreatedAt(t time.Time) { + m.created_at = &t } -// UpdatedAt returns the value of the "updated_at" field in the mutation. -func (m *MachineMutation) UpdatedAt() (r time.Time, exists bool) { - v := m.updated_at +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *LockMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at if v == nil { return } return *v, true } -// OldUpdatedAt returns the old "updated_at" field's value of the Machine entity. -// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// OldCreatedAt returns the old "created_at" field's value of the Lock entity. +// If the Lock object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MachineMutation) OldUpdatedAt(ctx context.Context) (v *time.Time, err error) { +func (m *LockMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + return v, errors.New("OldCreatedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) } - return oldValue.UpdatedAt, nil + return oldValue.CreatedAt, nil } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (m *MachineMutation) ClearUpdatedAt() { - m.updated_at = nil - m.clearedFields[machine.FieldUpdatedAt] = struct{}{} +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *LockMutation) ResetCreatedAt() { + m.created_at = nil } -// UpdatedAtCleared returns if the "updated_at" field was cleared in this mutation. -func (m *MachineMutation) UpdatedAtCleared() bool { - _, ok := m.clearedFields[machine.FieldUpdatedAt] - return ok +// Where appends a list predicates to the LockMutation builder. +func (m *LockMutation) Where(ps ...predicate.Lock) { + m.predicates = append(m.predicates, ps...) } -// ResetUpdatedAt resets all changes to the "updated_at" field. -func (m *MachineMutation) ResetUpdatedAt() { - m.updated_at = nil - delete(m.clearedFields, machine.FieldUpdatedAt) +// WhereP appends storage-level predicates to the LockMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *LockMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Lock, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) } -// SetLastPush sets the "last_push" field. -func (m *MachineMutation) SetLastPush(t time.Time) { - m.last_push = &t +// Op returns the operation name. +func (m *LockMutation) Op() Op { + return m.op } -// LastPush returns the value of the "last_push" field in the mutation. -func (m *MachineMutation) LastPush() (r time.Time, exists bool) { - v := m.last_push - if v == nil { - return - } - return *v, true +// SetOp allows setting the mutation operation. +func (m *LockMutation) SetOp(op Op) { + m.op = op } -// OldLastPush returns the old "last_push" field's value of the Machine entity. -// If the Machine object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MachineMutation) OldLastPush(ctx context.Context) (v *time.Time, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldLastPush is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldLastPush requires an ID field in the mutation") +// Type returns the node type of this mutation (Lock). +func (m *LockMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *LockMutation) Fields() []string { + fields := make([]string, 0, 2) + if m.name != nil { + fields = append(fields, lock.FieldName) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldLastPush: %w", err) + if m.created_at != nil { + fields = append(fields, lock.FieldCreatedAt) } - return oldValue.LastPush, nil + return fields } -// ClearLastPush clears the value of the "last_push" field. -func (m *MachineMutation) ClearLastPush() { - m.last_push = nil - m.clearedFields[machine.FieldLastPush] = struct{}{} +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *LockMutation) Field(name string) (ent.Value, bool) { + switch name { + case lock.FieldName: + return m.Name() + case lock.FieldCreatedAt: + return m.CreatedAt() + } + return nil, false } -// LastPushCleared returns if the "last_push" field was cleared in this mutation. -func (m *MachineMutation) LastPushCleared() bool { - _, ok := m.clearedFields[machine.FieldLastPush] - return ok +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *LockMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case lock.FieldName: + return m.OldName(ctx) + case lock.FieldCreatedAt: + return m.OldCreatedAt(ctx) + } + return nil, fmt.Errorf("unknown Lock field %s", name) } -// ResetLastPush resets all changes to the "last_push" field. -func (m *MachineMutation) ResetLastPush() { - m.last_push = nil - delete(m.clearedFields, machine.FieldLastPush) +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *LockMutation) SetField(name string, value ent.Value) error { + switch name { + case lock.FieldName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetName(v) + return nil + case lock.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + } + return fmt.Errorf("unknown Lock field %s", name) } -// SetLastHeartbeat sets the "last_heartbeat" field. -func (m *MachineMutation) SetLastHeartbeat(t time.Time) { - m.last_heartbeat = &t +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *LockMutation) AddedFields() []string { + return nil } -// LastHeartbeat returns the value of the "last_heartbeat" field in the mutation. -func (m *MachineMutation) LastHeartbeat() (r time.Time, exists bool) { - v := m.last_heartbeat - if v == nil { - return +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *LockMutation) AddedField(name string) (ent.Value, bool) { + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *LockMutation) AddField(name string, value ent.Value) error { + switch name { } - return *v, true + return fmt.Errorf("unknown Lock numeric field %s", name) } -// OldLastHeartbeat returns the old "last_heartbeat" field's value of the Machine entity. -// If the Machine object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MachineMutation) OldLastHeartbeat(ctx context.Context) (v *time.Time, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldLastHeartbeat is only allowed on UpdateOne operations") +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *LockMutation) ClearedFields() []string { + return nil +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *LockMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *LockMutation) ClearField(name string) error { + return fmt.Errorf("unknown Lock nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *LockMutation) ResetField(name string) error { + switch name { + case lock.FieldName: + m.ResetName() + return nil + case lock.FieldCreatedAt: + m.ResetCreatedAt() + return nil } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldLastHeartbeat requires an ID field in the mutation") + return fmt.Errorf("unknown Lock field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *LockMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *LockMutation) AddedIDs(name string) []ent.Value { + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *LockMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *LockMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *LockMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *LockMutation) EdgeCleared(name string) bool { + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *LockMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown Lock unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *LockMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown Lock edge %s", name) +} + +// MachineMutation represents an operation that mutates the Machine nodes in the graph. +type MachineMutation struct { + config + op Op + typ string + id *int + created_at *time.Time + updated_at *time.Time + last_push *time.Time + last_heartbeat *time.Time + machineId *string + password *string + ipAddress *string + scenarios *string + version *string + isValidated *bool + auth_type *string + osname *string + osversion *string + featureflags *string + hubstate *map[string][]schema.ItemState + datasources *map[string]int64 + clearedFields map[string]struct{} + alerts map[int]struct{} + removedalerts map[int]struct{} + clearedalerts bool + done bool + oldValue func(context.Context) (*Machine, error) + predicates []predicate.Machine +} + +var _ ent.Mutation = (*MachineMutation)(nil) + +// machineOption allows management of the mutation configuration using functional options. +type machineOption func(*MachineMutation) + +// newMachineMutation creates new mutation for the Machine entity. +func newMachineMutation(c config, op Op, opts ...machineOption) *MachineMutation { + m := &MachineMutation{ + config: c, + op: op, + typ: TypeMachine, + clearedFields: make(map[string]struct{}), } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldLastHeartbeat: %w", err) + for _, opt := range opts { + opt(m) } - return oldValue.LastHeartbeat, nil + return m } -// ClearLastHeartbeat clears the value of the "last_heartbeat" field. -func (m *MachineMutation) ClearLastHeartbeat() { - m.last_heartbeat = nil - m.clearedFields[machine.FieldLastHeartbeat] = struct{}{} +// withMachineID sets the ID field of the mutation. +func withMachineID(id int) machineOption { + return func(m *MachineMutation) { + var ( + err error + once sync.Once + value *Machine + ) + m.oldValue = func(ctx context.Context) (*Machine, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().Machine.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } } -// LastHeartbeatCleared returns if the "last_heartbeat" field was cleared in this mutation. -func (m *MachineMutation) LastHeartbeatCleared() bool { - _, ok := m.clearedFields[machine.FieldLastHeartbeat] - return ok +// withMachine sets the old Machine of the mutation. +func withMachine(node *Machine) machineOption { + return func(m *MachineMutation) { + m.oldValue = func(context.Context) (*Machine, error) { + return node, nil + } + m.id = &node.ID + } } -// ResetLastHeartbeat resets all changes to the "last_heartbeat" field. -func (m *MachineMutation) ResetLastHeartbeat() { - m.last_heartbeat = nil - delete(m.clearedFields, machine.FieldLastHeartbeat) +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m MachineMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client } -// SetMachineId sets the "machineId" field. -func (m *MachineMutation) SetMachineId(s string) { - m.machineId = &s +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m MachineMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil } -// MachineId returns the value of the "machineId" field in the mutation. -func (m *MachineMutation) MachineId() (r string, exists bool) { - v := m.machineId +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *MachineMutation) ID() (id int, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *MachineMutation) IDs(ctx context.Context) ([]int, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().Machine.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *MachineMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *MachineMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at if v == nil { return } return *v, true } -// OldMachineId returns the old "machineId" field's value of the Machine entity. +// OldCreatedAt returns the old "created_at" field's value of the Machine entity. // If the Machine object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MachineMutation) OldMachineId(ctx context.Context) (v string, err error) { +func (m *MachineMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldMachineId is only allowed on UpdateOne operations") + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *MachineMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *MachineMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *MachineMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *MachineMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetLastPush sets the "last_push" field. +func (m *MachineMutation) SetLastPush(t time.Time) { + m.last_push = &t +} + +// LastPush returns the value of the "last_push" field in the mutation. +func (m *MachineMutation) LastPush() (r time.Time, exists bool) { + v := m.last_push + if v == nil { + return + } + return *v, true +} + +// OldLastPush returns the old "last_push" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldLastPush(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLastPush is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLastPush requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLastPush: %w", err) + } + return oldValue.LastPush, nil +} + +// ClearLastPush clears the value of the "last_push" field. +func (m *MachineMutation) ClearLastPush() { + m.last_push = nil + m.clearedFields[machine.FieldLastPush] = struct{}{} +} + +// LastPushCleared returns if the "last_push" field was cleared in this mutation. +func (m *MachineMutation) LastPushCleared() bool { + _, ok := m.clearedFields[machine.FieldLastPush] + return ok +} + +// ResetLastPush resets all changes to the "last_push" field. +func (m *MachineMutation) ResetLastPush() { + m.last_push = nil + delete(m.clearedFields, machine.FieldLastPush) +} + +// SetLastHeartbeat sets the "last_heartbeat" field. +func (m *MachineMutation) SetLastHeartbeat(t time.Time) { + m.last_heartbeat = &t +} + +// LastHeartbeat returns the value of the "last_heartbeat" field in the mutation. +func (m *MachineMutation) LastHeartbeat() (r time.Time, exists bool) { + v := m.last_heartbeat + if v == nil { + return + } + return *v, true +} + +// OldLastHeartbeat returns the old "last_heartbeat" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldLastHeartbeat(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLastHeartbeat is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLastHeartbeat requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLastHeartbeat: %w", err) + } + return oldValue.LastHeartbeat, nil +} + +// ClearLastHeartbeat clears the value of the "last_heartbeat" field. +func (m *MachineMutation) ClearLastHeartbeat() { + m.last_heartbeat = nil + m.clearedFields[machine.FieldLastHeartbeat] = struct{}{} +} + +// LastHeartbeatCleared returns if the "last_heartbeat" field was cleared in this mutation. +func (m *MachineMutation) LastHeartbeatCleared() bool { + _, ok := m.clearedFields[machine.FieldLastHeartbeat] + return ok +} + +// ResetLastHeartbeat resets all changes to the "last_heartbeat" field. +func (m *MachineMutation) ResetLastHeartbeat() { + m.last_heartbeat = nil + delete(m.clearedFields, machine.FieldLastHeartbeat) +} + +// SetMachineId sets the "machineId" field. +func (m *MachineMutation) SetMachineId(s string) { + m.machineId = &s +} + +// MachineId returns the value of the "machineId" field in the mutation. +func (m *MachineMutation) MachineId() (r string, exists bool) { + v := m.machineId + if v == nil { + return + } + return *v, true +} + +// OldMachineId returns the old "machineId" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldMachineId(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMachineId is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { return v, errors.New("OldMachineId requires an ID field in the mutation") @@ -6456,395 +6941,1461 @@ func (m *MachineMutation) Password() (r string, exists bool) { if v == nil { return } - return *v, true + return *v, true +} + +// OldPassword returns the old "password" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldPassword(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPassword is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPassword requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPassword: %w", err) + } + return oldValue.Password, nil +} + +// ResetPassword resets all changes to the "password" field. +func (m *MachineMutation) ResetPassword() { + m.password = nil +} + +// SetIpAddress sets the "ipAddress" field. +func (m *MachineMutation) SetIpAddress(s string) { + m.ipAddress = &s +} + +// IpAddress returns the value of the "ipAddress" field in the mutation. +func (m *MachineMutation) IpAddress() (r string, exists bool) { + v := m.ipAddress + if v == nil { + return + } + return *v, true +} + +// OldIpAddress returns the old "ipAddress" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldIpAddress(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIpAddress is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIpAddress requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIpAddress: %w", err) + } + return oldValue.IpAddress, nil +} + +// ResetIpAddress resets all changes to the "ipAddress" field. +func (m *MachineMutation) ResetIpAddress() { + m.ipAddress = nil +} + +// SetScenarios sets the "scenarios" field. +func (m *MachineMutation) SetScenarios(s string) { + m.scenarios = &s +} + +// Scenarios returns the value of the "scenarios" field in the mutation. +func (m *MachineMutation) Scenarios() (r string, exists bool) { + v := m.scenarios + if v == nil { + return + } + return *v, true +} + +// OldScenarios returns the old "scenarios" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldScenarios(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldScenarios is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldScenarios requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldScenarios: %w", err) + } + return oldValue.Scenarios, nil +} + +// ClearScenarios clears the value of the "scenarios" field. +func (m *MachineMutation) ClearScenarios() { + m.scenarios = nil + m.clearedFields[machine.FieldScenarios] = struct{}{} +} + +// ScenariosCleared returns if the "scenarios" field was cleared in this mutation. +func (m *MachineMutation) ScenariosCleared() bool { + _, ok := m.clearedFields[machine.FieldScenarios] + return ok +} + +// ResetScenarios resets all changes to the "scenarios" field. +func (m *MachineMutation) ResetScenarios() { + m.scenarios = nil + delete(m.clearedFields, machine.FieldScenarios) +} + +// SetVersion sets the "version" field. +func (m *MachineMutation) SetVersion(s string) { + m.version = &s +} + +// Version returns the value of the "version" field in the mutation. +func (m *MachineMutation) Version() (r string, exists bool) { + v := m.version + if v == nil { + return + } + return *v, true +} + +// OldVersion returns the old "version" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldVersion(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldVersion is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldVersion requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldVersion: %w", err) + } + return oldValue.Version, nil +} + +// ClearVersion clears the value of the "version" field. +func (m *MachineMutation) ClearVersion() { + m.version = nil + m.clearedFields[machine.FieldVersion] = struct{}{} +} + +// VersionCleared returns if the "version" field was cleared in this mutation. +func (m *MachineMutation) VersionCleared() bool { + _, ok := m.clearedFields[machine.FieldVersion] + return ok +} + +// ResetVersion resets all changes to the "version" field. +func (m *MachineMutation) ResetVersion() { + m.version = nil + delete(m.clearedFields, machine.FieldVersion) +} + +// SetIsValidated sets the "isValidated" field. +func (m *MachineMutation) SetIsValidated(b bool) { + m.isValidated = &b +} + +// IsValidated returns the value of the "isValidated" field in the mutation. +func (m *MachineMutation) IsValidated() (r bool, exists bool) { + v := m.isValidated + if v == nil { + return + } + return *v, true +} + +// OldIsValidated returns the old "isValidated" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldIsValidated(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIsValidated is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIsValidated requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIsValidated: %w", err) + } + return oldValue.IsValidated, nil +} + +// ResetIsValidated resets all changes to the "isValidated" field. +func (m *MachineMutation) ResetIsValidated() { + m.isValidated = nil +} + +// SetAuthType sets the "auth_type" field. +func (m *MachineMutation) SetAuthType(s string) { + m.auth_type = &s +} + +// AuthType returns the value of the "auth_type" field in the mutation. +func (m *MachineMutation) AuthType() (r string, exists bool) { + v := m.auth_type + if v == nil { + return + } + return *v, true +} + +// OldAuthType returns the old "auth_type" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldAuthType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAuthType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAuthType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAuthType: %w", err) + } + return oldValue.AuthType, nil +} + +// ResetAuthType resets all changes to the "auth_type" field. +func (m *MachineMutation) ResetAuthType() { + m.auth_type = nil +} + +// SetOsname sets the "osname" field. +func (m *MachineMutation) SetOsname(s string) { + m.osname = &s +} + +// Osname returns the value of the "osname" field in the mutation. +func (m *MachineMutation) Osname() (r string, exists bool) { + v := m.osname + if v == nil { + return + } + return *v, true +} + +// OldOsname returns the old "osname" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldOsname(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldOsname is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldOsname requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldOsname: %w", err) + } + return oldValue.Osname, nil +} + +// ClearOsname clears the value of the "osname" field. +func (m *MachineMutation) ClearOsname() { + m.osname = nil + m.clearedFields[machine.FieldOsname] = struct{}{} +} + +// OsnameCleared returns if the "osname" field was cleared in this mutation. +func (m *MachineMutation) OsnameCleared() bool { + _, ok := m.clearedFields[machine.FieldOsname] + return ok +} + +// ResetOsname resets all changes to the "osname" field. +func (m *MachineMutation) ResetOsname() { + m.osname = nil + delete(m.clearedFields, machine.FieldOsname) +} + +// SetOsversion sets the "osversion" field. +func (m *MachineMutation) SetOsversion(s string) { + m.osversion = &s +} + +// Osversion returns the value of the "osversion" field in the mutation. +func (m *MachineMutation) Osversion() (r string, exists bool) { + v := m.osversion + if v == nil { + return + } + return *v, true +} + +// OldOsversion returns the old "osversion" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldOsversion(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldOsversion is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldOsversion requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldOsversion: %w", err) + } + return oldValue.Osversion, nil +} + +// ClearOsversion clears the value of the "osversion" field. +func (m *MachineMutation) ClearOsversion() { + m.osversion = nil + m.clearedFields[machine.FieldOsversion] = struct{}{} +} + +// OsversionCleared returns if the "osversion" field was cleared in this mutation. +func (m *MachineMutation) OsversionCleared() bool { + _, ok := m.clearedFields[machine.FieldOsversion] + return ok +} + +// ResetOsversion resets all changes to the "osversion" field. +func (m *MachineMutation) ResetOsversion() { + m.osversion = nil + delete(m.clearedFields, machine.FieldOsversion) +} + +// SetFeatureflags sets the "featureflags" field. +func (m *MachineMutation) SetFeatureflags(s string) { + m.featureflags = &s +} + +// Featureflags returns the value of the "featureflags" field in the mutation. +func (m *MachineMutation) Featureflags() (r string, exists bool) { + v := m.featureflags + if v == nil { + return + } + return *v, true +} + +// OldFeatureflags returns the old "featureflags" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldFeatureflags(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFeatureflags is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFeatureflags requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFeatureflags: %w", err) + } + return oldValue.Featureflags, nil +} + +// ClearFeatureflags clears the value of the "featureflags" field. +func (m *MachineMutation) ClearFeatureflags() { + m.featureflags = nil + m.clearedFields[machine.FieldFeatureflags] = struct{}{} +} + +// FeatureflagsCleared returns if the "featureflags" field was cleared in this mutation. +func (m *MachineMutation) FeatureflagsCleared() bool { + _, ok := m.clearedFields[machine.FieldFeatureflags] + return ok +} + +// ResetFeatureflags resets all changes to the "featureflags" field. +func (m *MachineMutation) ResetFeatureflags() { + m.featureflags = nil + delete(m.clearedFields, machine.FieldFeatureflags) +} + +// SetHubstate sets the "hubstate" field. +func (m *MachineMutation) SetHubstate(ms map[string][]schema.ItemState) { + m.hubstate = &ms +} + +// Hubstate returns the value of the "hubstate" field in the mutation. +func (m *MachineMutation) Hubstate() (r map[string][]schema.ItemState, exists bool) { + v := m.hubstate + if v == nil { + return + } + return *v, true +} + +// OldHubstate returns the old "hubstate" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldHubstate(ctx context.Context) (v map[string][]schema.ItemState, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldHubstate is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldHubstate requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldHubstate: %w", err) + } + return oldValue.Hubstate, nil +} + +// ClearHubstate clears the value of the "hubstate" field. +func (m *MachineMutation) ClearHubstate() { + m.hubstate = nil + m.clearedFields[machine.FieldHubstate] = struct{}{} +} + +// HubstateCleared returns if the "hubstate" field was cleared in this mutation. +func (m *MachineMutation) HubstateCleared() bool { + _, ok := m.clearedFields[machine.FieldHubstate] + return ok +} + +// ResetHubstate resets all changes to the "hubstate" field. +func (m *MachineMutation) ResetHubstate() { + m.hubstate = nil + delete(m.clearedFields, machine.FieldHubstate) +} + +// SetDatasources sets the "datasources" field. +func (m *MachineMutation) SetDatasources(value map[string]int64) { + m.datasources = &value +} + +// Datasources returns the value of the "datasources" field in the mutation. +func (m *MachineMutation) Datasources() (r map[string]int64, exists bool) { + v := m.datasources + if v == nil { + return + } + return *v, true +} + +// OldDatasources returns the old "datasources" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldDatasources(ctx context.Context) (v map[string]int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDatasources is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDatasources requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDatasources: %w", err) + } + return oldValue.Datasources, nil +} + +// ClearDatasources clears the value of the "datasources" field. +func (m *MachineMutation) ClearDatasources() { + m.datasources = nil + m.clearedFields[machine.FieldDatasources] = struct{}{} +} + +// DatasourcesCleared returns if the "datasources" field was cleared in this mutation. +func (m *MachineMutation) DatasourcesCleared() bool { + _, ok := m.clearedFields[machine.FieldDatasources] + return ok +} + +// ResetDatasources resets all changes to the "datasources" field. +func (m *MachineMutation) ResetDatasources() { + m.datasources = nil + delete(m.clearedFields, machine.FieldDatasources) +} + +// AddAlertIDs adds the "alerts" edge to the Alert entity by ids. +func (m *MachineMutation) AddAlertIDs(ids ...int) { + if m.alerts == nil { + m.alerts = make(map[int]struct{}) + } + for i := range ids { + m.alerts[ids[i]] = struct{}{} + } +} + +// ClearAlerts clears the "alerts" edge to the Alert entity. +func (m *MachineMutation) ClearAlerts() { + m.clearedalerts = true +} + +// AlertsCleared reports if the "alerts" edge to the Alert entity was cleared. +func (m *MachineMutation) AlertsCleared() bool { + return m.clearedalerts +} + +// RemoveAlertIDs removes the "alerts" edge to the Alert entity by IDs. +func (m *MachineMutation) RemoveAlertIDs(ids ...int) { + if m.removedalerts == nil { + m.removedalerts = make(map[int]struct{}) + } + for i := range ids { + delete(m.alerts, ids[i]) + m.removedalerts[ids[i]] = struct{}{} + } +} + +// RemovedAlerts returns the removed IDs of the "alerts" edge to the Alert entity. +func (m *MachineMutation) RemovedAlertsIDs() (ids []int) { + for id := range m.removedalerts { + ids = append(ids, id) + } + return +} + +// AlertsIDs returns the "alerts" edge IDs in the mutation. +func (m *MachineMutation) AlertsIDs() (ids []int) { + for id := range m.alerts { + ids = append(ids, id) + } + return +} + +// ResetAlerts resets all changes to the "alerts" edge. +func (m *MachineMutation) ResetAlerts() { + m.alerts = nil + m.clearedalerts = false + m.removedalerts = nil +} + +// Where appends a list predicates to the MachineMutation builder. +func (m *MachineMutation) Where(ps ...predicate.Machine) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the MachineMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *MachineMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Machine, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *MachineMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *MachineMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (Machine). +func (m *MachineMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *MachineMutation) Fields() []string { + fields := make([]string, 0, 16) + if m.created_at != nil { + fields = append(fields, machine.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, machine.FieldUpdatedAt) + } + if m.last_push != nil { + fields = append(fields, machine.FieldLastPush) + } + if m.last_heartbeat != nil { + fields = append(fields, machine.FieldLastHeartbeat) + } + if m.machineId != nil { + fields = append(fields, machine.FieldMachineId) + } + if m.password != nil { + fields = append(fields, machine.FieldPassword) + } + if m.ipAddress != nil { + fields = append(fields, machine.FieldIpAddress) + } + if m.scenarios != nil { + fields = append(fields, machine.FieldScenarios) + } + if m.version != nil { + fields = append(fields, machine.FieldVersion) + } + if m.isValidated != nil { + fields = append(fields, machine.FieldIsValidated) + } + if m.auth_type != nil { + fields = append(fields, machine.FieldAuthType) + } + if m.osname != nil { + fields = append(fields, machine.FieldOsname) + } + if m.osversion != nil { + fields = append(fields, machine.FieldOsversion) + } + if m.featureflags != nil { + fields = append(fields, machine.FieldFeatureflags) + } + if m.hubstate != nil { + fields = append(fields, machine.FieldHubstate) + } + if m.datasources != nil { + fields = append(fields, machine.FieldDatasources) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *MachineMutation) Field(name string) (ent.Value, bool) { + switch name { + case machine.FieldCreatedAt: + return m.CreatedAt() + case machine.FieldUpdatedAt: + return m.UpdatedAt() + case machine.FieldLastPush: + return m.LastPush() + case machine.FieldLastHeartbeat: + return m.LastHeartbeat() + case machine.FieldMachineId: + return m.MachineId() + case machine.FieldPassword: + return m.Password() + case machine.FieldIpAddress: + return m.IpAddress() + case machine.FieldScenarios: + return m.Scenarios() + case machine.FieldVersion: + return m.Version() + case machine.FieldIsValidated: + return m.IsValidated() + case machine.FieldAuthType: + return m.AuthType() + case machine.FieldOsname: + return m.Osname() + case machine.FieldOsversion: + return m.Osversion() + case machine.FieldFeatureflags: + return m.Featureflags() + case machine.FieldHubstate: + return m.Hubstate() + case machine.FieldDatasources: + return m.Datasources() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *MachineMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case machine.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case machine.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case machine.FieldLastPush: + return m.OldLastPush(ctx) + case machine.FieldLastHeartbeat: + return m.OldLastHeartbeat(ctx) + case machine.FieldMachineId: + return m.OldMachineId(ctx) + case machine.FieldPassword: + return m.OldPassword(ctx) + case machine.FieldIpAddress: + return m.OldIpAddress(ctx) + case machine.FieldScenarios: + return m.OldScenarios(ctx) + case machine.FieldVersion: + return m.OldVersion(ctx) + case machine.FieldIsValidated: + return m.OldIsValidated(ctx) + case machine.FieldAuthType: + return m.OldAuthType(ctx) + case machine.FieldOsname: + return m.OldOsname(ctx) + case machine.FieldOsversion: + return m.OldOsversion(ctx) + case machine.FieldFeatureflags: + return m.OldFeatureflags(ctx) + case machine.FieldHubstate: + return m.OldHubstate(ctx) + case machine.FieldDatasources: + return m.OldDatasources(ctx) + } + return nil, fmt.Errorf("unknown Machine field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *MachineMutation) SetField(name string, value ent.Value) error { + switch name { + case machine.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case machine.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case machine.FieldLastPush: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLastPush(v) + return nil + case machine.FieldLastHeartbeat: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLastHeartbeat(v) + return nil + case machine.FieldMachineId: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMachineId(v) + return nil + case machine.FieldPassword: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPassword(v) + return nil + case machine.FieldIpAddress: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIpAddress(v) + return nil + case machine.FieldScenarios: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetScenarios(v) + return nil + case machine.FieldVersion: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetVersion(v) + return nil + case machine.FieldIsValidated: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIsValidated(v) + return nil + case machine.FieldAuthType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAuthType(v) + return nil + case machine.FieldOsname: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOsname(v) + return nil + case machine.FieldOsversion: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOsversion(v) + return nil + case machine.FieldFeatureflags: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFeatureflags(v) + return nil + case machine.FieldHubstate: + v, ok := value.(map[string][]schema.ItemState) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetHubstate(v) + return nil + case machine.FieldDatasources: + v, ok := value.(map[string]int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDatasources(v) + return nil + } + return fmt.Errorf("unknown Machine field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *MachineMutation) AddedFields() []string { + return nil +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *MachineMutation) AddedField(name string) (ent.Value, bool) { + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *MachineMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown Machine numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *MachineMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(machine.FieldLastPush) { + fields = append(fields, machine.FieldLastPush) + } + if m.FieldCleared(machine.FieldLastHeartbeat) { + fields = append(fields, machine.FieldLastHeartbeat) + } + if m.FieldCleared(machine.FieldScenarios) { + fields = append(fields, machine.FieldScenarios) + } + if m.FieldCleared(machine.FieldVersion) { + fields = append(fields, machine.FieldVersion) + } + if m.FieldCleared(machine.FieldOsname) { + fields = append(fields, machine.FieldOsname) + } + if m.FieldCleared(machine.FieldOsversion) { + fields = append(fields, machine.FieldOsversion) + } + if m.FieldCleared(machine.FieldFeatureflags) { + fields = append(fields, machine.FieldFeatureflags) + } + if m.FieldCleared(machine.FieldHubstate) { + fields = append(fields, machine.FieldHubstate) + } + if m.FieldCleared(machine.FieldDatasources) { + fields = append(fields, machine.FieldDatasources) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *MachineMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *MachineMutation) ClearField(name string) error { + switch name { + case machine.FieldLastPush: + m.ClearLastPush() + return nil + case machine.FieldLastHeartbeat: + m.ClearLastHeartbeat() + return nil + case machine.FieldScenarios: + m.ClearScenarios() + return nil + case machine.FieldVersion: + m.ClearVersion() + return nil + case machine.FieldOsname: + m.ClearOsname() + return nil + case machine.FieldOsversion: + m.ClearOsversion() + return nil + case machine.FieldFeatureflags: + m.ClearFeatureflags() + return nil + case machine.FieldHubstate: + m.ClearHubstate() + return nil + case machine.FieldDatasources: + m.ClearDatasources() + return nil + } + return fmt.Errorf("unknown Machine nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *MachineMutation) ResetField(name string) error { + switch name { + case machine.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case machine.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case machine.FieldLastPush: + m.ResetLastPush() + return nil + case machine.FieldLastHeartbeat: + m.ResetLastHeartbeat() + return nil + case machine.FieldMachineId: + m.ResetMachineId() + return nil + case machine.FieldPassword: + m.ResetPassword() + return nil + case machine.FieldIpAddress: + m.ResetIpAddress() + return nil + case machine.FieldScenarios: + m.ResetScenarios() + return nil + case machine.FieldVersion: + m.ResetVersion() + return nil + case machine.FieldIsValidated: + m.ResetIsValidated() + return nil + case machine.FieldAuthType: + m.ResetAuthType() + return nil + case machine.FieldOsname: + m.ResetOsname() + return nil + case machine.FieldOsversion: + m.ResetOsversion() + return nil + case machine.FieldFeatureflags: + m.ResetFeatureflags() + return nil + case machine.FieldHubstate: + m.ResetHubstate() + return nil + case machine.FieldDatasources: + m.ResetDatasources() + return nil + } + return fmt.Errorf("unknown Machine field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *MachineMutation) AddedEdges() []string { + edges := make([]string, 0, 1) + if m.alerts != nil { + edges = append(edges, machine.EdgeAlerts) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *MachineMutation) AddedIDs(name string) []ent.Value { + switch name { + case machine.EdgeAlerts: + ids := make([]ent.Value, 0, len(m.alerts)) + for id := range m.alerts { + ids = append(ids, id) + } + return ids + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *MachineMutation) RemovedEdges() []string { + edges := make([]string, 0, 1) + if m.removedalerts != nil { + edges = append(edges, machine.EdgeAlerts) + } + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *MachineMutation) RemovedIDs(name string) []ent.Value { + switch name { + case machine.EdgeAlerts: + ids := make([]ent.Value, 0, len(m.removedalerts)) + for id := range m.removedalerts { + ids = append(ids, id) + } + return ids + } + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *MachineMutation) ClearedEdges() []string { + edges := make([]string, 0, 1) + if m.clearedalerts { + edges = append(edges, machine.EdgeAlerts) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *MachineMutation) EdgeCleared(name string) bool { + switch name { + case machine.EdgeAlerts: + return m.clearedalerts + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *MachineMutation) ClearEdge(name string) error { + switch name { + } + return fmt.Errorf("unknown Machine unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *MachineMutation) ResetEdge(name string) error { + switch name { + case machine.EdgeAlerts: + m.ResetAlerts() + return nil + } + return fmt.Errorf("unknown Machine edge %s", name) +} + +// MetaMutation represents an operation that mutates the Meta nodes in the graph. +type MetaMutation struct { + config + op Op + typ string + id *int + created_at *time.Time + updated_at *time.Time + key *string + value *string + clearedFields map[string]struct{} + owner *int + clearedowner bool + done bool + oldValue func(context.Context) (*Meta, error) + predicates []predicate.Meta +} + +var _ ent.Mutation = (*MetaMutation)(nil) + +// metaOption allows management of the mutation configuration using functional options. +type metaOption func(*MetaMutation) + +// newMetaMutation creates new mutation for the Meta entity. +func newMetaMutation(c config, op Op, opts ...metaOption) *MetaMutation { + m := &MetaMutation{ + config: c, + op: op, + typ: TypeMeta, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withMetaID sets the ID field of the mutation. +func withMetaID(id int) metaOption { + return func(m *MetaMutation) { + var ( + err error + once sync.Once + value *Meta + ) + m.oldValue = func(ctx context.Context) (*Meta, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().Meta.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } } -// OldPassword returns the old "password" field's value of the Machine entity. -// If the Machine object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MachineMutation) OldPassword(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPassword is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPassword requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldPassword: %w", err) +// withMeta sets the old Meta of the mutation. +func withMeta(node *Meta) metaOption { + return func(m *MetaMutation) { + m.oldValue = func(context.Context) (*Meta, error) { + return node, nil + } + m.id = &node.ID } - return oldValue.Password, nil } -// ResetPassword resets all changes to the "password" field. -func (m *MachineMutation) ResetPassword() { - m.password = nil +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m MetaMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client } -// SetIpAddress sets the "ipAddress" field. -func (m *MachineMutation) SetIpAddress(s string) { - m.ipAddress = &s +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m MetaMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil } -// IpAddress returns the value of the "ipAddress" field in the mutation. -func (m *MachineMutation) IpAddress() (r string, exists bool) { - v := m.ipAddress - if v == nil { +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *MetaMutation) ID() (id int, exists bool) { + if m.id == nil { return } - return *v, true + return *m.id, true } -// OldIpAddress returns the old "ipAddress" field's value of the Machine entity. -// If the Machine object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MachineMutation) OldIpAddress(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldIpAddress is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldIpAddress requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldIpAddress: %w", err) +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *MetaMutation) IDs(ctx context.Context) ([]int, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().Meta.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) } - return oldValue.IpAddress, nil -} - -// ResetIpAddress resets all changes to the "ipAddress" field. -func (m *MachineMutation) ResetIpAddress() { - m.ipAddress = nil } -// SetScenarios sets the "scenarios" field. -func (m *MachineMutation) SetScenarios(s string) { - m.scenarios = &s +// SetCreatedAt sets the "created_at" field. +func (m *MetaMutation) SetCreatedAt(t time.Time) { + m.created_at = &t } -// Scenarios returns the value of the "scenarios" field in the mutation. -func (m *MachineMutation) Scenarios() (r string, exists bool) { - v := m.scenarios +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *MetaMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at if v == nil { return } return *v, true } -// OldScenarios returns the old "scenarios" field's value of the Machine entity. -// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// OldCreatedAt returns the old "created_at" field's value of the Meta entity. +// If the Meta object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MachineMutation) OldScenarios(ctx context.Context) (v string, err error) { +func (m *MetaMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldScenarios is only allowed on UpdateOne operations") + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldScenarios requires an ID field in the mutation") + return v, errors.New("OldCreatedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldScenarios: %w", err) + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) } - return oldValue.Scenarios, nil -} - -// ClearScenarios clears the value of the "scenarios" field. -func (m *MachineMutation) ClearScenarios() { - m.scenarios = nil - m.clearedFields[machine.FieldScenarios] = struct{}{} -} - -// ScenariosCleared returns if the "scenarios" field was cleared in this mutation. -func (m *MachineMutation) ScenariosCleared() bool { - _, ok := m.clearedFields[machine.FieldScenarios] - return ok + return oldValue.CreatedAt, nil } -// ResetScenarios resets all changes to the "scenarios" field. -func (m *MachineMutation) ResetScenarios() { - m.scenarios = nil - delete(m.clearedFields, machine.FieldScenarios) +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *MetaMutation) ResetCreatedAt() { + m.created_at = nil } -// SetVersion sets the "version" field. -func (m *MachineMutation) SetVersion(s string) { - m.version = &s +// SetUpdatedAt sets the "updated_at" field. +func (m *MetaMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t } -// Version returns the value of the "version" field in the mutation. -func (m *MachineMutation) Version() (r string, exists bool) { - v := m.version +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *MetaMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at if v == nil { return } return *v, true } -// OldVersion returns the old "version" field's value of the Machine entity. -// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// OldUpdatedAt returns the old "updated_at" field's value of the Meta entity. +// If the Meta object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MachineMutation) OldVersion(ctx context.Context) (v string, err error) { +func (m *MetaMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldVersion is only allowed on UpdateOne operations") + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldVersion requires an ID field in the mutation") + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldVersion: %w", err) + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) } - return oldValue.Version, nil -} - -// ClearVersion clears the value of the "version" field. -func (m *MachineMutation) ClearVersion() { - m.version = nil - m.clearedFields[machine.FieldVersion] = struct{}{} -} - -// VersionCleared returns if the "version" field was cleared in this mutation. -func (m *MachineMutation) VersionCleared() bool { - _, ok := m.clearedFields[machine.FieldVersion] - return ok + return oldValue.UpdatedAt, nil } -// ResetVersion resets all changes to the "version" field. -func (m *MachineMutation) ResetVersion() { - m.version = nil - delete(m.clearedFields, machine.FieldVersion) +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *MetaMutation) ResetUpdatedAt() { + m.updated_at = nil } -// SetIsValidated sets the "isValidated" field. -func (m *MachineMutation) SetIsValidated(b bool) { - m.isValidated = &b +// SetKey sets the "key" field. +func (m *MetaMutation) SetKey(s string) { + m.key = &s } -// IsValidated returns the value of the "isValidated" field in the mutation. -func (m *MachineMutation) IsValidated() (r bool, exists bool) { - v := m.isValidated +// Key returns the value of the "key" field in the mutation. +func (m *MetaMutation) Key() (r string, exists bool) { + v := m.key if v == nil { return } return *v, true } -// OldIsValidated returns the old "isValidated" field's value of the Machine entity. -// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// OldKey returns the old "key" field's value of the Meta entity. +// If the Meta object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MachineMutation) OldIsValidated(ctx context.Context) (v bool, err error) { +func (m *MetaMutation) OldKey(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldIsValidated is only allowed on UpdateOne operations") + return v, errors.New("OldKey is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldIsValidated requires an ID field in the mutation") + return v, errors.New("OldKey requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldIsValidated: %w", err) + return v, fmt.Errorf("querying old value for OldKey: %w", err) } - return oldValue.IsValidated, nil + return oldValue.Key, nil } -// ResetIsValidated resets all changes to the "isValidated" field. -func (m *MachineMutation) ResetIsValidated() { - m.isValidated = nil +// ResetKey resets all changes to the "key" field. +func (m *MetaMutation) ResetKey() { + m.key = nil } -// SetStatus sets the "status" field. -func (m *MachineMutation) SetStatus(s string) { - m.status = &s +// SetValue sets the "value" field. +func (m *MetaMutation) SetValue(s string) { + m.value = &s } -// Status returns the value of the "status" field in the mutation. -func (m *MachineMutation) Status() (r string, exists bool) { - v := m.status +// Value returns the value of the "value" field in the mutation. +func (m *MetaMutation) Value() (r string, exists bool) { + v := m.value if v == nil { return } return *v, true } -// OldStatus returns the old "status" field's value of the Machine entity. -// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// OldValue returns the old "value" field's value of the Meta entity. +// If the Meta object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MachineMutation) OldStatus(ctx context.Context) (v string, err error) { +func (m *MetaMutation) OldValue(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldStatus is only allowed on UpdateOne operations") + return v, errors.New("OldValue is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldStatus requires an ID field in the mutation") + return v, errors.New("OldValue requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldStatus: %w", err) + return v, fmt.Errorf("querying old value for OldValue: %w", err) } - return oldValue.Status, nil -} - -// ClearStatus clears the value of the "status" field. -func (m *MachineMutation) ClearStatus() { - m.status = nil - m.clearedFields[machine.FieldStatus] = struct{}{} -} - -// StatusCleared returns if the "status" field was cleared in this mutation. -func (m *MachineMutation) StatusCleared() bool { - _, ok := m.clearedFields[machine.FieldStatus] - return ok + return oldValue.Value, nil } -// ResetStatus resets all changes to the "status" field. -func (m *MachineMutation) ResetStatus() { - m.status = nil - delete(m.clearedFields, machine.FieldStatus) +// ResetValue resets all changes to the "value" field. +func (m *MetaMutation) ResetValue() { + m.value = nil } -// SetAuthType sets the "auth_type" field. -func (m *MachineMutation) SetAuthType(s string) { - m.auth_type = &s +// SetAlertMetas sets the "alert_metas" field. +func (m *MetaMutation) SetAlertMetas(i int) { + m.owner = &i } -// AuthType returns the value of the "auth_type" field in the mutation. -func (m *MachineMutation) AuthType() (r string, exists bool) { - v := m.auth_type +// AlertMetas returns the value of the "alert_metas" field in the mutation. +func (m *MetaMutation) AlertMetas() (r int, exists bool) { + v := m.owner if v == nil { return } return *v, true } -// OldAuthType returns the old "auth_type" field's value of the Machine entity. -// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// OldAlertMetas returns the old "alert_metas" field's value of the Meta entity. +// If the Meta object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MachineMutation) OldAuthType(ctx context.Context) (v string, err error) { +func (m *MetaMutation) OldAlertMetas(ctx context.Context) (v int, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldAuthType is only allowed on UpdateOne operations") + return v, errors.New("OldAlertMetas is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldAuthType requires an ID field in the mutation") + return v, errors.New("OldAlertMetas requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldAuthType: %w", err) + return v, fmt.Errorf("querying old value for OldAlertMetas: %w", err) } - return oldValue.AuthType, nil + return oldValue.AlertMetas, nil } -// ResetAuthType resets all changes to the "auth_type" field. -func (m *MachineMutation) ResetAuthType() { - m.auth_type = nil +// ClearAlertMetas clears the value of the "alert_metas" field. +func (m *MetaMutation) ClearAlertMetas() { + m.owner = nil + m.clearedFields[meta.FieldAlertMetas] = struct{}{} } -// AddAlertIDs adds the "alerts" edge to the Alert entity by ids. -func (m *MachineMutation) AddAlertIDs(ids ...int) { - if m.alerts == nil { - m.alerts = make(map[int]struct{}) - } - for i := range ids { - m.alerts[ids[i]] = struct{}{} - } +// AlertMetasCleared returns if the "alert_metas" field was cleared in this mutation. +func (m *MetaMutation) AlertMetasCleared() bool { + _, ok := m.clearedFields[meta.FieldAlertMetas] + return ok } -// ClearAlerts clears the "alerts" edge to the Alert entity. -func (m *MachineMutation) ClearAlerts() { - m.clearedalerts = true +// ResetAlertMetas resets all changes to the "alert_metas" field. +func (m *MetaMutation) ResetAlertMetas() { + m.owner = nil + delete(m.clearedFields, meta.FieldAlertMetas) } -// AlertsCleared reports if the "alerts" edge to the Alert entity was cleared. -func (m *MachineMutation) AlertsCleared() bool { - return m.clearedalerts +// SetOwnerID sets the "owner" edge to the Alert entity by id. +func (m *MetaMutation) SetOwnerID(id int) { + m.owner = &id } -// RemoveAlertIDs removes the "alerts" edge to the Alert entity by IDs. -func (m *MachineMutation) RemoveAlertIDs(ids ...int) { - if m.removedalerts == nil { - m.removedalerts = make(map[int]struct{}) - } - for i := range ids { - delete(m.alerts, ids[i]) - m.removedalerts[ids[i]] = struct{}{} - } +// ClearOwner clears the "owner" edge to the Alert entity. +func (m *MetaMutation) ClearOwner() { + m.clearedowner = true + m.clearedFields[meta.FieldAlertMetas] = struct{}{} } -// RemovedAlerts returns the removed IDs of the "alerts" edge to the Alert entity. -func (m *MachineMutation) RemovedAlertsIDs() (ids []int) { - for id := range m.removedalerts { - ids = append(ids, id) +// OwnerCleared reports if the "owner" edge to the Alert entity was cleared. +func (m *MetaMutation) OwnerCleared() bool { + return m.AlertMetasCleared() || m.clearedowner +} + +// OwnerID returns the "owner" edge ID in the mutation. +func (m *MetaMutation) OwnerID() (id int, exists bool) { + if m.owner != nil { + return *m.owner, true } return } -// AlertsIDs returns the "alerts" edge IDs in the mutation. -func (m *MachineMutation) AlertsIDs() (ids []int) { - for id := range m.alerts { - ids = append(ids, id) +// OwnerIDs returns the "owner" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// OwnerID instead. It exists only for internal usage by the builders. +func (m *MetaMutation) OwnerIDs() (ids []int) { + if id := m.owner; id != nil { + ids = append(ids, *id) } return } -// ResetAlerts resets all changes to the "alerts" edge. -func (m *MachineMutation) ResetAlerts() { - m.alerts = nil - m.clearedalerts = false - m.removedalerts = nil +// ResetOwner resets all changes to the "owner" edge. +func (m *MetaMutation) ResetOwner() { + m.owner = nil + m.clearedowner = false } -// Where appends a list predicates to the MachineMutation builder. -func (m *MachineMutation) Where(ps ...predicate.Machine) { +// Where appends a list predicates to the MetaMutation builder. +func (m *MetaMutation) Where(ps ...predicate.Meta) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the MetaMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *MetaMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Meta, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. -func (m *MachineMutation) Op() Op { +func (m *MetaMutation) Op() Op { return m.op } -// Type returns the node type of this mutation (Machine). -func (m *MachineMutation) Type() string { +// SetOp allows setting the mutation operation. +func (m *MetaMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (Meta). +func (m *MetaMutation) Type() string { return m.typ } // Fields returns all fields that were changed during this mutation. Note that in // order to get all numeric fields that were incremented/decremented, call // AddedFields(). -func (m *MachineMutation) Fields() []string { - fields := make([]string, 0, 12) +func (m *MetaMutation) Fields() []string { + fields := make([]string, 0, 5) if m.created_at != nil { - fields = append(fields, machine.FieldCreatedAt) + fields = append(fields, meta.FieldCreatedAt) } if m.updated_at != nil { - fields = append(fields, machine.FieldUpdatedAt) - } - if m.last_push != nil { - fields = append(fields, machine.FieldLastPush) - } - if m.last_heartbeat != nil { - fields = append(fields, machine.FieldLastHeartbeat) - } - if m.machineId != nil { - fields = append(fields, machine.FieldMachineId) - } - if m.password != nil { - fields = append(fields, machine.FieldPassword) - } - if m.ipAddress != nil { - fields = append(fields, machine.FieldIpAddress) - } - if m.scenarios != nil { - fields = append(fields, machine.FieldScenarios) - } - if m.version != nil { - fields = append(fields, machine.FieldVersion) + fields = append(fields, meta.FieldUpdatedAt) } - if m.isValidated != nil { - fields = append(fields, machine.FieldIsValidated) + if m.key != nil { + fields = append(fields, meta.FieldKey) } - if m.status != nil { - fields = append(fields, machine.FieldStatus) + if m.value != nil { + fields = append(fields, meta.FieldValue) } - if m.auth_type != nil { - fields = append(fields, machine.FieldAuthType) + if m.owner != nil { + fields = append(fields, meta.FieldAlertMetas) } return fields } @@ -6852,32 +8403,18 @@ func (m *MachineMutation) Fields() []string { // Field returns the value of a field with the given name. The second boolean // return value indicates that this field was not set, or was not defined in the // schema. -func (m *MachineMutation) Field(name string) (ent.Value, bool) { +func (m *MetaMutation) Field(name string) (ent.Value, bool) { switch name { - case machine.FieldCreatedAt: + case meta.FieldCreatedAt: return m.CreatedAt() - case machine.FieldUpdatedAt: + case meta.FieldUpdatedAt: return m.UpdatedAt() - case machine.FieldLastPush: - return m.LastPush() - case machine.FieldLastHeartbeat: - return m.LastHeartbeat() - case machine.FieldMachineId: - return m.MachineId() - case machine.FieldPassword: - return m.Password() - case machine.FieldIpAddress: - return m.IpAddress() - case machine.FieldScenarios: - return m.Scenarios() - case machine.FieldVersion: - return m.Version() - case machine.FieldIsValidated: - return m.IsValidated() - case machine.FieldStatus: - return m.Status() - case machine.FieldAuthType: - return m.AuthType() + case meta.FieldKey: + return m.Key() + case meta.FieldValue: + return m.Value() + case meta.FieldAlertMetas: + return m.AlertMetas() } return nil, false } @@ -6885,372 +8422,244 @@ func (m *MachineMutation) Field(name string) (ent.Value, bool) { // OldField returns the old value of the field from the database. An error is // returned if the mutation operation is not UpdateOne, or the query to the // database failed. -func (m *MachineMutation) OldField(ctx context.Context, name string) (ent.Value, error) { +func (m *MetaMutation) OldField(ctx context.Context, name string) (ent.Value, error) { switch name { - case machine.FieldCreatedAt: + case meta.FieldCreatedAt: return m.OldCreatedAt(ctx) - case machine.FieldUpdatedAt: + case meta.FieldUpdatedAt: return m.OldUpdatedAt(ctx) - case machine.FieldLastPush: - return m.OldLastPush(ctx) - case machine.FieldLastHeartbeat: - return m.OldLastHeartbeat(ctx) - case machine.FieldMachineId: - return m.OldMachineId(ctx) - case machine.FieldPassword: - return m.OldPassword(ctx) - case machine.FieldIpAddress: - return m.OldIpAddress(ctx) - case machine.FieldScenarios: - return m.OldScenarios(ctx) - case machine.FieldVersion: - return m.OldVersion(ctx) - case machine.FieldIsValidated: - return m.OldIsValidated(ctx) - case machine.FieldStatus: - return m.OldStatus(ctx) - case machine.FieldAuthType: - return m.OldAuthType(ctx) + case meta.FieldKey: + return m.OldKey(ctx) + case meta.FieldValue: + return m.OldValue(ctx) + case meta.FieldAlertMetas: + return m.OldAlertMetas(ctx) } - return nil, fmt.Errorf("unknown Machine field %s", name) + return nil, fmt.Errorf("unknown Meta field %s", name) } // SetField sets the value of a field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *MachineMutation) SetField(name string, value ent.Value) error { +func (m *MetaMutation) SetField(name string, value ent.Value) error { switch name { - case machine.FieldCreatedAt: + case meta.FieldCreatedAt: v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } m.SetCreatedAt(v) return nil - case machine.FieldUpdatedAt: + case meta.FieldUpdatedAt: v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } m.SetUpdatedAt(v) return nil - case machine.FieldLastPush: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetLastPush(v) - return nil - case machine.FieldLastHeartbeat: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetLastHeartbeat(v) - return nil - case machine.FieldMachineId: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetMachineId(v) - return nil - case machine.FieldPassword: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetPassword(v) - return nil - case machine.FieldIpAddress: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetIpAddress(v) - return nil - case machine.FieldScenarios: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetScenarios(v) - return nil - case machine.FieldVersion: + case meta.FieldKey: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetVersion(v) - return nil - case machine.FieldIsValidated: - v, ok := value.(bool) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetIsValidated(v) + m.SetKey(v) return nil - case machine.FieldStatus: + case meta.FieldValue: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetStatus(v) + m.SetValue(v) return nil - case machine.FieldAuthType: - v, ok := value.(string) + case meta.FieldAlertMetas: + v, ok := value.(int) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetAuthType(v) + m.SetAlertMetas(v) return nil } - return fmt.Errorf("unknown Machine field %s", name) + return fmt.Errorf("unknown Meta field %s", name) } // AddedFields returns all numeric fields that were incremented/decremented during // this mutation. -func (m *MachineMutation) AddedFields() []string { - return nil +func (m *MetaMutation) AddedFields() []string { + var fields []string + return fields } // AddedField returns the numeric value that was incremented/decremented on a field // with the given name. The second boolean return value indicates that this field // was not set, or was not defined in the schema. -func (m *MachineMutation) AddedField(name string) (ent.Value, bool) { +func (m *MetaMutation) AddedField(name string) (ent.Value, bool) { + switch name { + } return nil, false } // AddField adds the value to the field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *MachineMutation) AddField(name string, value ent.Value) error { +func (m *MetaMutation) AddField(name string, value ent.Value) error { switch name { } - return fmt.Errorf("unknown Machine numeric field %s", name) + return fmt.Errorf("unknown Meta numeric field %s", name) } // ClearedFields returns all nullable fields that were cleared during this // mutation. -func (m *MachineMutation) ClearedFields() []string { +func (m *MetaMutation) ClearedFields() []string { var fields []string - if m.FieldCleared(machine.FieldCreatedAt) { - fields = append(fields, machine.FieldCreatedAt) - } - if m.FieldCleared(machine.FieldUpdatedAt) { - fields = append(fields, machine.FieldUpdatedAt) - } - if m.FieldCleared(machine.FieldLastPush) { - fields = append(fields, machine.FieldLastPush) - } - if m.FieldCleared(machine.FieldLastHeartbeat) { - fields = append(fields, machine.FieldLastHeartbeat) - } - if m.FieldCleared(machine.FieldScenarios) { - fields = append(fields, machine.FieldScenarios) - } - if m.FieldCleared(machine.FieldVersion) { - fields = append(fields, machine.FieldVersion) - } - if m.FieldCleared(machine.FieldStatus) { - fields = append(fields, machine.FieldStatus) + if m.FieldCleared(meta.FieldAlertMetas) { + fields = append(fields, meta.FieldAlertMetas) } return fields } // FieldCleared returns a boolean indicating if a field with the given name was // cleared in this mutation. -func (m *MachineMutation) FieldCleared(name string) bool { +func (m *MetaMutation) FieldCleared(name string) bool { _, ok := m.clearedFields[name] return ok } // ClearField clears the value of the field with the given name. It returns an // error if the field is not defined in the schema. -func (m *MachineMutation) ClearField(name string) error { +func (m *MetaMutation) ClearField(name string) error { switch name { - case machine.FieldCreatedAt: - m.ClearCreatedAt() - return nil - case machine.FieldUpdatedAt: - m.ClearUpdatedAt() - return nil - case machine.FieldLastPush: - m.ClearLastPush() - return nil - case machine.FieldLastHeartbeat: - m.ClearLastHeartbeat() - return nil - case machine.FieldScenarios: - m.ClearScenarios() - return nil - case machine.FieldVersion: - m.ClearVersion() - return nil - case machine.FieldStatus: - m.ClearStatus() + case meta.FieldAlertMetas: + m.ClearAlertMetas() return nil } - return fmt.Errorf("unknown Machine nullable field %s", name) + return fmt.Errorf("unknown Meta nullable field %s", name) } // ResetField resets all changes in the mutation for the field with the given name. // It returns an error if the field is not defined in the schema. -func (m *MachineMutation) ResetField(name string) error { +func (m *MetaMutation) ResetField(name string) error { switch name { - case machine.FieldCreatedAt: + case meta.FieldCreatedAt: m.ResetCreatedAt() return nil - case machine.FieldUpdatedAt: + case meta.FieldUpdatedAt: m.ResetUpdatedAt() return nil - case machine.FieldLastPush: - m.ResetLastPush() - return nil - case machine.FieldLastHeartbeat: - m.ResetLastHeartbeat() - return nil - case machine.FieldMachineId: - m.ResetMachineId() - return nil - case machine.FieldPassword: - m.ResetPassword() - return nil - case machine.FieldIpAddress: - m.ResetIpAddress() - return nil - case machine.FieldScenarios: - m.ResetScenarios() - return nil - case machine.FieldVersion: - m.ResetVersion() - return nil - case machine.FieldIsValidated: - m.ResetIsValidated() + case meta.FieldKey: + m.ResetKey() return nil - case machine.FieldStatus: - m.ResetStatus() + case meta.FieldValue: + m.ResetValue() return nil - case machine.FieldAuthType: - m.ResetAuthType() + case meta.FieldAlertMetas: + m.ResetAlertMetas() return nil } - return fmt.Errorf("unknown Machine field %s", name) + return fmt.Errorf("unknown Meta field %s", name) } // AddedEdges returns all edge names that were set/added in this mutation. -func (m *MachineMutation) AddedEdges() []string { +func (m *MetaMutation) AddedEdges() []string { edges := make([]string, 0, 1) - if m.alerts != nil { - edges = append(edges, machine.EdgeAlerts) + if m.owner != nil { + edges = append(edges, meta.EdgeOwner) } return edges } // AddedIDs returns all IDs (to other nodes) that were added for the given edge // name in this mutation. -func (m *MachineMutation) AddedIDs(name string) []ent.Value { +func (m *MetaMutation) AddedIDs(name string) []ent.Value { switch name { - case machine.EdgeAlerts: - ids := make([]ent.Value, 0, len(m.alerts)) - for id := range m.alerts { - ids = append(ids, id) + case meta.EdgeOwner: + if id := m.owner; id != nil { + return []ent.Value{*id} } - return ids } return nil } // RemovedEdges returns all edge names that were removed in this mutation. -func (m *MachineMutation) RemovedEdges() []string { +func (m *MetaMutation) RemovedEdges() []string { edges := make([]string, 0, 1) - if m.removedalerts != nil { - edges = append(edges, machine.EdgeAlerts) - } return edges } // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. -func (m *MachineMutation) RemovedIDs(name string) []ent.Value { - switch name { - case machine.EdgeAlerts: - ids := make([]ent.Value, 0, len(m.removedalerts)) - for id := range m.removedalerts { - ids = append(ids, id) - } - return ids - } +func (m *MetaMutation) RemovedIDs(name string) []ent.Value { return nil } // ClearedEdges returns all edge names that were cleared in this mutation. -func (m *MachineMutation) ClearedEdges() []string { +func (m *MetaMutation) ClearedEdges() []string { edges := make([]string, 0, 1) - if m.clearedalerts { - edges = append(edges, machine.EdgeAlerts) + if m.clearedowner { + edges = append(edges, meta.EdgeOwner) } return edges } // EdgeCleared returns a boolean which indicates if the edge with the given name // was cleared in this mutation. -func (m *MachineMutation) EdgeCleared(name string) bool { +func (m *MetaMutation) EdgeCleared(name string) bool { switch name { - case machine.EdgeAlerts: - return m.clearedalerts + case meta.EdgeOwner: + return m.clearedowner } return false } // ClearEdge clears the value of the edge with the given name. It returns an error // if that edge is not defined in the schema. -func (m *MachineMutation) ClearEdge(name string) error { +func (m *MetaMutation) ClearEdge(name string) error { switch name { + case meta.EdgeOwner: + m.ClearOwner() + return nil } - return fmt.Errorf("unknown Machine unique edge %s", name) + return fmt.Errorf("unknown Meta unique edge %s", name) } // ResetEdge resets all changes to the edge with the given name in this mutation. // It returns an error if the edge is not defined in the schema. -func (m *MachineMutation) ResetEdge(name string) error { +func (m *MetaMutation) ResetEdge(name string) error { switch name { - case machine.EdgeAlerts: - m.ResetAlerts() + case meta.EdgeOwner: + m.ResetOwner() return nil } - return fmt.Errorf("unknown Machine edge %s", name) + return fmt.Errorf("unknown Meta edge %s", name) } -// MetaMutation represents an operation that mutates the Meta nodes in the graph. -type MetaMutation struct { +// MetricMutation represents an operation that mutates the Metric nodes in the graph. +type MetricMutation struct { config - op Op - typ string - id *int - created_at *time.Time - updated_at *time.Time - key *string - value *string - clearedFields map[string]struct{} - owner *int - clearedowner bool - done bool - oldValue func(context.Context) (*Meta, error) - predicates []predicate.Meta + op Op + typ string + id *int + generated_type *metric.GeneratedType + generated_by *string + received_at *time.Time + pushed_at *time.Time + payload *string + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*Metric, error) + predicates []predicate.Metric } -var _ ent.Mutation = (*MetaMutation)(nil) +var _ ent.Mutation = (*MetricMutation)(nil) -// metaOption allows management of the mutation configuration using functional options. -type metaOption func(*MetaMutation) +// metricOption allows management of the mutation configuration using functional options. +type metricOption func(*MetricMutation) -// newMetaMutation creates new mutation for the Meta entity. -func newMetaMutation(c config, op Op, opts ...metaOption) *MetaMutation { - m := &MetaMutation{ +// newMetricMutation creates new mutation for the Metric entity. +func newMetricMutation(c config, op Op, opts ...metricOption) *MetricMutation { + m := &MetricMutation{ config: c, op: op, - typ: TypeMeta, + typ: TypeMetric, clearedFields: make(map[string]struct{}), } for _, opt := range opts { @@ -7259,20 +8668,20 @@ func newMetaMutation(c config, op Op, opts ...metaOption) *MetaMutation { return m } -// withMetaID sets the ID field of the mutation. -func withMetaID(id int) metaOption { - return func(m *MetaMutation) { +// withMetricID sets the ID field of the mutation. +func withMetricID(id int) metricOption { + return func(m *MetricMutation) { var ( err error once sync.Once - value *Meta + value *Metric ) - m.oldValue = func(ctx context.Context) (*Meta, error) { + m.oldValue = func(ctx context.Context) (*Metric, error) { once.Do(func() { if m.done { err = errors.New("querying old values post mutation is not allowed") } else { - value, err = m.Client().Meta.Get(ctx, id) + value, err = m.Client().Metric.Get(ctx, id) } }) return value, err @@ -7281,10 +8690,10 @@ func withMetaID(id int) metaOption { } } -// withMeta sets the old Meta of the mutation. -func withMeta(node *Meta) metaOption { - return func(m *MetaMutation) { - m.oldValue = func(context.Context) (*Meta, error) { +// withMetric sets the old Metric of the mutation. +func withMetric(node *Metric) metricOption { + return func(m *MetricMutation) { + m.oldValue = func(context.Context) (*Metric, error) { return node, nil } m.id = &node.ID @@ -7293,7 +8702,7 @@ func withMeta(node *Meta) metaOption { // Client returns a new `ent.Client` from the mutation. If the mutation was // executed in a transaction (ent.Tx), a transactional client is returned. -func (m MetaMutation) Client() *Client { +func (m MetricMutation) Client() *Client { client := &Client{config: m.config} client.init() return client @@ -7301,7 +8710,7 @@ func (m MetaMutation) Client() *Client { // Tx returns an `ent.Tx` for mutations that were executed in transactions; // it returns an error otherwise. -func (m MetaMutation) Tx() (*Tx, error) { +func (m MetricMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { return nil, errors.New("ent: mutation is not running in a transaction") } @@ -7312,7 +8721,7 @@ func (m MetaMutation) Tx() (*Tx, error) { // ID returns the ID value in the mutation. Note that the ID is only available // if it was provided to the builder or after it was returned from the database. -func (m *MetaMutation) ID() (id int, exists bool) { +func (m *MetricMutation) ID() (id int, exists bool) { if m.id == nil { return } @@ -7323,7 +8732,7 @@ func (m *MetaMutation) ID() (id int, exists bool) { // That means, if the mutation is applied within a transaction with an isolation level such // as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated // or updated by the mutation. -func (m *MetaMutation) IDs(ctx context.Context) ([]int, error) { +func (m *MetricMutation) IDs(ctx context.Context) ([]int, error) { switch { case m.op.Is(OpUpdateOne | OpDeleteOne): id, exists := m.ID() @@ -7332,304 +8741,254 @@ func (m *MetaMutation) IDs(ctx context.Context) ([]int, error) { } fallthrough case m.op.Is(OpUpdate | OpDelete): - return m.Client().Meta.Query().Where(m.predicates...).IDs(ctx) + return m.Client().Metric.Query().Where(m.predicates...).IDs(ctx) default: return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) } } -// SetCreatedAt sets the "created_at" field. -func (m *MetaMutation) SetCreatedAt(t time.Time) { - m.created_at = &t -} - -// CreatedAt returns the value of the "created_at" field in the mutation. -func (m *MetaMutation) CreatedAt() (r time.Time, exists bool) { - v := m.created_at - if v == nil { - return - } - return *v, true -} - -// OldCreatedAt returns the old "created_at" field's value of the Meta entity. -// If the Meta object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MetaMutation) OldCreatedAt(ctx context.Context) (v *time.Time, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldCreatedAt requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) - } - return oldValue.CreatedAt, nil -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (m *MetaMutation) ClearCreatedAt() { - m.created_at = nil - m.clearedFields[meta.FieldCreatedAt] = struct{}{} -} - -// CreatedAtCleared returns if the "created_at" field was cleared in this mutation. -func (m *MetaMutation) CreatedAtCleared() bool { - _, ok := m.clearedFields[meta.FieldCreatedAt] - return ok -} - -// ResetCreatedAt resets all changes to the "created_at" field. -func (m *MetaMutation) ResetCreatedAt() { - m.created_at = nil - delete(m.clearedFields, meta.FieldCreatedAt) -} - -// SetUpdatedAt sets the "updated_at" field. -func (m *MetaMutation) SetUpdatedAt(t time.Time) { - m.updated_at = &t -} - -// UpdatedAt returns the value of the "updated_at" field in the mutation. -func (m *MetaMutation) UpdatedAt() (r time.Time, exists bool) { - v := m.updated_at +// SetGeneratedType sets the "generated_type" field. +func (m *MetricMutation) SetGeneratedType(mt metric.GeneratedType) { + m.generated_type = &mt +} + +// GeneratedType returns the value of the "generated_type" field in the mutation. +func (m *MetricMutation) GeneratedType() (r metric.GeneratedType, exists bool) { + v := m.generated_type if v == nil { return } return *v, true } -// OldUpdatedAt returns the old "updated_at" field's value of the Meta entity. -// If the Meta object wasn't provided to the builder, the object is fetched from the database. +// OldGeneratedType returns the old "generated_type" field's value of the Metric entity. +// If the Metric object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MetaMutation) OldUpdatedAt(ctx context.Context) (v *time.Time, err error) { +func (m *MetricMutation) OldGeneratedType(ctx context.Context) (v metric.GeneratedType, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + return v, errors.New("OldGeneratedType is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + return v, errors.New("OldGeneratedType requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + return v, fmt.Errorf("querying old value for OldGeneratedType: %w", err) } - return oldValue.UpdatedAt, nil + return oldValue.GeneratedType, nil } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (m *MetaMutation) ClearUpdatedAt() { - m.updated_at = nil - m.clearedFields[meta.FieldUpdatedAt] = struct{}{} -} - -// UpdatedAtCleared returns if the "updated_at" field was cleared in this mutation. -func (m *MetaMutation) UpdatedAtCleared() bool { - _, ok := m.clearedFields[meta.FieldUpdatedAt] - return ok -} - -// ResetUpdatedAt resets all changes to the "updated_at" field. -func (m *MetaMutation) ResetUpdatedAt() { - m.updated_at = nil - delete(m.clearedFields, meta.FieldUpdatedAt) +// ResetGeneratedType resets all changes to the "generated_type" field. +func (m *MetricMutation) ResetGeneratedType() { + m.generated_type = nil } -// SetKey sets the "key" field. -func (m *MetaMutation) SetKey(s string) { - m.key = &s +// SetGeneratedBy sets the "generated_by" field. +func (m *MetricMutation) SetGeneratedBy(s string) { + m.generated_by = &s } -// Key returns the value of the "key" field in the mutation. -func (m *MetaMutation) Key() (r string, exists bool) { - v := m.key +// GeneratedBy returns the value of the "generated_by" field in the mutation. +func (m *MetricMutation) GeneratedBy() (r string, exists bool) { + v := m.generated_by if v == nil { return } return *v, true } -// OldKey returns the old "key" field's value of the Meta entity. -// If the Meta object wasn't provided to the builder, the object is fetched from the database. +// OldGeneratedBy returns the old "generated_by" field's value of the Metric entity. +// If the Metric object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MetaMutation) OldKey(ctx context.Context) (v string, err error) { +func (m *MetricMutation) OldGeneratedBy(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldKey is only allowed on UpdateOne operations") + return v, errors.New("OldGeneratedBy is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldKey requires an ID field in the mutation") + return v, errors.New("OldGeneratedBy requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldKey: %w", err) + return v, fmt.Errorf("querying old value for OldGeneratedBy: %w", err) } - return oldValue.Key, nil + return oldValue.GeneratedBy, nil } -// ResetKey resets all changes to the "key" field. -func (m *MetaMutation) ResetKey() { - m.key = nil +// ResetGeneratedBy resets all changes to the "generated_by" field. +func (m *MetricMutation) ResetGeneratedBy() { + m.generated_by = nil } -// SetValue sets the "value" field. -func (m *MetaMutation) SetValue(s string) { - m.value = &s +// SetReceivedAt sets the "received_at" field. +func (m *MetricMutation) SetReceivedAt(t time.Time) { + m.received_at = &t } -// Value returns the value of the "value" field in the mutation. -func (m *MetaMutation) Value() (r string, exists bool) { - v := m.value +// ReceivedAt returns the value of the "received_at" field in the mutation. +func (m *MetricMutation) ReceivedAt() (r time.Time, exists bool) { + v := m.received_at if v == nil { return } return *v, true } -// OldValue returns the old "value" field's value of the Meta entity. -// If the Meta object wasn't provided to the builder, the object is fetched from the database. +// OldReceivedAt returns the old "received_at" field's value of the Metric entity. +// If the Metric object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MetaMutation) OldValue(ctx context.Context) (v string, err error) { +func (m *MetricMutation) OldReceivedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldValue is only allowed on UpdateOne operations") + return v, errors.New("OldReceivedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldValue requires an ID field in the mutation") + return v, errors.New("OldReceivedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldValue: %w", err) + return v, fmt.Errorf("querying old value for OldReceivedAt: %w", err) } - return oldValue.Value, nil + return oldValue.ReceivedAt, nil } -// ResetValue resets all changes to the "value" field. -func (m *MetaMutation) ResetValue() { - m.value = nil +// ResetReceivedAt resets all changes to the "received_at" field. +func (m *MetricMutation) ResetReceivedAt() { + m.received_at = nil } -// SetAlertMetas sets the "alert_metas" field. -func (m *MetaMutation) SetAlertMetas(i int) { - m.owner = &i +// SetPushedAt sets the "pushed_at" field. +func (m *MetricMutation) SetPushedAt(t time.Time) { + m.pushed_at = &t } -// AlertMetas returns the value of the "alert_metas" field in the mutation. -func (m *MetaMutation) AlertMetas() (r int, exists bool) { - v := m.owner +// PushedAt returns the value of the "pushed_at" field in the mutation. +func (m *MetricMutation) PushedAt() (r time.Time, exists bool) { + v := m.pushed_at if v == nil { return } return *v, true } -// OldAlertMetas returns the old "alert_metas" field's value of the Meta entity. -// If the Meta object wasn't provided to the builder, the object is fetched from the database. +// OldPushedAt returns the old "pushed_at" field's value of the Metric entity. +// If the Metric object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MetaMutation) OldAlertMetas(ctx context.Context) (v int, err error) { +func (m *MetricMutation) OldPushedAt(ctx context.Context) (v *time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldAlertMetas is only allowed on UpdateOne operations") + return v, errors.New("OldPushedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldAlertMetas requires an ID field in the mutation") + return v, errors.New("OldPushedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldAlertMetas: %w", err) + return v, fmt.Errorf("querying old value for OldPushedAt: %w", err) } - return oldValue.AlertMetas, nil + return oldValue.PushedAt, nil } -// ClearAlertMetas clears the value of the "alert_metas" field. -func (m *MetaMutation) ClearAlertMetas() { - m.owner = nil - m.clearedFields[meta.FieldAlertMetas] = struct{}{} +// ClearPushedAt clears the value of the "pushed_at" field. +func (m *MetricMutation) ClearPushedAt() { + m.pushed_at = nil + m.clearedFields[metric.FieldPushedAt] = struct{}{} } -// AlertMetasCleared returns if the "alert_metas" field was cleared in this mutation. -func (m *MetaMutation) AlertMetasCleared() bool { - _, ok := m.clearedFields[meta.FieldAlertMetas] +// PushedAtCleared returns if the "pushed_at" field was cleared in this mutation. +func (m *MetricMutation) PushedAtCleared() bool { + _, ok := m.clearedFields[metric.FieldPushedAt] return ok } -// ResetAlertMetas resets all changes to the "alert_metas" field. -func (m *MetaMutation) ResetAlertMetas() { - m.owner = nil - delete(m.clearedFields, meta.FieldAlertMetas) -} - -// SetOwnerID sets the "owner" edge to the Alert entity by id. -func (m *MetaMutation) SetOwnerID(id int) { - m.owner = &id -} - -// ClearOwner clears the "owner" edge to the Alert entity. -func (m *MetaMutation) ClearOwner() { - m.clearedowner = true +// ResetPushedAt resets all changes to the "pushed_at" field. +func (m *MetricMutation) ResetPushedAt() { + m.pushed_at = nil + delete(m.clearedFields, metric.FieldPushedAt) } -// OwnerCleared reports if the "owner" edge to the Alert entity was cleared. -func (m *MetaMutation) OwnerCleared() bool { - return m.AlertMetasCleared() || m.clearedowner +// SetPayload sets the "payload" field. +func (m *MetricMutation) SetPayload(s string) { + m.payload = &s } -// OwnerID returns the "owner" edge ID in the mutation. -func (m *MetaMutation) OwnerID() (id int, exists bool) { - if m.owner != nil { - return *m.owner, true +// Payload returns the value of the "payload" field in the mutation. +func (m *MetricMutation) Payload() (r string, exists bool) { + v := m.payload + if v == nil { + return } - return + return *v, true } -// OwnerIDs returns the "owner" edge IDs in the mutation. -// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use -// OwnerID instead. It exists only for internal usage by the builders. -func (m *MetaMutation) OwnerIDs() (ids []int) { - if id := m.owner; id != nil { - ids = append(ids, *id) +// OldPayload returns the old "payload" field's value of the Metric entity. +// If the Metric object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MetricMutation) OldPayload(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPayload is only allowed on UpdateOne operations") } - return + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPayload requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPayload: %w", err) + } + return oldValue.Payload, nil } -// ResetOwner resets all changes to the "owner" edge. -func (m *MetaMutation) ResetOwner() { - m.owner = nil - m.clearedowner = false +// ResetPayload resets all changes to the "payload" field. +func (m *MetricMutation) ResetPayload() { + m.payload = nil } -// Where appends a list predicates to the MetaMutation builder. -func (m *MetaMutation) Where(ps ...predicate.Meta) { +// Where appends a list predicates to the MetricMutation builder. +func (m *MetricMutation) Where(ps ...predicate.Metric) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the MetricMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *MetricMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Metric, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. -func (m *MetaMutation) Op() Op { +func (m *MetricMutation) Op() Op { return m.op } -// Type returns the node type of this mutation (Meta). -func (m *MetaMutation) Type() string { +// SetOp allows setting the mutation operation. +func (m *MetricMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (Metric). +func (m *MetricMutation) Type() string { return m.typ } // Fields returns all fields that were changed during this mutation. Note that in // order to get all numeric fields that were incremented/decremented, call // AddedFields(). -func (m *MetaMutation) Fields() []string { +func (m *MetricMutation) Fields() []string { fields := make([]string, 0, 5) - if m.created_at != nil { - fields = append(fields, meta.FieldCreatedAt) + if m.generated_type != nil { + fields = append(fields, metric.FieldGeneratedType) } - if m.updated_at != nil { - fields = append(fields, meta.FieldUpdatedAt) + if m.generated_by != nil { + fields = append(fields, metric.FieldGeneratedBy) } - if m.key != nil { - fields = append(fields, meta.FieldKey) + if m.received_at != nil { + fields = append(fields, metric.FieldReceivedAt) } - if m.value != nil { - fields = append(fields, meta.FieldValue) + if m.pushed_at != nil { + fields = append(fields, metric.FieldPushedAt) } - if m.owner != nil { - fields = append(fields, meta.FieldAlertMetas) + if m.payload != nil { + fields = append(fields, metric.FieldPayload) } return fields } @@ -7637,18 +8996,18 @@ func (m *MetaMutation) Fields() []string { // Field returns the value of a field with the given name. The second boolean // return value indicates that this field was not set, or was not defined in the // schema. -func (m *MetaMutation) Field(name string) (ent.Value, bool) { +func (m *MetricMutation) Field(name string) (ent.Value, bool) { switch name { - case meta.FieldCreatedAt: - return m.CreatedAt() - case meta.FieldUpdatedAt: - return m.UpdatedAt() - case meta.FieldKey: - return m.Key() - case meta.FieldValue: - return m.Value() - case meta.FieldAlertMetas: - return m.AlertMetas() + case metric.FieldGeneratedType: + return m.GeneratedType() + case metric.FieldGeneratedBy: + return m.GeneratedBy() + case metric.FieldReceivedAt: + return m.ReceivedAt() + case metric.FieldPushedAt: + return m.PushedAt() + case metric.FieldPayload: + return m.Payload() } return nil, false } @@ -7656,224 +9015,183 @@ func (m *MetaMutation) Field(name string) (ent.Value, bool) { // OldField returns the old value of the field from the database. An error is // returned if the mutation operation is not UpdateOne, or the query to the // database failed. -func (m *MetaMutation) OldField(ctx context.Context, name string) (ent.Value, error) { +func (m *MetricMutation) OldField(ctx context.Context, name string) (ent.Value, error) { switch name { - case meta.FieldCreatedAt: - return m.OldCreatedAt(ctx) - case meta.FieldUpdatedAt: - return m.OldUpdatedAt(ctx) - case meta.FieldKey: - return m.OldKey(ctx) - case meta.FieldValue: - return m.OldValue(ctx) - case meta.FieldAlertMetas: - return m.OldAlertMetas(ctx) + case metric.FieldGeneratedType: + return m.OldGeneratedType(ctx) + case metric.FieldGeneratedBy: + return m.OldGeneratedBy(ctx) + case metric.FieldReceivedAt: + return m.OldReceivedAt(ctx) + case metric.FieldPushedAt: + return m.OldPushedAt(ctx) + case metric.FieldPayload: + return m.OldPayload(ctx) } - return nil, fmt.Errorf("unknown Meta field %s", name) + return nil, fmt.Errorf("unknown Metric field %s", name) } // SetField sets the value of a field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *MetaMutation) SetField(name string, value ent.Value) error { +func (m *MetricMutation) SetField(name string, value ent.Value) error { switch name { - case meta.FieldCreatedAt: - v, ok := value.(time.Time) + case metric.FieldGeneratedType: + v, ok := value.(metric.GeneratedType) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetCreatedAt(v) + m.SetGeneratedType(v) return nil - case meta.FieldUpdatedAt: - v, ok := value.(time.Time) + case metric.FieldGeneratedBy: + v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetUpdatedAt(v) + m.SetGeneratedBy(v) return nil - case meta.FieldKey: - v, ok := value.(string) + case metric.FieldReceivedAt: + v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetKey(v) + m.SetReceivedAt(v) return nil - case meta.FieldValue: - v, ok := value.(string) + case metric.FieldPushedAt: + v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetValue(v) + m.SetPushedAt(v) return nil - case meta.FieldAlertMetas: - v, ok := value.(int) + case metric.FieldPayload: + v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetAlertMetas(v) + m.SetPayload(v) return nil } - return fmt.Errorf("unknown Meta field %s", name) + return fmt.Errorf("unknown Metric field %s", name) } // AddedFields returns all numeric fields that were incremented/decremented during // this mutation. -func (m *MetaMutation) AddedFields() []string { - var fields []string - return fields +func (m *MetricMutation) AddedFields() []string { + return nil } // AddedField returns the numeric value that was incremented/decremented on a field // with the given name. The second boolean return value indicates that this field // was not set, or was not defined in the schema. -func (m *MetaMutation) AddedField(name string) (ent.Value, bool) { - switch name { - } +func (m *MetricMutation) AddedField(name string) (ent.Value, bool) { return nil, false } // AddField adds the value to the field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *MetaMutation) AddField(name string, value ent.Value) error { +func (m *MetricMutation) AddField(name string, value ent.Value) error { switch name { } - return fmt.Errorf("unknown Meta numeric field %s", name) + return fmt.Errorf("unknown Metric numeric field %s", name) } // ClearedFields returns all nullable fields that were cleared during this // mutation. -func (m *MetaMutation) ClearedFields() []string { +func (m *MetricMutation) ClearedFields() []string { var fields []string - if m.FieldCleared(meta.FieldCreatedAt) { - fields = append(fields, meta.FieldCreatedAt) - } - if m.FieldCleared(meta.FieldUpdatedAt) { - fields = append(fields, meta.FieldUpdatedAt) - } - if m.FieldCleared(meta.FieldAlertMetas) { - fields = append(fields, meta.FieldAlertMetas) + if m.FieldCleared(metric.FieldPushedAt) { + fields = append(fields, metric.FieldPushedAt) } return fields } // FieldCleared returns a boolean indicating if a field with the given name was // cleared in this mutation. -func (m *MetaMutation) FieldCleared(name string) bool { +func (m *MetricMutation) FieldCleared(name string) bool { _, ok := m.clearedFields[name] return ok } // ClearField clears the value of the field with the given name. It returns an // error if the field is not defined in the schema. -func (m *MetaMutation) ClearField(name string) error { +func (m *MetricMutation) ClearField(name string) error { switch name { - case meta.FieldCreatedAt: - m.ClearCreatedAt() - return nil - case meta.FieldUpdatedAt: - m.ClearUpdatedAt() - return nil - case meta.FieldAlertMetas: - m.ClearAlertMetas() + case metric.FieldPushedAt: + m.ClearPushedAt() return nil } - return fmt.Errorf("unknown Meta nullable field %s", name) + return fmt.Errorf("unknown Metric nullable field %s", name) } // ResetField resets all changes in the mutation for the field with the given name. // It returns an error if the field is not defined in the schema. -func (m *MetaMutation) ResetField(name string) error { +func (m *MetricMutation) ResetField(name string) error { switch name { - case meta.FieldCreatedAt: - m.ResetCreatedAt() + case metric.FieldGeneratedType: + m.ResetGeneratedType() return nil - case meta.FieldUpdatedAt: - m.ResetUpdatedAt() + case metric.FieldGeneratedBy: + m.ResetGeneratedBy() return nil - case meta.FieldKey: - m.ResetKey() + case metric.FieldReceivedAt: + m.ResetReceivedAt() return nil - case meta.FieldValue: - m.ResetValue() + case metric.FieldPushedAt: + m.ResetPushedAt() return nil - case meta.FieldAlertMetas: - m.ResetAlertMetas() + case metric.FieldPayload: + m.ResetPayload() return nil } - return fmt.Errorf("unknown Meta field %s", name) + return fmt.Errorf("unknown Metric field %s", name) } // AddedEdges returns all edge names that were set/added in this mutation. -func (m *MetaMutation) AddedEdges() []string { - edges := make([]string, 0, 1) - if m.owner != nil { - edges = append(edges, meta.EdgeOwner) - } +func (m *MetricMutation) AddedEdges() []string { + edges := make([]string, 0, 0) return edges } // AddedIDs returns all IDs (to other nodes) that were added for the given edge // name in this mutation. -func (m *MetaMutation) AddedIDs(name string) []ent.Value { - switch name { - case meta.EdgeOwner: - if id := m.owner; id != nil { - return []ent.Value{*id} - } - } +func (m *MetricMutation) AddedIDs(name string) []ent.Value { return nil } // RemovedEdges returns all edge names that were removed in this mutation. -func (m *MetaMutation) RemovedEdges() []string { - edges := make([]string, 0, 1) +func (m *MetricMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) return edges } // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. -func (m *MetaMutation) RemovedIDs(name string) []ent.Value { +func (m *MetricMutation) RemovedIDs(name string) []ent.Value { return nil } // ClearedEdges returns all edge names that were cleared in this mutation. -func (m *MetaMutation) ClearedEdges() []string { - edges := make([]string, 0, 1) - if m.clearedowner { - edges = append(edges, meta.EdgeOwner) - } +func (m *MetricMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) return edges } // EdgeCleared returns a boolean which indicates if the edge with the given name // was cleared in this mutation. -func (m *MetaMutation) EdgeCleared(name string) bool { - switch name { - case meta.EdgeOwner: - return m.clearedowner - } +func (m *MetricMutation) EdgeCleared(name string) bool { return false } // ClearEdge clears the value of the edge with the given name. It returns an error // if that edge is not defined in the schema. -func (m *MetaMutation) ClearEdge(name string) error { - switch name { - case meta.EdgeOwner: - m.ClearOwner() - return nil - } - return fmt.Errorf("unknown Meta unique edge %s", name) +func (m *MetricMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown Metric unique edge %s", name) } // ResetEdge resets all changes to the edge with the given name in this mutation. // It returns an error if the edge is not defined in the schema. -func (m *MetaMutation) ResetEdge(name string) error { - switch name { - case meta.EdgeOwner: - m.ResetOwner() - return nil - } - return fmt.Errorf("unknown Meta edge %s", name) +func (m *MetricMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown Metric edge %s", name) } diff --git a/pkg/database/ent/predicate/predicate.go b/pkg/database/ent/predicate/predicate.go index e95abcec343..8ad03e2fc48 100644 --- a/pkg/database/ent/predicate/predicate.go +++ b/pkg/database/ent/predicate/predicate.go @@ -21,8 +21,14 @@ type Decision func(*sql.Selector) // Event is the predicate function for event builders. type Event func(*sql.Selector) +// Lock is the predicate function for lock builders. +type Lock func(*sql.Selector) + // Machine is the predicate function for machine builders. type Machine func(*sql.Selector) // Meta is the predicate function for meta builders. type Meta func(*sql.Selector) + +// Metric is the predicate function for metric builders. +type Metric func(*sql.Selector) diff --git a/pkg/database/ent/runtime.go b/pkg/database/ent/runtime.go index bceea37b3a7..15413490633 100644 --- a/pkg/database/ent/runtime.go +++ b/pkg/database/ent/runtime.go @@ -10,6 +10,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" "github.com/crowdsecurity/crowdsec/pkg/database/ent/event" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" "github.com/crowdsecurity/crowdsec/pkg/database/ent/meta" "github.com/crowdsecurity/crowdsec/pkg/database/ent/schema" @@ -25,8 +26,6 @@ func init() { alertDescCreatedAt := alertFields[0].Descriptor() // alert.DefaultCreatedAt holds the default value on creation for the created_at field. alert.DefaultCreatedAt = alertDescCreatedAt.Default.(func() time.Time) - // alert.UpdateDefaultCreatedAt holds the default value on update for the created_at field. - alert.UpdateDefaultCreatedAt = alertDescCreatedAt.UpdateDefault.(func() time.Time) // alertDescUpdatedAt is the schema descriptor for updated_at field. alertDescUpdatedAt := alertFields[1].Descriptor() // alert.DefaultUpdatedAt holds the default value on creation for the updated_at field. @@ -63,8 +62,6 @@ func init() { bouncerDescCreatedAt := bouncerFields[0].Descriptor() // bouncer.DefaultCreatedAt holds the default value on creation for the created_at field. bouncer.DefaultCreatedAt = bouncerDescCreatedAt.Default.(func() time.Time) - // bouncer.UpdateDefaultCreatedAt holds the default value on update for the created_at field. - bouncer.UpdateDefaultCreatedAt = bouncerDescCreatedAt.UpdateDefault.(func() time.Time) // bouncerDescUpdatedAt is the schema descriptor for updated_at field. bouncerDescUpdatedAt := bouncerFields[1].Descriptor() // bouncer.DefaultUpdatedAt holds the default value on creation for the updated_at field. @@ -75,16 +72,8 @@ func init() { bouncerDescIPAddress := bouncerFields[5].Descriptor() // bouncer.DefaultIPAddress holds the default value on creation for the ip_address field. bouncer.DefaultIPAddress = bouncerDescIPAddress.Default.(string) - // bouncerDescUntil is the schema descriptor for until field. - bouncerDescUntil := bouncerFields[8].Descriptor() - // bouncer.DefaultUntil holds the default value on creation for the until field. - bouncer.DefaultUntil = bouncerDescUntil.Default.(func() time.Time) - // bouncerDescLastPull is the schema descriptor for last_pull field. - bouncerDescLastPull := bouncerFields[9].Descriptor() - // bouncer.DefaultLastPull holds the default value on creation for the last_pull field. - bouncer.DefaultLastPull = bouncerDescLastPull.Default.(func() time.Time) // bouncerDescAuthType is the schema descriptor for auth_type field. - bouncerDescAuthType := bouncerFields[10].Descriptor() + bouncerDescAuthType := bouncerFields[9].Descriptor() // bouncer.DefaultAuthType holds the default value on creation for the auth_type field. bouncer.DefaultAuthType = bouncerDescAuthType.Default.(string) configitemFields := schema.ConfigItem{}.Fields() @@ -93,8 +82,6 @@ func init() { configitemDescCreatedAt := configitemFields[0].Descriptor() // configitem.DefaultCreatedAt holds the default value on creation for the created_at field. configitem.DefaultCreatedAt = configitemDescCreatedAt.Default.(func() time.Time) - // configitem.UpdateDefaultCreatedAt holds the default value on update for the created_at field. - configitem.UpdateDefaultCreatedAt = configitemDescCreatedAt.UpdateDefault.(func() time.Time) // configitemDescUpdatedAt is the schema descriptor for updated_at field. configitemDescUpdatedAt := configitemFields[1].Descriptor() // configitem.DefaultUpdatedAt holds the default value on creation for the updated_at field. @@ -107,8 +94,6 @@ func init() { decisionDescCreatedAt := decisionFields[0].Descriptor() // decision.DefaultCreatedAt holds the default value on creation for the created_at field. decision.DefaultCreatedAt = decisionDescCreatedAt.Default.(func() time.Time) - // decision.UpdateDefaultCreatedAt holds the default value on update for the created_at field. - decision.UpdateDefaultCreatedAt = decisionDescCreatedAt.UpdateDefault.(func() time.Time) // decisionDescUpdatedAt is the schema descriptor for updated_at field. decisionDescUpdatedAt := decisionFields[1].Descriptor() // decision.DefaultUpdatedAt holds the default value on creation for the updated_at field. @@ -125,8 +110,6 @@ func init() { eventDescCreatedAt := eventFields[0].Descriptor() // event.DefaultCreatedAt holds the default value on creation for the created_at field. event.DefaultCreatedAt = eventDescCreatedAt.Default.(func() time.Time) - // event.UpdateDefaultCreatedAt holds the default value on update for the created_at field. - event.UpdateDefaultCreatedAt = eventDescCreatedAt.UpdateDefault.(func() time.Time) // eventDescUpdatedAt is the schema descriptor for updated_at field. eventDescUpdatedAt := eventFields[1].Descriptor() // event.DefaultUpdatedAt holds the default value on creation for the updated_at field. @@ -137,14 +120,18 @@ func init() { eventDescSerialized := eventFields[3].Descriptor() // event.SerializedValidator is a validator for the "serialized" field. It is called by the builders before save. event.SerializedValidator = eventDescSerialized.Validators[0].(func(string) error) + lockFields := schema.Lock{}.Fields() + _ = lockFields + // lockDescCreatedAt is the schema descriptor for created_at field. + lockDescCreatedAt := lockFields[1].Descriptor() + // lock.DefaultCreatedAt holds the default value on creation for the created_at field. + lock.DefaultCreatedAt = lockDescCreatedAt.Default.(func() time.Time) machineFields := schema.Machine{}.Fields() _ = machineFields // machineDescCreatedAt is the schema descriptor for created_at field. machineDescCreatedAt := machineFields[0].Descriptor() // machine.DefaultCreatedAt holds the default value on creation for the created_at field. machine.DefaultCreatedAt = machineDescCreatedAt.Default.(func() time.Time) - // machine.UpdateDefaultCreatedAt holds the default value on update for the created_at field. - machine.UpdateDefaultCreatedAt = machineDescCreatedAt.UpdateDefault.(func() time.Time) // machineDescUpdatedAt is the schema descriptor for updated_at field. machineDescUpdatedAt := machineFields[1].Descriptor() // machine.DefaultUpdatedAt holds the default value on creation for the updated_at field. @@ -155,14 +142,6 @@ func init() { machineDescLastPush := machineFields[2].Descriptor() // machine.DefaultLastPush holds the default value on creation for the last_push field. machine.DefaultLastPush = machineDescLastPush.Default.(func() time.Time) - // machine.UpdateDefaultLastPush holds the default value on update for the last_push field. - machine.UpdateDefaultLastPush = machineDescLastPush.UpdateDefault.(func() time.Time) - // machineDescLastHeartbeat is the schema descriptor for last_heartbeat field. - machineDescLastHeartbeat := machineFields[3].Descriptor() - // machine.DefaultLastHeartbeat holds the default value on creation for the last_heartbeat field. - machine.DefaultLastHeartbeat = machineDescLastHeartbeat.Default.(func() time.Time) - // machine.UpdateDefaultLastHeartbeat holds the default value on update for the last_heartbeat field. - machine.UpdateDefaultLastHeartbeat = machineDescLastHeartbeat.UpdateDefault.(func() time.Time) // machineDescScenarios is the schema descriptor for scenarios field. machineDescScenarios := machineFields[7].Descriptor() // machine.ScenariosValidator is a validator for the "scenarios" field. It is called by the builders before save. @@ -172,7 +151,7 @@ func init() { // machine.DefaultIsValidated holds the default value on creation for the isValidated field. machine.DefaultIsValidated = machineDescIsValidated.Default.(bool) // machineDescAuthType is the schema descriptor for auth_type field. - machineDescAuthType := machineFields[11].Descriptor() + machineDescAuthType := machineFields[10].Descriptor() // machine.DefaultAuthType holds the default value on creation for the auth_type field. machine.DefaultAuthType = machineDescAuthType.Default.(string) metaFields := schema.Meta{}.Fields() @@ -181,8 +160,6 @@ func init() { metaDescCreatedAt := metaFields[0].Descriptor() // meta.DefaultCreatedAt holds the default value on creation for the created_at field. meta.DefaultCreatedAt = metaDescCreatedAt.Default.(func() time.Time) - // meta.UpdateDefaultCreatedAt holds the default value on update for the created_at field. - meta.UpdateDefaultCreatedAt = metaDescCreatedAt.UpdateDefault.(func() time.Time) // metaDescUpdatedAt is the schema descriptor for updated_at field. metaDescUpdatedAt := metaFields[1].Descriptor() // meta.DefaultUpdatedAt holds the default value on creation for the updated_at field. diff --git a/pkg/database/ent/runtime/runtime.go b/pkg/database/ent/runtime/runtime.go index e64f7bd7554..9cb9d96258a 100644 --- a/pkg/database/ent/runtime/runtime.go +++ b/pkg/database/ent/runtime/runtime.go @@ -5,6 +5,6 @@ package runtime // The schema-stitching logic is generated in github.com/crowdsecurity/crowdsec/pkg/database/ent/runtime.go const ( - Version = "v0.11.3" // Version of ent codegen. - Sum = "h1:F5FBGAWiDCGder7YT+lqMnyzXl6d0xU3xMBM/SO3CMc=" // Sum of ent codegen. + Version = "v0.13.1" // Version of ent codegen. + Sum = "h1:uD8QwN1h6SNphdCCzmkMN3feSUzNnVvV/WIkHKMbzOE=" // Sum of ent codegen. ) diff --git a/pkg/database/ent/schema/alert.go b/pkg/database/ent/schema/alert.go index f2df9d7f09c..87ace24aa84 100644 --- a/pkg/database/ent/schema/alert.go +++ b/pkg/database/ent/schema/alert.go @@ -6,6 +6,7 @@ import ( "entgo.io/ent/schema/edge" "entgo.io/ent/schema/field" "entgo.io/ent/schema/index" + "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -19,38 +20,39 @@ func (Alert) Fields() []ent.Field { return []ent.Field{ field.Time("created_at"). Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional(), + Immutable(), field.Time("updated_at"). Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional(), - field.String("scenario"), - field.String("bucketId").Default("").Optional(), - field.String("message").Default("").Optional(), - field.Int32("eventsCount").Default(0).Optional(), - field.Time("startedAt").Default(types.UtcNow).Optional(), - field.Time("stoppedAt").Default(types.UtcNow).Optional(), + UpdateDefault(types.UtcNow), + field.String("scenario").Immutable(), + field.String("bucketId").Default("").Optional().Immutable(), + field.String("message").Default("").Optional().Immutable(), + field.Int32("eventsCount").Default(0).Optional().Immutable(), + field.Time("startedAt").Default(types.UtcNow).Optional().Immutable(), + field.Time("stoppedAt").Default(types.UtcNow).Optional().Immutable(), field.String("sourceIp"). - Optional(), + Optional().Immutable(), field.String("sourceRange"). - Optional(), + Optional().Immutable(), field.String("sourceAsNumber"). - Optional(), + Optional().Immutable(), field.String("sourceAsName"). - Optional(), + Optional().Immutable(), field.String("sourceCountry"). - Optional(), + Optional().Immutable(), field.Float32("sourceLatitude"). - Optional(), + Optional().Immutable(), field.Float32("sourceLongitude"). - Optional(), - field.String("sourceScope").Optional(), - field.String("sourceValue").Optional(), - field.Int32("capacity").Optional(), - field.String("leakSpeed").Optional(), - field.String("scenarioVersion").Optional(), - field.String("scenarioHash").Optional(), - field.Bool("simulated").Default(false), - field.String("uuid").Optional(), //this uuid is mostly here to ensure that CAPI/PAPI has a unique id for each alert + Optional().Immutable(), + field.String("sourceScope").Optional().Immutable(), + field.String("sourceValue").Optional().Immutable(), + field.Int32("capacity").Optional().Immutable(), + field.String("leakSpeed").Optional().Immutable(), + field.String("scenarioVersion").Optional().Immutable(), + field.String("scenarioHash").Optional().Immutable(), + field.Bool("simulated").Default(false).Immutable(), + field.String("uuid").Optional().Immutable(), // this uuid is mostly here to ensure that CAPI/PAPI has a unique id for each alert + field.Bool("remediation").Optional().Immutable(), } } diff --git a/pkg/database/ent/schema/bouncer.go b/pkg/database/ent/schema/bouncer.go index c3081291254..599c4c404fc 100644 --- a/pkg/database/ent/schema/bouncer.go +++ b/pkg/database/ent/schema/bouncer.go @@ -3,6 +3,7 @@ package schema import ( "entgo.io/ent" "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -16,20 +17,22 @@ func (Bouncer) Fields() []ent.Field { return []ent.Field{ field.Time("created_at"). Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional().StructTag(`json:"created_at"`), + StructTag(`json:"created_at"`). + Immutable(), field.Time("updated_at"). Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional().StructTag(`json:"updated_at"`), - field.String("name").Unique().StructTag(`json:"name"`), - field.String("api_key").StructTag(`json:"api_key"`), // hash of api_key + UpdateDefault(types.UtcNow).StructTag(`json:"updated_at"`), + field.String("name").Unique().StructTag(`json:"name"`).Immutable(), + field.String("api_key").Sensitive(), // hash of api_key field.Bool("revoked").StructTag(`json:"revoked"`), field.String("ip_address").Default("").Optional().StructTag(`json:"ip_address"`), field.String("type").Optional().StructTag(`json:"type"`), field.String("version").Optional().StructTag(`json:"version"`), - field.Time("until").Default(types.UtcNow).Optional().StructTag(`json:"until"`), - field.Time("last_pull"). - Default(types.UtcNow).StructTag(`json:"last_pull"`), + field.Time("last_pull").Nillable().Optional().StructTag(`json:"last_pull"`), field.String("auth_type").StructTag(`json:"auth_type"`).Default(types.ApiKeyAuthType), + field.String("osname").Optional(), + field.String("osversion").Optional(), + field.String("featureflags").Optional(), } } diff --git a/pkg/database/ent/schema/config.go b/pkg/database/ent/schema/config.go index f3320a9cce6..d526db25a8d 100644 --- a/pkg/database/ent/schema/config.go +++ b/pkg/database/ent/schema/config.go @@ -3,6 +3,7 @@ package schema import ( "entgo.io/ent" "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -11,21 +12,20 @@ type ConfigItem struct { ent.Schema } -// Fields of the Bouncer. func (ConfigItem) Fields() []ent.Field { return []ent.Field{ field.Time("created_at"). Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional().StructTag(`json:"created_at"`), + Immutable(). + StructTag(`json:"created_at"`), field.Time("updated_at"). Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional().StructTag(`json:"updated_at"`), - field.String("name").Unique().StructTag(`json:"name"`), + UpdateDefault(types.UtcNow).StructTag(`json:"updated_at"`), + field.String("name").Unique().StructTag(`json:"name"`).Immutable(), field.String("value").StructTag(`json:"value"`), // a json object } } -// Edges of the Bouncer. func (ConfigItem) Edges() []ent.Edge { return nil } diff --git a/pkg/database/ent/schema/decision.go b/pkg/database/ent/schema/decision.go index b7a99fb7a70..4089be38096 100644 --- a/pkg/database/ent/schema/decision.go +++ b/pkg/database/ent/schema/decision.go @@ -6,6 +6,7 @@ import ( "entgo.io/ent/schema/edge" "entgo.io/ent/schema/field" "entgo.io/ent/schema/index" + "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -19,25 +20,25 @@ func (Decision) Fields() []ent.Field { return []ent.Field{ field.Time("created_at"). Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional(), + Immutable(), field.Time("updated_at"). Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional(), + UpdateDefault(types.UtcNow), field.Time("until").Nillable().Optional().SchemaType(map[string]string{ dialect.MySQL: "datetime", }), - field.String("scenario"), - field.String("type"), - field.Int64("start_ip").Optional(), - field.Int64("end_ip").Optional(), - field.Int64("start_suffix").Optional(), - field.Int64("end_suffix").Optional(), - field.Int64("ip_size").Optional(), - field.String("scope"), - field.String("value"), - field.String("origin"), - field.Bool("simulated").Default(false), - field.String("uuid").Optional(), //this uuid is mostly here to ensure that CAPI/PAPI has a unique id for each decision + field.String("scenario").Immutable(), + field.String("type").Immutable(), + field.Int64("start_ip").Optional().Immutable(), + field.Int64("end_ip").Optional().Immutable(), + field.Int64("start_suffix").Optional().Immutable(), + field.Int64("end_suffix").Optional().Immutable(), + field.Int64("ip_size").Optional().Immutable(), + field.String("scope").Immutable(), + field.String("value").Immutable(), + field.String("origin").Immutable(), + field.Bool("simulated").Default(false).Immutable(), + field.String("uuid").Optional().Immutable(), // this uuid is mostly here to ensure that CAPI/PAPI has a unique id for each decision field.Int("alert_decisions").Optional(), } } diff --git a/pkg/database/ent/schema/event.go b/pkg/database/ent/schema/event.go index 6b6d2733ff7..107f68e5274 100644 --- a/pkg/database/ent/schema/event.go +++ b/pkg/database/ent/schema/event.go @@ -5,6 +5,7 @@ import ( "entgo.io/ent/schema/edge" "entgo.io/ent/schema/field" "entgo.io/ent/schema/index" + "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -18,12 +19,12 @@ func (Event) Fields() []ent.Field { return []ent.Field{ field.Time("created_at"). Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional(), + Immutable(), field.Time("updated_at"). Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional(), - field.Time("time"), - field.String("serialized").MaxLen(8191), + UpdateDefault(types.UtcNow), + field.Time("time").Immutable(), + field.String("serialized").MaxLen(8191).Immutable(), field.Int("alert_events").Optional(), } } diff --git a/pkg/database/ent/schema/lock.go b/pkg/database/ent/schema/lock.go new file mode 100644 index 00000000000..a287e2b59ad --- /dev/null +++ b/pkg/database/ent/schema/lock.go @@ -0,0 +1,23 @@ +package schema + +import ( + "entgo.io/ent" + "entgo.io/ent/schema/field" + + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +type Lock struct { + ent.Schema +} + +func (Lock) Fields() []ent.Field { + return []ent.Field{ + field.String("name").Unique().Immutable().StructTag(`json:"name"`), + field.Time("created_at").Default(types.UtcNow).StructTag(`json:"created_at"`).Immutable(), + } +} + +func (Lock) Edges() []ent.Edge { + return nil +} diff --git a/pkg/database/ent/schema/machine.go b/pkg/database/ent/schema/machine.go index e155c936071..5b68f61b1a0 100644 --- a/pkg/database/ent/schema/machine.go +++ b/pkg/database/ent/schema/machine.go @@ -4,9 +4,17 @@ import ( "entgo.io/ent" "entgo.io/ent/schema/edge" "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/types" ) +// ItemState is defined here instead of using pkg/models/HubItem to avoid introducing a dependency +type ItemState struct { + Name string `json:"name,omitempty"` + Status string `json:"status,omitempty"` + Version string `json:"version,omitempty"` +} + // Machine holds the schema definition for the Machine entity. type Machine struct { ent.Schema @@ -17,25 +25,30 @@ func (Machine) Fields() []ent.Field { return []ent.Field{ field.Time("created_at"). Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional(), + Immutable(), field.Time("updated_at"). Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional(), + UpdateDefault(types.UtcNow), field.Time("last_push"). Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional(), + Nillable().Optional(), field.Time("last_heartbeat"). - Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional(), - field.String("machineId").Unique(), + Nillable().Optional(), + field.String("machineId"). + Unique(). + Immutable(), field.String("password").Sensitive(), field.String("ipAddress"), field.String("scenarios").MaxLen(100000).Optional(), field.String("version").Optional(), field.Bool("isValidated"). Default(false), - field.String("status").Optional(), field.String("auth_type").Default(types.PasswordAuthType).StructTag(`json:"auth_type"`), + field.String("osname").Optional(), + field.String("osversion").Optional(), + field.String("featureflags").Optional(), + field.JSON("hubstate", map[string][]ItemState{}).Optional(), + field.JSON("datasources", map[string]int64{}).Optional(), } } diff --git a/pkg/database/ent/schema/meta.go b/pkg/database/ent/schema/meta.go index 1a84bb1b667..a87010cd8a3 100644 --- a/pkg/database/ent/schema/meta.go +++ b/pkg/database/ent/schema/meta.go @@ -5,6 +5,7 @@ import ( "entgo.io/ent/schema/edge" "entgo.io/ent/schema/field" "entgo.io/ent/schema/index" + "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -17,13 +18,12 @@ type Meta struct { func (Meta) Fields() []ent.Field { return []ent.Field{ field.Time("created_at"). - Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional(), + Default(types.UtcNow).Immutable(), field.Time("updated_at"). Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional(), - field.String("key"), - field.String("value").MaxLen(4095), + UpdateDefault(types.UtcNow), + field.String("key").Immutable(), + field.String("value").MaxLen(4095).Immutable(), field.Int("alert_metas").Optional(), } } diff --git a/pkg/database/ent/schema/metric.go b/pkg/database/ent/schema/metric.go new file mode 100644 index 00000000000..319c67b7aa7 --- /dev/null +++ b/pkg/database/ent/schema/metric.go @@ -0,0 +1,34 @@ +package schema + +import ( + "entgo.io/ent" + "entgo.io/ent/schema/field" +) + +// Metric is actually a set of metrics collected by a device +// (logprocessor, bouncer, etc) at a given time. +type Metric struct { + ent.Schema +} + +func (Metric) Fields() []ent.Field { + return []ent.Field{ + field.Enum("generated_type"). + Values("LP", "RC"). + Immutable(). + Comment("Type of the metrics source: LP=logprocessor, RC=remediation"), + field.String("generated_by"). + Immutable(). + Comment("Source of the metrics: machine id, bouncer name...\nIt must come from the auth middleware."), + field.Time("received_at"). + Immutable(). + Comment("When the metrics are received by LAPI"), + field.Time("pushed_at"). + Nillable(). + Optional(). + Comment("When the metrics are sent to the console"), + field.Text("payload"). + Immutable(). + Comment("The actual metrics (item0)"), + } +} diff --git a/pkg/database/ent/tx.go b/pkg/database/ent/tx.go index 2a1efd152a0..bf8221ce4a5 100644 --- a/pkg/database/ent/tx.go +++ b/pkg/database/ent/tx.go @@ -22,20 +22,18 @@ type Tx struct { Decision *DecisionClient // Event is the client for interacting with the Event builders. Event *EventClient + // Lock is the client for interacting with the Lock builders. + Lock *LockClient // Machine is the client for interacting with the Machine builders. Machine *MachineClient // Meta is the client for interacting with the Meta builders. Meta *MetaClient + // Metric is the client for interacting with the Metric builders. + Metric *MetricClient // lazily loaded. client *Client clientOnce sync.Once - - // completion callbacks. - mu sync.Mutex - onCommit []CommitHook - onRollback []RollbackHook - // ctx lives for the life of the transaction. It is // the same context used by the underlying connection. ctx context.Context @@ -80,9 +78,9 @@ func (tx *Tx) Commit() error { var fn Committer = CommitFunc(func(context.Context, *Tx) error { return txDriver.tx.Commit() }) - tx.mu.Lock() - hooks := append([]CommitHook(nil), tx.onCommit...) - tx.mu.Unlock() + txDriver.mu.Lock() + hooks := append([]CommitHook(nil), txDriver.onCommit...) + txDriver.mu.Unlock() for i := len(hooks) - 1; i >= 0; i-- { fn = hooks[i](fn) } @@ -91,9 +89,10 @@ func (tx *Tx) Commit() error { // OnCommit adds a hook to call on commit. func (tx *Tx) OnCommit(f CommitHook) { - tx.mu.Lock() - defer tx.mu.Unlock() - tx.onCommit = append(tx.onCommit, f) + txDriver := tx.config.driver.(*txDriver) + txDriver.mu.Lock() + txDriver.onCommit = append(txDriver.onCommit, f) + txDriver.mu.Unlock() } type ( @@ -135,9 +134,9 @@ func (tx *Tx) Rollback() error { var fn Rollbacker = RollbackFunc(func(context.Context, *Tx) error { return txDriver.tx.Rollback() }) - tx.mu.Lock() - hooks := append([]RollbackHook(nil), tx.onRollback...) - tx.mu.Unlock() + txDriver.mu.Lock() + hooks := append([]RollbackHook(nil), txDriver.onRollback...) + txDriver.mu.Unlock() for i := len(hooks) - 1; i >= 0; i-- { fn = hooks[i](fn) } @@ -146,9 +145,10 @@ func (tx *Tx) Rollback() error { // OnRollback adds a hook to call on rollback. func (tx *Tx) OnRollback(f RollbackHook) { - tx.mu.Lock() - defer tx.mu.Unlock() - tx.onRollback = append(tx.onRollback, f) + txDriver := tx.config.driver.(*txDriver) + txDriver.mu.Lock() + txDriver.onRollback = append(txDriver.onRollback, f) + txDriver.mu.Unlock() } // Client returns a Client that binds to current transaction. @@ -166,8 +166,10 @@ func (tx *Tx) init() { tx.ConfigItem = NewConfigItemClient(tx.config) tx.Decision = NewDecisionClient(tx.config) tx.Event = NewEventClient(tx.config) + tx.Lock = NewLockClient(tx.config) tx.Machine = NewMachineClient(tx.config) tx.Meta = NewMetaClient(tx.config) + tx.Metric = NewMetricClient(tx.config) } // txDriver wraps the given dialect.Tx with a nop dialect.Driver implementation. @@ -186,6 +188,10 @@ type txDriver struct { drv dialect.Driver // tx is the underlying transaction. tx dialect.Tx + // completion hooks. + mu sync.Mutex + onCommit []CommitHook + onRollback []RollbackHook } // newTx creates a new transactional driver. diff --git a/pkg/database/errors.go b/pkg/database/errors.go index 8e96f52d7ce..77f92707e51 100644 --- a/pkg/database/errors.go +++ b/pkg/database/errors.go @@ -13,8 +13,8 @@ var ( ItemNotFound = errors.New("object not found") ParseTimeFail = errors.New("unable to parse time") ParseDurationFail = errors.New("unable to parse duration") - MarshalFail = errors.New("unable to marshal") - UnmarshalFail = errors.New("unable to unmarshal") + MarshalFail = errors.New("unable to serialize") + UnmarshalFail = errors.New("unable to parse") BulkError = errors.New("unable to insert bulk") ParseType = errors.New("unable to parse type") InvalidIPOrRange = errors.New("invalid ip address / range") diff --git a/pkg/database/flush.go b/pkg/database/flush.go new file mode 100644 index 00000000000..8f646ddc961 --- /dev/null +++ b/pkg/database/flush.go @@ -0,0 +1,311 @@ +package database + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/go-co-op/gocron" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/go-cs-lib/ptr" + + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/event" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +const ( + // how long to keep metrics in the local database + defaultMetricsMaxAge = 7 * 24 * time.Hour + flushInterval = 1 * time.Minute +) + +func (c *Client) StartFlushScheduler(ctx context.Context, config *csconfig.FlushDBCfg) (*gocron.Scheduler, error) { + maxItems := 0 + maxAge := "" + + if config.MaxItems != nil && *config.MaxItems <= 0 { + return nil, errors.New("max_items can't be zero or negative") + } + + if config.MaxItems != nil { + maxItems = *config.MaxItems + } + + if config.MaxAge != nil && *config.MaxAge != "" { + maxAge = *config.MaxAge + } + + // Init & Start cronjob every minute for alerts + scheduler := gocron.NewScheduler(time.UTC) + + job, err := scheduler.Every(1).Minute().Do(c.FlushAlerts, ctx, maxAge, maxItems) + if err != nil { + return nil, fmt.Errorf("while starting FlushAlerts scheduler: %w", err) + } + + job.SingletonMode() + // Init & Start cronjob every hour for bouncers/agents + if config.AgentsGC != nil { + if config.AgentsGC.Cert != nil { + duration, err := ParseDuration(*config.AgentsGC.Cert) + if err != nil { + return nil, fmt.Errorf("while parsing agents cert auto-delete duration: %w", err) + } + + config.AgentsGC.CertDuration = &duration + } + + if config.AgentsGC.LoginPassword != nil { + duration, err := ParseDuration(*config.AgentsGC.LoginPassword) + if err != nil { + return nil, fmt.Errorf("while parsing agents login/password auto-delete duration: %w", err) + } + + config.AgentsGC.LoginPasswordDuration = &duration + } + + if config.AgentsGC.Api != nil { + log.Warning("agents auto-delete for API auth is not supported (use cert or login_password)") + } + } + + if config.BouncersGC != nil { + if config.BouncersGC.Cert != nil { + duration, err := ParseDuration(*config.BouncersGC.Cert) + if err != nil { + return nil, fmt.Errorf("while parsing bouncers cert auto-delete duration: %w", err) + } + + config.BouncersGC.CertDuration = &duration + } + + if config.BouncersGC.Api != nil { + duration, err := ParseDuration(*config.BouncersGC.Api) + if err != nil { + return nil, fmt.Errorf("while parsing bouncers api auto-delete duration: %w", err) + } + + config.BouncersGC.ApiDuration = &duration + } + + if config.BouncersGC.LoginPassword != nil { + log.Warning("bouncers auto-delete for login/password auth is not supported (use cert or api)") + } + } + + baJob, err := scheduler.Every(flushInterval).Do(c.FlushAgentsAndBouncers, ctx, config.AgentsGC, config.BouncersGC) + if err != nil { + return nil, fmt.Errorf("while starting FlushAgentsAndBouncers scheduler: %w", err) + } + + baJob.SingletonMode() + + metricsJob, err := scheduler.Every(flushInterval).Do(c.flushMetrics, ctx, config.MetricsMaxAge) + if err != nil { + return nil, fmt.Errorf("while starting flushMetrics scheduler: %w", err) + } + + metricsJob.SingletonMode() + + scheduler.StartAsync() + + return scheduler, nil +} + +// flushMetrics deletes metrics older than maxAge, regardless if they have been pushed to CAPI or not +func (c *Client) flushMetrics(ctx context.Context, maxAge *time.Duration) { + if maxAge == nil { + maxAge = ptr.Of(defaultMetricsMaxAge) + } + + c.Log.Debugf("flushing metrics older than %s", maxAge) + + deleted, err := c.Ent.Metric.Delete().Where( + metric.ReceivedAtLTE(time.Now().UTC().Add(-*maxAge)), + ).Exec(ctx) + if err != nil { + c.Log.Errorf("while flushing metrics: %s", err) + return + } + + if deleted > 0 { + c.Log.Debugf("flushed %d metrics snapshots", deleted) + } +} + +func (c *Client) FlushOrphans(ctx context.Context) { + /* While it has only been linked to some very corner-case bug : https://github.com/crowdsecurity/crowdsec/issues/778 */ + /* We want to take care of orphaned events for which the parent alert/decision has been deleted */ + eventsCount, err := c.Ent.Event.Delete().Where(event.Not(event.HasOwner())).Exec(ctx) + if err != nil { + c.Log.Warningf("error while deleting orphan events: %s", err) + return + } + + if eventsCount > 0 { + c.Log.Infof("%d deleted orphan events", eventsCount) + } + + eventsCount, err = c.Ent.Decision.Delete().Where( + decision.Not(decision.HasOwner())).Where(decision.UntilLTE(time.Now().UTC())).Exec(ctx) + if err != nil { + c.Log.Warningf("error while deleting orphan decisions: %s", err) + return + } + + if eventsCount > 0 { + c.Log.Infof("%d deleted orphan decisions", eventsCount) + } +} + +func (c *Client) flushBouncers(ctx context.Context, authType string, duration *time.Duration) { + if duration == nil { + return + } + + count, err := c.Ent.Bouncer.Delete().Where( + bouncer.LastPullLTE(time.Now().UTC().Add(-*duration)), + ).Where( + bouncer.AuthTypeEQ(authType), + ).Exec(ctx) + if err != nil { + c.Log.Errorf("while auto-deleting expired bouncers (%s): %s", authType, err) + return + } + + if count > 0 { + c.Log.Infof("deleted %d expired bouncers (%s)", count, authType) + } +} + +func (c *Client) flushAgents(ctx context.Context, authType string, duration *time.Duration) { + if duration == nil { + return + } + + count, err := c.Ent.Machine.Delete().Where( + machine.LastHeartbeatLTE(time.Now().UTC().Add(-*duration)), + machine.Not(machine.HasAlerts()), + machine.AuthTypeEQ(authType), + ).Exec(ctx) + if err != nil { + c.Log.Errorf("while auto-deleting expired machines (%s): %s", authType, err) + return + } + + if count > 0 { + c.Log.Infof("deleted %d expired machines (%s auth)", count, authType) + } +} + +func (c *Client) FlushAgentsAndBouncers(ctx context.Context, agentsCfg *csconfig.AuthGCCfg, bouncersCfg *csconfig.AuthGCCfg) error { + log.Debug("starting FlushAgentsAndBouncers") + + if agentsCfg != nil { + c.flushAgents(ctx, types.TlsAuthType, agentsCfg.CertDuration) + c.flushAgents(ctx, types.PasswordAuthType, agentsCfg.LoginPasswordDuration) + } + + if bouncersCfg != nil { + c.flushBouncers(ctx, types.TlsAuthType, bouncersCfg.CertDuration) + c.flushBouncers(ctx, types.ApiKeyAuthType, bouncersCfg.ApiDuration) + } + + return nil +} + +func (c *Client) FlushAlerts(ctx context.Context, MaxAge string, MaxItems int) error { + var ( + deletedByAge int + deletedByNbItem int + totalAlerts int + err error + ) + + if !c.CanFlush { + c.Log.Debug("a list is being imported, flushing later") + return nil + } + + c.Log.Debug("Flushing orphan alerts") + c.FlushOrphans(ctx) + c.Log.Debug("Done flushing orphan alerts") + + totalAlerts, err = c.TotalAlerts(ctx) + if err != nil { + c.Log.Warningf("FlushAlerts (max items count): %s", err) + return fmt.Errorf("unable to get alerts count: %w", err) + } + + c.Log.Debugf("FlushAlerts (Total alerts): %d", totalAlerts) + + if MaxAge != "" { + filter := map[string][]string{ + "created_before": {MaxAge}, + } + + nbDeleted, err := c.DeleteAlertWithFilter(ctx, filter) + if err != nil { + c.Log.Warningf("FlushAlerts (max age): %s", err) + return fmt.Errorf("unable to flush alerts with filter until=%s: %w", MaxAge, err) + } + + c.Log.Debugf("FlushAlerts (deleted max age alerts): %d", nbDeleted) + deletedByAge = nbDeleted + } + + if MaxItems > 0 { + // We get the highest id for the alerts + // We subtract MaxItems to avoid deleting alerts that are not old enough + // This gives us the oldest alert that we want to keep + // We then delete all the alerts with an id lower than this one + // We can do this because the id is auto-increment, and the database won't reuse the same id twice + lastAlert, err := c.QueryAlertWithFilter(ctx, map[string][]string{ + "sort": {"DESC"}, + "limit": {"1"}, + // we do not care about fetching the edges, we just want the id + "with_decisions": {"false"}, + }) + c.Log.Debugf("FlushAlerts (last alert): %+v", lastAlert) + + if err != nil { + c.Log.Errorf("FlushAlerts: could not get last alert: %s", err) + return fmt.Errorf("could not get last alert: %w", err) + } + + if len(lastAlert) != 0 { + maxid := lastAlert[0].ID - MaxItems + + c.Log.Debugf("FlushAlerts (max id): %d", maxid) + + if maxid > 0 { + // This may lead to orphan alerts (at least on MySQL), but the next time the flush job will run, they will be deleted + deletedByNbItem, err = c.Ent.Alert.Delete().Where(alert.IDLT(maxid)).Exec(ctx) + if err != nil { + c.Log.Errorf("FlushAlerts: Could not delete alerts: %s", err) + return fmt.Errorf("could not delete alerts: %w", err) + } + } + } + } + + if deletedByNbItem > 0 { + c.Log.Infof("flushed %d/%d alerts because the max number of alerts has been reached (%d max)", + deletedByNbItem, totalAlerts, MaxItems) + } + + if deletedByAge > 0 { + c.Log.Infof("flushed %d/%d alerts because they were created %s ago or more", + deletedByAge, totalAlerts, MaxAge) + } + + return nil +} diff --git a/pkg/database/lock.go b/pkg/database/lock.go new file mode 100644 index 00000000000..474228a069c --- /dev/null +++ b/pkg/database/lock.go @@ -0,0 +1,87 @@ +package database + +import ( + "context" + "time" + + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/pkg/database/ent" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +const ( + CAPIPullLockTimeout = 10 + CapiPullLockName = "pullCAPI" +) + +func (c *Client) AcquireLock(ctx context.Context, name string) error { + log.Debugf("acquiring lock %s", name) + _, err := c.Ent.Lock.Create(). + SetName(name). + SetCreatedAt(types.UtcNow()). + Save(ctx) + + if ent.IsConstraintError(err) { + return err + } + + if err != nil { + return errors.Wrapf(InsertFail, "insert lock: %s", err) + } + + return nil +} + +func (c *Client) ReleaseLock(ctx context.Context, name string) error { + log.Debugf("releasing lock %s", name) + _, err := c.Ent.Lock.Delete().Where(lock.NameEQ(name)).Exec(ctx) + if err != nil { + return errors.Wrapf(DeleteFail, "delete lock: %s", err) + } + + return nil +} + +func (c *Client) ReleaseLockWithTimeout(ctx context.Context, name string, timeout int) error { + log.Debugf("releasing lock %s with timeout of %d minutes", name, timeout) + + _, err := c.Ent.Lock.Delete().Where( + lock.NameEQ(name), + lock.CreatedAtLT(time.Now().UTC().Add(-time.Duration(timeout)*time.Minute)), + ).Exec(ctx) + if err != nil { + return errors.Wrapf(DeleteFail, "delete lock: %s", err) + } + + return nil +} + +func (c *Client) IsLocked(err error) bool { + return ent.IsConstraintError(err) +} + +func (c *Client) AcquirePullCAPILock(ctx context.Context) error { + // delete orphan "old" lock if present + err := c.ReleaseLockWithTimeout(ctx, CapiPullLockName, CAPIPullLockTimeout) + if err != nil { + log.Errorf("unable to release pullCAPI lock: %s", err) + } + + return c.AcquireLock(ctx, CapiPullLockName) +} + +func (c *Client) ReleasePullCAPILock(ctx context.Context) error { + log.Debugf("deleting lock %s", CapiPullLockName) + + _, err := c.Ent.Lock.Delete().Where( + lock.NameEQ(CapiPullLockName), + ).Exec(ctx) + if err != nil { + return errors.Wrapf(DeleteFail, "delete lock: %s", err) + } + + return nil +} diff --git a/pkg/database/machines.go b/pkg/database/machines.go index 7a010fbfbc4..d8c02825312 100644 --- a/pkg/database/machines.go +++ b/pkg/database/machines.go @@ -1,7 +1,9 @@ package database import ( + "context" "fmt" + "strings" "time" "github.com/go-openapi/strfmt" @@ -10,39 +12,97 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/schema" + "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/types" ) -const CapiMachineID = types.CAPIOrigin -const CapiListsMachineID = types.ListOrigin +const ( + CapiMachineID = types.CAPIOrigin + CapiListsMachineID = types.ListOrigin +) + +type MachineNotFoundError struct { + MachineID string +} + +func (e *MachineNotFoundError) Error() string { + return fmt.Sprintf("'%s' does not exist", e.MachineID) +} + +func (c *Client) MachineUpdateBaseMetrics(ctx context.Context, machineID string, baseMetrics models.BaseMetrics, hubItems models.HubItems, datasources map[string]int64) error { + os := baseMetrics.Os + features := strings.Join(baseMetrics.FeatureFlags, ",") -func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipAddress string, isValidated bool, force bool, authType string) (*ent.Machine, error) { + var heartbeat time.Time + + if len(baseMetrics.Metrics) == 0 { + heartbeat = time.Now().UTC() + } else { + heartbeat = time.Unix(*baseMetrics.Metrics[0].Meta.UtcNowTimestamp, 0) + } + + hubState := map[string][]schema.ItemState{} + for itemType, items := range hubItems { + hubState[itemType] = []schema.ItemState{} + for _, item := range items { + hubState[itemType] = append(hubState[itemType], schema.ItemState{ + Name: item.Name, + Status: item.Status, + Version: item.Version, + }) + } + } + + _, err := c.Ent.Machine. + Update(). + Where(machine.MachineIdEQ(machineID)). + SetNillableVersion(baseMetrics.Version). + SetOsname(*os.Name). + SetOsversion(*os.Version). + SetFeatureflags(features). + SetLastHeartbeat(heartbeat). + SetHubstate(hubState). + SetDatasources(datasources). + Save(ctx) + if err != nil { + return fmt.Errorf("unable to update base machine metrics in database: %w", err) + } + + return nil +} + +func (c *Client) CreateMachine(ctx context.Context, machineID *string, password *strfmt.Password, ipAddress string, isValidated bool, force bool, authType string) (*ent.Machine, error) { hashPassword, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost) if err != nil { - c.Log.Warningf("CreateMachine : %s", err) - return nil, errors.Wrap(HashError, "") + c.Log.Warningf("CreateMachine: %s", err) + return nil, HashError } machineExist, err := c.Ent.Machine. Query(). Where(machine.MachineIdEQ(*machineID)). - Select(machine.FieldMachineId).Strings(c.CTX) + Select(machine.FieldMachineId).Strings(ctx) if err != nil { return nil, errors.Wrapf(QueryFail, "machine '%s': %s", *machineID, err) } + if len(machineExist) > 0 { if force { - _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(*machineID)).SetPassword(string(hashPassword)).Save(c.CTX) + _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(*machineID)).SetPassword(string(hashPassword)).Save(ctx) if err != nil { c.Log.Warningf("CreateMachine : %s", err) return nil, errors.Wrapf(UpdateFail, "machine '%s'", *machineID) } - machine, err := c.QueryMachineByID(*machineID) + + machine, err := c.QueryMachineByID(ctx, *machineID) if err != nil { return nil, errors.Wrapf(QueryFail, "machine '%s': %s", *machineID, err) } + return machine, nil } + return nil, errors.Wrapf(UserExists, "user '%s'", *machineID) } @@ -53,8 +113,7 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA SetIpAddress(ipAddress). SetIsValidated(isValidated). SetAuthType(authType). - Save(c.CTX) - + Save(ctx) if err != nil { c.Log.Warningf("CreateMachine : %s", err) return nil, errors.Wrapf(InsertFail, "creating machine '%s'", *machineID) @@ -63,124 +122,146 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA return machine, nil } -func (c *Client) QueryMachineByID(machineID string) (*ent.Machine, error) { +func (c *Client) QueryMachineByID(ctx context.Context, machineID string) (*ent.Machine, error) { machine, err := c.Ent.Machine. Query(). Where(machine.MachineIdEQ(machineID)). - Only(c.CTX) + Only(ctx) if err != nil { c.Log.Warningf("QueryMachineByID : %s", err) return &ent.Machine{}, errors.Wrapf(UserNotExists, "user '%s'", machineID) } + return machine, nil } -func (c *Client) ListMachines() ([]*ent.Machine, error) { - machines, err := c.Ent.Machine.Query().All(c.CTX) +func (c *Client) ListMachines(ctx context.Context) ([]*ent.Machine, error) { + machines, err := c.Ent.Machine.Query().All(ctx) if err != nil { - return []*ent.Machine{}, errors.Wrapf(QueryFail, "listing machines: %s", err) + return nil, errors.Wrapf(QueryFail, "listing machines: %s", err) } + return machines, nil } -func (c *Client) ValidateMachine(machineID string) error { - rets, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetIsValidated(true).Save(c.CTX) +func (c *Client) ValidateMachine(ctx context.Context, machineID string) error { + rets, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetIsValidated(true).Save(ctx) if err != nil { return errors.Wrapf(UpdateFail, "validating machine: %s", err) } + if rets == 0 { - return fmt.Errorf("machine not found") + return errors.New("machine not found") } + return nil } -func (c *Client) QueryPendingMachine() ([]*ent.Machine, error) { - var machines []*ent.Machine - var err error - - machines, err = c.Ent.Machine.Query().Where(machine.IsValidatedEQ(false)).All(c.CTX) +func (c *Client) QueryPendingMachine(ctx context.Context) ([]*ent.Machine, error) { + machines, err := c.Ent.Machine.Query().Where(machine.IsValidatedEQ(false)).All(ctx) if err != nil { c.Log.Warningf("QueryPendingMachine : %s", err) - return []*ent.Machine{}, errors.Wrapf(QueryFail, "querying pending machines: %s", err) + return nil, errors.Wrapf(QueryFail, "querying pending machines: %s", err) } + return machines, nil } -func (c *Client) DeleteWatcher(name string) error { +func (c *Client) DeleteWatcher(ctx context.Context, name string) error { nbDeleted, err := c.Ent.Machine. Delete(). Where(machine.MachineIdEQ(name)). - Exec(c.CTX) + Exec(ctx) if err != nil { return err } if nbDeleted == 0 { - return fmt.Errorf("machine doesn't exist") + return &MachineNotFoundError{MachineID: name} } return nil } -func (c *Client) UpdateMachineLastPush(machineID string) error { - _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetLastPush(time.Now().UTC()).Save(c.CTX) +func (c *Client) BulkDeleteWatchers(ctx context.Context, machines []*ent.Machine) (int, error) { + ids := make([]int, len(machines)) + for i, b := range machines { + ids[i] = b.ID + } + + nbDeleted, err := c.Ent.Machine.Delete().Where(machine.IDIn(ids...)).Exec(ctx) if err != nil { - return errors.Wrapf(UpdateFail, "updating machine last_push: %s", err) + return nbDeleted, err } - return nil + + return nbDeleted, nil } -func (c *Client) UpdateMachineLastHeartBeat(machineID string) error { - _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetLastHeartbeat(time.Now().UTC()).Save(c.CTX) +func (c *Client) UpdateMachineLastHeartBeat(ctx context.Context, machineID string) error { + _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetLastHeartbeat(time.Now().UTC()).Save(ctx) if err != nil { return errors.Wrapf(UpdateFail, "updating machine last_heartbeat: %s", err) } + return nil } -func (c *Client) UpdateMachineScenarios(scenarios string, ID int) error { - _, err := c.Ent.Machine.UpdateOneID(ID). +func (c *Client) UpdateMachineScenarios(ctx context.Context, scenarios string, id int) error { + _, err := c.Ent.Machine.UpdateOneID(id). SetUpdatedAt(time.Now().UTC()). SetScenarios(scenarios). - Save(c.CTX) + Save(ctx) if err != nil { - return fmt.Errorf("unable to update machine in database: %s", err) + return fmt.Errorf("unable to update machine in database: %w", err) } + return nil } -func (c *Client) UpdateMachineIP(ipAddr string, ID int) error { - _, err := c.Ent.Machine.UpdateOneID(ID). +func (c *Client) UpdateMachineIP(ctx context.Context, ipAddr string, id int) error { + _, err := c.Ent.Machine.UpdateOneID(id). SetIpAddress(ipAddr). - Save(c.CTX) + Save(ctx) if err != nil { - return fmt.Errorf("unable to update machine IP in database: %s", err) + return fmt.Errorf("unable to update machine IP in database: %w", err) } + return nil } -func (c *Client) UpdateMachineVersion(ipAddr string, ID int) error { - _, err := c.Ent.Machine.UpdateOneID(ID). +func (c *Client) UpdateMachineVersion(ctx context.Context, ipAddr string, id int) error { + _, err := c.Ent.Machine.UpdateOneID(id). SetVersion(ipAddr). - Save(c.CTX) + Save(ctx) if err != nil { - return fmt.Errorf("unable to update machine version in database: %s", err) + return fmt.Errorf("unable to update machine version in database: %w", err) } + return nil } -func (c *Client) IsMachineRegistered(machineID string) (bool, error) { - exist, err := c.Ent.Machine.Query().Where().Select(machine.FieldMachineId).Strings(c.CTX) +func (c *Client) IsMachineRegistered(ctx context.Context, machineID string) (bool, error) { + exist, err := c.Ent.Machine.Query().Where().Select(machine.FieldMachineId).Strings(ctx) if err != nil { return false, err } + if len(exist) == 1 { return true, nil } + if len(exist) > 1 { - return false, fmt.Errorf("More than one item with the same machineID in database") + return false, errors.New("more than one item with the same machineID in database") } return false, nil +} +func (c *Client) QueryMachinesInactiveSince(ctx context.Context, t time.Time) ([]*ent.Machine, error) { + return c.Ent.Machine.Query().Where( + machine.Or( + machine.And(machine.LastHeartbeatLT(t), machine.IsValidatedEQ(true)), + machine.And(machine.LastHeartbeatIsNil(), machine.CreatedAtLT(t)), + ), + ).All(ctx) } diff --git a/pkg/database/metrics.go b/pkg/database/metrics.go new file mode 100644 index 00000000000..eb4c472821e --- /dev/null +++ b/pkg/database/metrics.go @@ -0,0 +1,71 @@ +package database + +import ( + "context" + "fmt" + "time" + + "github.com/crowdsecurity/crowdsec/pkg/database/ent" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" +) + +func (c *Client) CreateMetric(ctx context.Context, generatedType metric.GeneratedType, generatedBy string, receivedAt time.Time, payload string) (*ent.Metric, error) { + metric, err := c.Ent.Metric. + Create(). + SetGeneratedType(generatedType). + SetGeneratedBy(generatedBy). + SetReceivedAt(receivedAt). + SetPayload(payload). + Save(ctx) + if err != nil { + c.Log.Warningf("CreateMetric: %s", err) + return nil, fmt.Errorf("storing metrics snapshot for '%s' at %s: %w", generatedBy, receivedAt, InsertFail) + } + + return metric, nil +} + +func (c *Client) GetLPUsageMetricsByMachineID(ctx context.Context, machineId string) ([]*ent.Metric, error) { + metrics, err := c.Ent.Metric.Query(). + Where( + metric.GeneratedTypeEQ(metric.GeneratedTypeLP), + metric.GeneratedByEQ(machineId), + metric.PushedAtIsNil(), + ). + All(ctx) + if err != nil { + c.Log.Warningf("GetLPUsageMetricsByOrigin: %s", err) + return nil, fmt.Errorf("getting LP usage metrics by origin %s: %w", machineId, err) + } + + return metrics, nil +} + +func (c *Client) GetBouncerUsageMetricsByName(ctx context.Context, bouncerName string) ([]*ent.Metric, error) { + metrics, err := c.Ent.Metric.Query(). + Where( + metric.GeneratedTypeEQ(metric.GeneratedTypeRC), + metric.GeneratedByEQ(bouncerName), + metric.PushedAtIsNil(), + ). + All(ctx) + if err != nil { + c.Log.Warningf("GetBouncerUsageMetricsByName: %s", err) + return nil, fmt.Errorf("getting bouncer usage metrics by name %s: %w", bouncerName, err) + } + + return metrics, nil +} + +func (c *Client) MarkUsageMetricsAsSent(ctx context.Context, ids []int) error { + _, err := c.Ent.Metric.Update(). + Where(metric.IDIn(ids...)). + SetPushedAt(time.Now().UTC()). + Save(ctx) + if err != nil { + c.Log.Warningf("MarkUsageMetricsAsSent: %s", err) + return fmt.Errorf("marking usage metrics as sent: %w", err) + } + + return nil +} diff --git a/pkg/database/utils.go b/pkg/database/utils.go index 5d6d4a44264..8148df56f24 100644 --- a/pkg/database/utils.go +++ b/pkg/database/utils.go @@ -4,18 +4,23 @@ import ( "encoding/binary" "fmt" "net" + "strconv" + "strings" + "time" ) func IP2Int(ip net.IP) uint32 { if len(ip) == 16 { return binary.BigEndian.Uint32(ip[12:16]) } + return binary.BigEndian.Uint32(ip) } func Int2ip(nn uint32) net.IP { ip := make(net.IP, 4) binary.BigEndian.PutUint32(ip, nn) + return ip } @@ -23,20 +28,22 @@ func IsIpv4(host string) bool { return net.ParseIP(host) != nil } -//Stolen from : https://github.com/llimllib/ipaddress/ +// Stolen from : https://github.com/llimllib/ipaddress/ // Return the final address of a net range. Convert to IPv4 if possible, // otherwise return an ipv6 func LastAddress(n *net.IPNet) net.IP { ip := n.IP.To4() if ip == nil { ip = n.IP + return net.IP{ ip[0] | ^n.Mask[0], ip[1] | ^n.Mask[1], ip[2] | ^n.Mask[2], ip[3] | ^n.Mask[3], ip[4] | ^n.Mask[4], ip[5] | ^n.Mask[5], ip[6] | ^n.Mask[6], ip[7] | ^n.Mask[7], ip[8] | ^n.Mask[8], ip[9] | ^n.Mask[9], ip[10] | ^n.Mask[10], ip[11] | ^n.Mask[11], ip[12] | ^n.Mask[12], ip[13] | ^n.Mask[13], ip[14] | ^n.Mask[14], - ip[15] | ^n.Mask[15]} + ip[15] | ^n.Mask[15], + } } return net.IPv4( @@ -46,20 +53,44 @@ func LastAddress(n *net.IPNet) net.IP { ip[3]|^n.Mask[3]) } +// GetIpsFromIpRange takes a CIDR range and returns the start and end IP func GetIpsFromIpRange(host string) (int64, int64, error) { - var ipStart int64 - var ipEnd int64 - var err error - var parsedRange *net.IPNet - - if _, parsedRange, err = net.ParseCIDR(host); err != nil { - return ipStart, ipEnd, fmt.Errorf("'%s' is not a valid CIDR", host) + _, parsedRange, err := net.ParseCIDR(host) + if err != nil { + return 0, 0, fmt.Errorf("'%s' is not a valid CIDR", host) } + if parsedRange == nil { - return ipStart, ipEnd, fmt.Errorf("unable to parse network : %s", err) + return 0, 0, fmt.Errorf("unable to parse network: %w", err) } - ipStart = int64(IP2Int(parsedRange.IP)) - ipEnd = int64(IP2Int(LastAddress(parsedRange))) + + ipStart := int64(IP2Int(parsedRange.IP)) + ipEnd := int64(IP2Int(LastAddress(parsedRange))) return ipStart, ipEnd, nil } + +func ParseDuration(d string) (time.Duration, error) { + durationStr := d + + if strings.HasSuffix(d, "d") { + days := strings.Split(d, "d")[0] + if days == "" { + return 0, fmt.Errorf("'%s' can't be parsed as duration", d) + } + + daysInt, err := strconv.Atoi(days) + if err != nil { + return 0, err + } + + durationStr = strconv.Itoa(daysInt*24) + "h" + } + + duration, err := time.ParseDuration(durationStr) + if err != nil { + return 0, err + } + + return duration, nil +} diff --git a/pkg/dumps/bucket_dump.go b/pkg/dumps/bucket_dump.go new file mode 100644 index 00000000000..328c581928b --- /dev/null +++ b/pkg/dumps/bucket_dump.go @@ -0,0 +1,33 @@ +package dumps + +import ( + "io" + "os" + + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +type BucketPourInfo map[string][]types.Event + +func LoadBucketPourDump(filepath string) (*BucketPourInfo, error) { + dumpData, err := os.Open(filepath) + if err != nil { + return nil, err + } + defer dumpData.Close() + + results, err := io.ReadAll(dumpData) + if err != nil { + return nil, err + } + + var bucketDump BucketPourInfo + + if err := yaml.Unmarshal(results, &bucketDump); err != nil { + return nil, err + } + + return &bucketDump, nil +} diff --git a/pkg/dumps/parser_dump.go b/pkg/dumps/parser_dump.go new file mode 100644 index 00000000000..bc8f78dc203 --- /dev/null +++ b/pkg/dumps/parser_dump.go @@ -0,0 +1,346 @@ +package dumps + +import ( + "errors" + "fmt" + "io" + "os" + "sort" + "strings" + "time" + + "github.com/fatih/color" + diff "github.com/r3labs/diff/v2" + log "github.com/sirupsen/logrus" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/go-cs-lib/maptools" + + "github.com/crowdsecurity/crowdsec/pkg/emoji" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +type ParserResult struct { + Idx int + Evt types.Event + Success bool +} + +type ParserResults map[string]map[string][]ParserResult + +type DumpOpts struct { + Details bool + SkipOk bool + ShowNotOkParsers bool +} + +func LoadParserDump(filepath string) (*ParserResults, error) { + dumpData, err := os.Open(filepath) + if err != nil { + return nil, err + } + defer dumpData.Close() + + results, err := io.ReadAll(dumpData) + if err != nil { + return nil, err + } + + pdump := ParserResults{} + + if err := yaml.Unmarshal(results, &pdump); err != nil { + return nil, err + } + + /* we know that some variables should always be set, + let's check if they're present in last parser output of last stage */ + + stages := maptools.SortedKeys(pdump) + + var lastStage string + + // Loop over stages to find last successful one with at least one parser + for i := len(stages) - 2; i >= 0; i-- { + if len(pdump[stages[i]]) != 0 { + lastStage = stages[i] + break + } + } + + parsers := make([]string, 0, len(pdump[lastStage])) + + for k := range pdump[lastStage] { + parsers = append(parsers, k) + } + + sort.Strings(parsers) + + if len(parsers) == 0 { + return nil, errors.New("no parser found. Please install the appropriate parser and retry") + } + + lastParser := parsers[len(parsers)-1] + + for idx, result := range pdump[lastStage][lastParser] { + if result.Evt.StrTime == "" { + log.Warningf("Line %d/%d is missing evt.StrTime. It is most likely a mistake as it will prevent your logs to be processed in time-machine/forensic mode.", idx, len(pdump[lastStage][lastParser])) + } else { + log.Debugf("Line %d/%d has evt.StrTime set to '%s'", idx, len(pdump[lastStage][lastParser]), result.Evt.StrTime) + } + } + + return &pdump, nil +} + +type tree struct { + // note : we can use line -> time as the unique identifier (of acquisition) + state map[time.Time]map[string]map[string]ParserResult + assoc map[time.Time]string + parserOrder map[string][]string +} + +func newTree() *tree { + return &tree{ + state: make(map[time.Time]map[string]map[string]ParserResult), + assoc: make(map[time.Time]string), + parserOrder: make(map[string][]string), + } +} + +func DumpTree(parserResults ParserResults, bucketPour BucketPourInfo, opts DumpOpts) { + t := newTree() + t.processEvents(parserResults) + t.processBuckets(bucketPour) + t.displayResults(opts) +} + +func (t *tree) processEvents(parserResults ParserResults) { + for stage, parsers := range parserResults { + // let's process parsers in the order according to idx + t.parserOrder[stage] = make([]string, len(parsers)) + + for pname, parser := range parsers { + if len(parser) > 0 { + t.parserOrder[stage][parser[0].Idx-1] = pname + } + } + + for _, parser := range t.parserOrder[stage] { + results := parsers[parser] + for _, parserRes := range results { + evt := parserRes.Evt + if _, ok := t.state[evt.Line.Time]; !ok { + t.state[evt.Line.Time] = make(map[string]map[string]ParserResult) + t.assoc[evt.Line.Time] = evt.Line.Raw + } + + if _, ok := t.state[evt.Line.Time][stage]; !ok { + t.state[evt.Line.Time][stage] = make(map[string]ParserResult) + } + + t.state[evt.Line.Time][stage][parser] = ParserResult{Evt: evt, Success: parserRes.Success} + } + } + } +} + +func (t *tree) processBuckets(bucketPour BucketPourInfo) { + for bname, evtlist := range bucketPour { + for _, evt := range evtlist { + if evt.Line.Raw == "" { + continue + } + + // it might be bucket overflow being reprocessed, skip this + if _, ok := t.state[evt.Line.Time]; !ok { + t.state[evt.Line.Time] = make(map[string]map[string]ParserResult) + t.assoc[evt.Line.Time] = evt.Line.Raw + } + + // there is a trick : to know if an event successfully exit the parsers, we check if it reached the pour() phase + // we thus use a fake stage "buckets" and a fake parser "OK" to know if it entered + if _, ok := t.state[evt.Line.Time]["buckets"]; !ok { + t.state[evt.Line.Time]["buckets"] = make(map[string]ParserResult) + } + + t.state[evt.Line.Time]["buckets"][bname] = ParserResult{Success: true} + } + } +} + +func (t *tree) displayResults(opts DumpOpts) { + yellow := color.New(color.FgYellow).SprintFunc() + red := color.New(color.FgRed).SprintFunc() + green := color.New(color.FgGreen).SprintFunc() + whitelistReason := "" + + // get each line + for tstamp, rawstr := range t.assoc { + if opts.SkipOk { + if _, ok := t.state[tstamp]["buckets"]["OK"]; ok { + continue + } + } + + fmt.Printf("line: %s\n", rawstr) + + skeys := make([]string, 0, len(t.state[tstamp])) + + for k := range t.state[tstamp] { + // there is a trick : to know if an event successfully exit the parsers, we check if it reached the pour() phase + // we thus use a fake stage "buckets" and a fake parser "OK" to know if it entered + if k == "buckets" { + continue + } + + skeys = append(skeys, k) + } + + sort.Strings(skeys) + + // iterate stage + var prevItem types.Event + + for _, stage := range skeys { + parsers := t.state[tstamp][stage] + + sep := "├" + presep := "|" + + fmt.Printf("\t%s %s\n", sep, stage) + + for idx, parser := range t.parserOrder[stage] { + res := parsers[parser].Success + sep := "├" + + if idx == len(t.parserOrder[stage])-1 { + sep = "└" + } + + created := 0 + updated := 0 + deleted := 0 + whitelisted := false + changeStr := "" + detailsDisplay := "" + + if res { + changelog, _ := diff.Diff(prevItem, parsers[parser].Evt) + for _, change := range changelog { + switch change.Type { + case "create": + created++ + + detailsDisplay += fmt.Sprintf("\t%s\t\t%s %s evt.%s : %s\n", presep, sep, change.Type, strings.Join(change.Path, "."), green(change.To)) + case "update": + detailsDisplay += fmt.Sprintf("\t%s\t\t%s %s evt.%s : %s -> %s\n", presep, sep, change.Type, strings.Join(change.Path, "."), change.From, yellow(change.To)) + + if change.Path[0] == "Whitelisted" && change.To == true { //nolint:revive + whitelisted = true + + if whitelistReason == "" { + whitelistReason = parsers[parser].Evt.WhitelistReason + } + } + + updated++ + case "delete": + deleted++ + + detailsDisplay += fmt.Sprintf("\t%s\t\t%s %s evt.%s\n", presep, sep, change.Type, red(strings.Join(change.Path, "."))) + } + } + + prevItem = parsers[parser].Evt + } + + if created > 0 { + changeStr += green(fmt.Sprintf("+%d", created)) + } + + if updated > 0 { + if changeStr != "" { + changeStr += " " + } + + changeStr += yellow(fmt.Sprintf("~%d", updated)) + } + + if deleted > 0 { + if changeStr != "" { + changeStr += " " + } + + changeStr += red(fmt.Sprintf("-%d", deleted)) + } + + if whitelisted { + if changeStr != "" { + changeStr += " " + } + + changeStr += red("[whitelisted]") + } + + if changeStr == "" { + changeStr = yellow("unchanged") + } + + if res { + fmt.Printf("\t%s\t%s %s %s (%s)\n", presep, sep, emoji.GreenCircle, parser, changeStr) + + if opts.Details { + fmt.Print(detailsDisplay) + } + } else if opts.ShowNotOkParsers { + fmt.Printf("\t%s\t%s %s %s\n", presep, sep, emoji.RedCircle, parser) + } + } + } + + sep := "└" + + if len(t.state[tstamp]["buckets"]) > 0 { + sep = "├" + } + + // did the event enter the bucket pour phase ? + if _, ok := t.state[tstamp]["buckets"]["OK"]; ok { + fmt.Printf("\t%s-------- parser success %s\n", sep, emoji.GreenCircle) + } else if whitelistReason != "" { + fmt.Printf("\t%s-------- parser success, ignored by whitelist (%s) %s\n", sep, whitelistReason, emoji.GreenCircle) + } else { + fmt.Printf("\t%s-------- parser failure %s\n", sep, emoji.RedCircle) + } + + // now print bucket info + if len(t.state[tstamp]["buckets"]) > 0 { + fmt.Printf("\t├ Scenarios\n") + } + + bnames := make([]string, 0, len(t.state[tstamp]["buckets"])) + + for k := range t.state[tstamp]["buckets"] { + // there is a trick : to know if an event successfully exit the parsers, we check if it reached the pour() phase + // we thus use a fake stage "buckets" and a fake parser "OK" to know if it entered + if k == "OK" { + continue + } + + bnames = append(bnames, k) + } + + sort.Strings(bnames) + + for idx, bname := range bnames { + sep := "├" + if idx == len(bnames)-1 { + sep = "└" + } + + fmt.Printf("\t\t%s %s %s\n", sep, emoji.GreenCircle, bname) + } + + fmt.Println() + } +} diff --git a/pkg/emoji/emoji.go b/pkg/emoji/emoji.go new file mode 100644 index 00000000000..51295a85411 --- /dev/null +++ b/pkg/emoji/emoji.go @@ -0,0 +1,14 @@ +package emoji + +const ( + CheckMarkButton = "\u2705" // ✅ + CheckMark = "\u2714\ufe0f" // ✔ī¸ + CrossMark = "\u274c" // ❌ + GreenCircle = "\U0001f7e2" // đŸŸĸ + House = "\U0001f3e0" // 🏠 + Package = "\U0001f4e6" // đŸ“Ļ + Prohibited = "\U0001f6ab" // đŸšĢ + QuestionMark = "\u2753" // ❓ + RedCircle = "\U0001f534" // 🔴 + Warning = "\u26a0\ufe0f" // ⚠ī¸ +) diff --git a/pkg/exprhelpers/crowdsec_cti.go b/pkg/exprhelpers/crowdsec_cti.go index 6440295c856..ccd67b27a49 100644 --- a/pkg/exprhelpers/crowdsec_cti.go +++ b/pkg/exprhelpers/crowdsec_cti.go @@ -1,14 +1,15 @@ package exprhelpers import ( + "errors" "fmt" "time" "github.com/bluele/gcache" + log "github.com/sirupsen/logrus" + "github.com/crowdsecurity/crowdsec/pkg/cticlient" "github.com/crowdsecurity/crowdsec/pkg/types" - "github.com/pkg/errors" - log "github.com/sirupsen/logrus" ) var CTIUrl = "https://cti.api.crowdsec.net" @@ -16,21 +17,20 @@ var CTIUrlSuffix = "/v2/smoke/" var CTIApiKey = "" // this is set for non-recoverable errors, such as 403 when querying API or empty API key -var CTIApiEnabled = true +var CTIApiEnabled = false // when hitting quotas or auth errors, we temporarily disable the API var CTIBackOffUntil time.Time -var CTIBackOffDuration time.Duration = 5 * time.Minute +var CTIBackOffDuration = 5 * time.Minute var ctiClient *cticlient.CrowdsecCTIClient func InitCrowdsecCTI(Key *string, TTL *time.Duration, Size *int, LogLevel *log.Level) error { - if Key != nil { - CTIApiKey = *Key - } else { - CTIApiEnabled = false - return fmt.Errorf("CTI API key not set, CTI will not be available") + if Key == nil || *Key == "" { + log.Warningf("CTI API key not set or empty, CTI will not be available") + return cticlient.ErrDisabled } + CTIApiKey = *Key if Size == nil { Size = new(int) *Size = 1000 @@ -39,20 +39,17 @@ func InitCrowdsecCTI(Key *string, TTL *time.Duration, Size *int, LogLevel *log.L TTL = new(time.Duration) *TTL = 5 * time.Minute } - //dedicated logger clog := log.New() if err := types.ConfigureLogger(clog); err != nil { - return errors.Wrap(err, "while configuring datasource logger") + return fmt.Errorf("while configuring datasource logger: %w", err) } if LogLevel != nil { clog.SetLevel(*LogLevel) } - customLog := log.Fields{ - "type": "crowdsec-cti", - } - subLogger := clog.WithFields(customLog) + subLogger := clog.WithField("type", "crowdsec-cti") CrowdsecCTIInitCache(*Size, *TTL) ctiClient = cticlient.NewCrowdsecCTIClient(cticlient.WithAPIKey(CTIApiKey), cticlient.WithLogger(subLogger)) + CTIApiEnabled = true return nil } @@ -61,7 +58,7 @@ func ShutdownCrowdsecCTI() { CTICache.Purge() } CTIApiKey = "" - CTIApiEnabled = true + CTIApiEnabled = false } // Cache for responses @@ -75,31 +72,23 @@ func CrowdsecCTIInitCache(size int, ttl time.Duration) { // func CrowdsecCTI(ip string) (*cticlient.SmokeItem, error) { func CrowdsecCTI(params ...any) (any, error) { - ip := params[0].(string) + var ip string if !CTIApiEnabled { - ctiClient.Logger.Warningf("Crowdsec CTI API is disabled, please check your configuration") return &cticlient.SmokeItem{}, cticlient.ErrDisabled } - - if CTIApiKey == "" { - ctiClient.Logger.Warningf("CrowdsecCTI : no key provided, skipping") - return &cticlient.SmokeItem{}, cticlient.ErrDisabled - } - - if ctiClient == nil { - ctiClient.Logger.Warningf("CrowdsecCTI: no client, skipping") - return &cticlient.SmokeItem{}, cticlient.ErrDisabled + var ok bool + if ip, ok = params[0].(string); !ok { + return &cticlient.SmokeItem{}, fmt.Errorf("invalid type for ip : %T", params[0]) } if val, err := CTICache.Get(ip); err == nil && val != nil { ctiClient.Logger.Debugf("cti cache fetch for %s", ip) ret, ok := val.(*cticlient.SmokeItem) - if !ok { - ctiClient.Logger.Warningf("CrowdsecCTI: invalid type in cache, removing") - CTICache.Remove(ip) - } else { + if ok { return ret, nil } + ctiClient.Logger.Warningf("CrowdsecCTI: invalid type in cache, removing") + CTICache.Remove(ip) } if !CTIBackOffUntil.IsZero() && time.Now().Before(CTIBackOffUntil) { @@ -112,17 +101,18 @@ func CrowdsecCTI(params ...any) (any, error) { ctiResp, err := ctiClient.GetIPInfo(ip) ctiClient.Logger.Debugf("request for %s took %v", ip, time.Since(before)) if err != nil { - if err == cticlient.ErrUnauthorized { + switch { + case errors.Is(err, cticlient.ErrUnauthorized): CTIApiEnabled = false ctiClient.Logger.Errorf("Invalid API key provided, disabling CTI API") return &cticlient.SmokeItem{}, cticlient.ErrUnauthorized - } else if err == cticlient.ErrLimit { + case errors.Is(err, cticlient.ErrLimit): CTIBackOffUntil = time.Now().Add(CTIBackOffDuration) ctiClient.Logger.Errorf("CTI API is throttled, will try again in %s", CTIBackOffDuration) return &cticlient.SmokeItem{}, cticlient.ErrLimit - } else { + default: ctiClient.Logger.Warnf("CTI API error : %s", err) - return &cticlient.SmokeItem{}, fmt.Errorf("unexpected error : %v", err) + return &cticlient.SmokeItem{}, fmt.Errorf("unexpected error: %w", err) } } diff --git a/pkg/exprhelpers/crowdsec_cti_test.go b/pkg/exprhelpers/crowdsec_cti_test.go index 51ab5f8a3c4..9f78b932d6d 100644 --- a/pkg/exprhelpers/crowdsec_cti_test.go +++ b/pkg/exprhelpers/crowdsec_cti_test.go @@ -3,6 +3,7 @@ package exprhelpers import ( "bytes" "encoding/json" + "errors" "io" "net/http" "strings" @@ -10,8 +11,9 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" - "github.com/crowdsecurity/go-cs-lib/pkg/ptr" + "github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/crowdsec/pkg/cticlient" ) @@ -67,7 +69,7 @@ func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { } func smokeHandler(req *http.Request) *http.Response { - apiKey := req.Header.Get("x-api-key") + apiKey := req.Header.Get("X-Api-Key") if apiKey != validApiKey { return &http.Response{ StatusCode: http.StatusForbidden, @@ -77,6 +79,7 @@ func smokeHandler(req *http.Request) *http.Response { } requestedIP := strings.Split(req.URL.Path, "/")[3] + sample, ok := sampledata[requestedIP] if !ok { return &http.Response{ @@ -106,8 +109,21 @@ func smokeHandler(req *http.Request) *http.Response { } } +func TestNilClient(t *testing.T) { + defer ShutdownCrowdsecCTI() + + if err := InitCrowdsecCTI(ptr.Of(""), nil, nil, nil); !errors.Is(err, cticlient.ErrDisabled) { + t.Fatalf("failed to init CTI : %s", err) + } + + item, err := CrowdsecCTI("1.2.3.4") + assert.Equal(t, err, cticlient.ErrDisabled) + assert.Equal(t, &cticlient.SmokeItem{}, item) +} + func TestInvalidAuth(t *testing.T) { defer ShutdownCrowdsecCTI() + if err := InitCrowdsecCTI(ptr.Of("asdasd"), nil, nil, nil); err != nil { t.Fatalf("failed to init CTI : %s", err) } @@ -117,8 +133,8 @@ func TestInvalidAuth(t *testing.T) { })) item, err := CrowdsecCTI("1.2.3.4") - assert.Equal(t, item, &cticlient.SmokeItem{}) - assert.Equal(t, CTIApiEnabled, false) + assert.Equal(t, &cticlient.SmokeItem{}, item) + assert.False(t, CTIApiEnabled) assert.Equal(t, err, cticlient.ErrUnauthorized) //CTI is now disabled, all requests should return empty @@ -127,28 +143,30 @@ func TestInvalidAuth(t *testing.T) { })) item, err = CrowdsecCTI("1.2.3.4") - assert.Equal(t, item, &cticlient.SmokeItem{}) - assert.Equal(t, CTIApiEnabled, false) + assert.Equal(t, &cticlient.SmokeItem{}, item) + assert.False(t, CTIApiEnabled) assert.Equal(t, err, cticlient.ErrDisabled) } func TestNoKey(t *testing.T) { defer ShutdownCrowdsecCTI() + err := InitCrowdsecCTI(nil, nil, nil, nil) - assert.ErrorContains(t, err, "CTI API key not set") + require.ErrorIs(t, err, cticlient.ErrDisabled) //Replace the client created by InitCrowdsecCTI with one that uses a custom transport ctiClient = cticlient.NewCrowdsecCTIClient(cticlient.WithAPIKey("asdasd"), cticlient.WithHTTPClient(&http.Client{ Transport: RoundTripFunc(smokeHandler), })) item, err := CrowdsecCTI("1.2.3.4") - assert.Equal(t, item, &cticlient.SmokeItem{}) - assert.Equal(t, CTIApiEnabled, false) + assert.Equal(t, &cticlient.SmokeItem{}, item) + assert.False(t, CTIApiEnabled) assert.Equal(t, err, cticlient.ErrDisabled) } func TestCache(t *testing.T) { defer ShutdownCrowdsecCTI() + cacheDuration := 1 * time.Second if err := InitCrowdsecCTI(ptr.Of(validApiKey), &cacheDuration, nil, nil); err != nil { t.Fatalf("failed to init CTI : %s", err) @@ -161,28 +179,27 @@ func TestCache(t *testing.T) { item, err := CrowdsecCTI("1.2.3.4") ctiResp := item.(*cticlient.SmokeItem) assert.Equal(t, "1.2.3.4", ctiResp.Ip) - assert.Equal(t, CTIApiEnabled, true) - assert.Equal(t, CTICache.Len(true), 1) - assert.Equal(t, err, nil) + assert.True(t, CTIApiEnabled) + assert.Equal(t, 1, CTICache.Len(true)) + require.NoError(t, err) item, err = CrowdsecCTI("1.2.3.4") ctiResp = item.(*cticlient.SmokeItem) assert.Equal(t, "1.2.3.4", ctiResp.Ip) - assert.Equal(t, CTIApiEnabled, true) - assert.Equal(t, CTICache.Len(true), 1) - assert.Equal(t, err, nil) + assert.True(t, CTIApiEnabled) + assert.Equal(t, 1, CTICache.Len(true)) + require.NoError(t, err) time.Sleep(2 * time.Second) - assert.Equal(t, CTICache.Len(true), 0) + assert.Equal(t, 0, CTICache.Len(true)) item, err = CrowdsecCTI("1.2.3.4") ctiResp = item.(*cticlient.SmokeItem) assert.Equal(t, "1.2.3.4", ctiResp.Ip) - assert.Equal(t, CTIApiEnabled, true) - assert.Equal(t, CTICache.Len(true), 1) - assert.Equal(t, err, nil) - + assert.True(t, CTIApiEnabled) + assert.Equal(t, 1, CTICache.Len(true)) + require.NoError(t, err) } diff --git a/pkg/exprhelpers/debugger.go b/pkg/exprhelpers/debugger.go new file mode 100644 index 00000000000..2e47af6d1de --- /dev/null +++ b/pkg/exprhelpers/debugger.go @@ -0,0 +1,416 @@ +package exprhelpers + +import ( + "fmt" + "strconv" + "strings" + + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/file" + "github.com/expr-lang/expr/vm" + log "github.com/sirupsen/logrus" +) + +type ExprRuntimeDebug struct { + Logger *log.Entry + Lines []string + Outputs []OpOutput +} + +var IndentStep = 4 + +// we use this struct to store the output of the expr runtime +type OpOutput struct { + Code string //relevant code part + + CodeDepth int //level of nesting + BlockStart bool + BlockEnd bool + + Func bool //true if it's a function call + FuncName string + Args []string + FuncResults []string + // + Comparison bool //true if it's a comparison + Negated bool + Left string + Right string + // + JumpIf bool //true if it's conditional jump + IfTrue bool + IfFalse bool + // + Condition bool //true if it's a condition + ConditionIn bool + ConditionContains bool + //used for comparisons, conditional jumps and conditions + StrConditionResult string + ConditionResult *bool //should always be present for conditions + + // + Finalized bool //used when a node is finalized, we already fetched result from next OP +} + +func (o *OpOutput) String() string { + ret := fmt.Sprintf("%*c", o.CodeDepth, ' ') + if o.Code != "" { + ret += fmt.Sprintf("[%s]", o.Code) + } + ret += " " + + switch { + case o.BlockStart: + ret = fmt.Sprintf("%*cBLOCK_START [%s]", o.CodeDepth-IndentStep, ' ', o.Code) + return ret + case o.BlockEnd: + indent := o.CodeDepth - (IndentStep * 2) + if indent < 0 { + indent = 0 + } + ret = fmt.Sprintf("%*cBLOCK_END [%s]", indent, ' ', o.Code) + if o.StrConditionResult != "" { + ret += fmt.Sprintf(" -> %s", o.StrConditionResult) + } + return ret + //A block end can carry a value, for example if it's a count, any, all etc. XXX + case o.Func: + return ret + fmt.Sprintf("%s(%s) = %s", o.FuncName, strings.Join(o.Args, ", "), strings.Join(o.FuncResults, ", ")) + case o.Comparison: + if o.Negated { + ret += "NOT " + } + ret += fmt.Sprintf("%s == %s -> %s", o.Left, o.Right, o.StrConditionResult) + return ret + case o.ConditionIn: + return ret + fmt.Sprintf("%s in %s -> %s", o.Args[0], o.Args[1], o.StrConditionResult) + case o.ConditionContains: + return ret + fmt.Sprintf("%s contains %s -> %s", o.Args[0], o.Args[1], o.StrConditionResult) + case o.JumpIf && o.IfTrue: + if o.ConditionResult != nil { + if *o.ConditionResult { + return ret + "OR -> false" + } + return ret + "OR -> true" + } + return ret + "OR(?)" + case o.JumpIf && o.IfFalse: + if o.ConditionResult != nil { + if *o.ConditionResult { + return ret + "AND -> true" + } + return ret + "AND -> false" + } + return ret + "AND(?)" + } + return ret + "" +} + +func (erp ExprRuntimeDebug) extractCode(ip int, program *vm.Program) string { + locations := program.Locations() + src := string(program.Source()) + + currentInstruction := locations[ip] + + var closest *file.Location + + for i := ip + 1; i < len(locations); i++ { + if locations[i].From > currentInstruction.From { + if closest == nil || locations[i].From < closest.From { + closest = &locations[i] + } + } + } + + var end int + if closest == nil { + end = len(src) + } else { + end = closest.From + } + + return cleanTextForDebug(src[locations[ip].From:end]) +} + +func autoQuote(v any) string { + switch x := v.(type) { + case string: + //let's avoid printing long strings. it can happen ie. when we are debugging expr with `File()` or similar helpers + if len(x) > 40 { + return fmt.Sprintf("%q", x[:40]+"...") + } else { + return fmt.Sprintf("%q", x) + } + default: + return fmt.Sprintf("%v", x) + } +} + +func (erp ExprRuntimeDebug) ipDebug(ip int, vm *vm.VM, program *vm.Program, parts []string, outputs []OpOutput) ([]OpOutput, error) { + + IdxOut := len(outputs) + prevIdxOut := 0 + currentDepth := 0 + + //when there is a function call or comparison, we need to wait for the next instruction to get the result and "finalize" the previous one + if IdxOut > 0 { + prevIdxOut = IdxOut - 1 + currentDepth = outputs[prevIdxOut].CodeDepth + if outputs[prevIdxOut].Func && !outputs[prevIdxOut].Finalized { + stack := vm.Stack + num_items := 1 + for i := len(stack) - 1; i >= 0 && num_items > 0; i-- { + outputs[prevIdxOut].FuncResults = append(outputs[prevIdxOut].FuncResults, autoQuote(stack[i])) + num_items-- + } + outputs[prevIdxOut].Finalized = true + } else if (outputs[prevIdxOut].Comparison || outputs[prevIdxOut].Condition) && !outputs[prevIdxOut].Finalized { + stack := vm.Stack + outputs[prevIdxOut].StrConditionResult = fmt.Sprintf("%+v", stack) + if val, ok := stack[0].(bool); ok { + outputs[prevIdxOut].ConditionResult = new(bool) + *outputs[prevIdxOut].ConditionResult = val + } + outputs[prevIdxOut].Finalized = true + } + } + + erp.Logger.Tracef("[STEP %d:%s] (stack:%+v) (parts:%+v) {depth:%d}", ip, parts[1], vm.Stack, parts, currentDepth) + out := OpOutput{} + out.CodeDepth = currentDepth + out.Code = erp.extractCode(ip, program) + + switch parts[1] { + case "OpBegin": + out.CodeDepth += IndentStep + out.BlockStart = true + outputs = append(outputs, out) + case "OpEnd": + out.CodeDepth -= IndentStep + out.BlockEnd = true + //OpEnd can carry value, if it's any/all/count etc. + if len(vm.Stack) > 0 { + out.StrConditionResult = fmt.Sprintf("%v", vm.Stack) + } + outputs = append(outputs, out) + case "OpNot": + //negate the previous condition + outputs[prevIdxOut].Negated = true + case "OpTrue": //generated when possible ? (1 == 1) + out.Condition = true + out.ConditionResult = new(bool) + *out.ConditionResult = true + out.StrConditionResult = "true" + outputs = append(outputs, out) + case "OpFalse": //generated when possible ? (1 != 1) + out.Condition = true + out.ConditionResult = new(bool) + *out.ConditionResult = false + out.StrConditionResult = "false" + outputs = append(outputs, out) + case "OpJumpIfTrue": //OR + stack := vm.Stack + out.JumpIf = true + out.IfTrue = true + out.StrConditionResult = fmt.Sprintf("%v", stack[0]) + + if val, ok := stack[0].(bool); ok { + out.ConditionResult = new(bool) + *out.ConditionResult = val + } + outputs = append(outputs, out) + case "OpJumpIfFalse": //AND + stack := vm.Stack + out.JumpIf = true + out.IfFalse = true + out.StrConditionResult = fmt.Sprintf("%v", stack[0]) + if val, ok := stack[0].(bool); ok { + out.ConditionResult = new(bool) + *out.ConditionResult = val + } + outputs = append(outputs, out) + case "OpCall1": //Op for function calls + out.Func = true + out.FuncName = parts[3] + stack := vm.Stack + num_items := 1 + for i := len(stack) - 1; i >= 0 && num_items > 0; i-- { + out.Args = append(out.Args, autoQuote(stack[i])) + num_items-- + } + outputs = append(outputs, out) + case "OpCall2": //Op for function calls + out.Func = true + out.FuncName = parts[3] + stack := vm.Stack + num_items := 2 + for i := len(stack) - 1; i >= 0 && num_items > 0; i-- { + out.Args = append(out.Args, autoQuote(stack[i])) + num_items-- + } + outputs = append(outputs, out) + case "OpCall3": //Op for function calls + out.Func = true + out.FuncName = parts[3] + stack := vm.Stack + num_items := 3 + for i := len(stack) - 1; i >= 0 && num_items > 0; i-- { + out.Args = append(out.Args, autoQuote(stack[i])) + num_items-- + } + outputs = append(outputs, out) + //double check OpCallFast and OpCallTyped + case "OpCallFast", "OpCallTyped": + // + case "OpCallN": //Op for function calls with more than 3 args + out.Func = true + out.FuncName = parts[1] + stack := vm.Stack + + //for OpCallN, we get the number of args + if len(program.Arguments) >= ip { + nb_args := program.Arguments[ip] + if nb_args > 0 { + //we need to skip the top item on stack + for i := len(stack) - 2; i >= 0 && nb_args > 0; i-- { + out.Args = append(out.Args, autoQuote(stack[i])) + nb_args-- + } + } + } else { //let's blindly take the items on stack + for _, val := range vm.Stack { + out.Args = append(out.Args, autoQuote(val)) + } + } + outputs = append(outputs, out) + case "OpEqualString", "OpEqual", "OpEqualInt": //comparisons + stack := vm.Stack + out.Comparison = true + out.Left = autoQuote(stack[0]) + out.Right = autoQuote(stack[1]) + outputs = append(outputs, out) + case "OpIn": //in operator + stack := vm.Stack + out.Condition = true + out.ConditionIn = true + //seems that we tend to receive stack[1] as a map. + //it is tempting to use reflect to extract keys, but we end up with an array that doesn't match the initial order + //(because of the random order of the map) + out.Args = append(out.Args, autoQuote(stack[0])) + out.Args = append(out.Args, autoQuote(stack[1])) + outputs = append(outputs, out) + case "OpContains": //kind OpIn , but reverse + stack := vm.Stack + out.Condition = true + out.ConditionContains = true + //seems that we tend to receive stack[1] as a map. + //it is tempting to use reflect to extract keys, but we end up with an array that doesn't match the initial order + //(because of the random order of the map) + out.Args = append(out.Args, autoQuote(stack[0])) + out.Args = append(out.Args, autoQuote(stack[1])) + outputs = append(outputs, out) + } + return outputs, nil +} + +func (erp ExprRuntimeDebug) ipSeek(ip int) []string { + for i := range len(erp.Lines) { + parts := strings.Fields(erp.Lines[i]) + if len(parts) == 0 { + continue + } + if parts[0] == strconv.Itoa(ip) { + return parts + } + } + return nil +} + +func Run(program *vm.Program, env interface{}, logger *log.Entry, debug bool) (any, error) { + if debug { + dbgInfo, ret, err := RunWithDebug(program, env, logger) + DisplayExprDebug(program, dbgInfo, logger, ret) + return ret, err + } + return expr.Run(program, env) +} + +func cleanTextForDebug(text string) string { + text = strings.Join(strings.Fields(text), " ") + text = strings.Trim(text, " \t\n") + return text +} + +func DisplayExprDebug(program *vm.Program, outputs []OpOutput, logger *log.Entry, ret any) { + logger.Debugf("dbg(result=%v): %s", ret, cleanTextForDebug(string(program.Source()))) + for _, output := range outputs { + logger.Debugf("%s", output.String()) + } +} + +// TBD: Based on the level of the logger (ie. trace vs debug) we could decide to add more low level instructions (pop, push, etc.) +func RunWithDebug(program *vm.Program, env interface{}, logger *log.Entry) ([]OpOutput, any, error) { + outputs := []OpOutput{} + erp := ExprRuntimeDebug{ + Logger: logger, + } + vm := vm.Debug() + opcodes := program.Disassemble() + lines := strings.Split(opcodes, "\n") + erp.Lines = lines + + go func() { + //We must never return until the execution of the program is done + var err error + erp.Logger.Tracef("[START] ip 0") + ops := erp.ipSeek(0) + if ops == nil { + log.Warningf("error while debugging expr: failed getting ops for ip 0") + } + if outputs, err = erp.ipDebug(0, vm, program, ops, outputs); err != nil { + log.Warningf("error while debugging expr: error while debugging at ip 0") + } + vm.Step() + for ip := range vm.Position() { + ops := erp.ipSeek(ip) + if ops == nil { + erp.Logger.Tracef("[DONE] ip %d", ip) + break + } + if outputs, err = erp.ipDebug(ip, vm, program, ops, outputs); err != nil { + log.Warningf("error while debugging expr: error while debugging at ip %d", ip) + } + vm.Step() + } + }() + + var return_error error + ret, err := vm.Run(program, env) + //if the expr runtime failed, we don't need to wait for the debug to finish + if err != nil { + return_error = err + } + //the overall result of expression is the result of last op ? + if len(outputs) > 0 { + lastOutIdx := len(outputs) + if lastOutIdx > 0 { + lastOutIdx -= 1 + } + switch val := ret.(type) { + case bool: + log.Tracef("completing with bool %t", ret) + //if outputs[lastOutIdx].Comparison { + outputs[lastOutIdx].StrConditionResult = fmt.Sprintf("%v", ret) + outputs[lastOutIdx].ConditionResult = new(bool) + *outputs[lastOutIdx].ConditionResult = val + outputs[lastOutIdx].Finalized = true + default: + log.Tracef("completing with type %T -> %v", ret, ret) + outputs[lastOutIdx].StrConditionResult = fmt.Sprintf("%v", ret) + outputs[lastOutIdx].Finalized = true + } + } else { + log.Tracef("no output from expr runtime") + } + return outputs, ret, return_error +} diff --git a/pkg/exprhelpers/debugger_test.go b/pkg/exprhelpers/debugger_test.go new file mode 100644 index 00000000000..32144454084 --- /dev/null +++ b/pkg/exprhelpers/debugger_test.go @@ -0,0 +1,357 @@ +package exprhelpers + +import ( + "reflect" + "strings" + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/expr-lang/expr" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +type ExprDbgTest struct { + Name string + Expr string + ExpectedOutputs []OpOutput + ExpectedFailedCompile bool + ExpectedFailRuntime bool + Env map[string]interface{} + LogLevel log.Level +} + +// For the sake of testing functions with 2, 3 and N args +func UpperTwo(params ...any) (any, error) { + s := params[0].(string) + v := params[1].(string) + + return strings.ToUpper(s) + strings.ToUpper(v), nil +} + +func UpperThree(params ...any) (any, error) { + s := params[0].(string) + v := params[1].(string) + x := params[2].(string) + + return strings.ToUpper(s) + strings.ToUpper(v) + strings.ToUpper(x), nil +} + +func UpperN(params ...any) (any, error) { + s := params[0].(string) + v := params[1].(string) + x := params[2].(string) + y := params[3].(string) + + return strings.ToUpper(s) + strings.ToUpper(v) + strings.ToUpper(x) + strings.ToUpper(y), nil +} + +func boolPtr(b bool) *bool { + return &b +} + +type teststruct struct { + Foo string +} + +// You need to add the tag expr_debug when running the tests +func TestBaseDbg(t *testing.T) { + defaultEnv := map[string]interface{}{ + "queue": &types.Queue{}, + "evt": &types.Event{}, + "sample_array": []string{"a", "b", "c", "ZZ"}, + "base_string": "hello world", + "base_int": 42, + "base_float": 42.42, + "nilvar": &teststruct{}, + "base_struct": struct { + Foo string + Bar int + Myarr []string + }{ + Foo: "bar", + Bar: 42, + Myarr: []string{"a", "b", "c"}, + }, + } + // tips for the tests: + // use '%#v' to dump in golang syntax + // use regexp to clear empty/default fields: + // [a-z]+: (false|\[\]string\(nil\)|""), + // ConditionResult:(*bool) + + // Missing multi parametes function + tests := []ExprDbgTest{ + { + Name: "nil deref", + Expr: "Upper('1') == '1' && nilvar.Foo == '42'", + Env: defaultEnv, + ExpectedFailRuntime: true, + ExpectedOutputs: []OpOutput{ + {Code: "Upper('1')", CodeDepth: 0, Func: true, FuncName: "Upper", Args: []string{"\"1\""}, FuncResults: []string{"\"1\""}, ConditionResult: (*bool)(nil), Finalized: true}, + {Code: "== '1'", CodeDepth: 0, Comparison: true, Left: "\"1\"", Right: "\"1\"", StrConditionResult: "[true]", ConditionResult: boolPtr(true), Finalized: true}, + {Code: "&&", CodeDepth: 0, JumpIf: true, IfFalse: true, StrConditionResult: "", ConditionResult: boolPtr(true), Finalized: true}, + }, + }, + { + Name: "OpCall2", + Expr: "UpperTwo('hello', 'world') == 'HELLOWORLD'", + Env: defaultEnv, + ExpectedOutputs: []OpOutput{ + {Code: "UpperTwo('hello', 'world')", CodeDepth: 0, Func: true, FuncName: "UpperTwo", Args: []string{"\"world\"", "\"hello\""}, FuncResults: []string{"\"HELLOWORLD\""}, ConditionResult: (*bool)(nil), Finalized: true}, + {Code: "== 'HELLOWORLD'", CodeDepth: 0, Comparison: true, Left: "\"HELLOWORLD\"", Right: "\"HELLOWORLD\"", StrConditionResult: "true", ConditionResult: boolPtr(true), Finalized: true}, + }, + }, + { + Name: "OpCall3", + Expr: "UpperThree('hello', 'world', 'foo') == 'HELLOWORLDFOO'", + Env: defaultEnv, + ExpectedOutputs: []OpOutput{ + {Code: "UpperThree('hello', 'world', 'foo')", CodeDepth: 0, Func: true, FuncName: "UpperThree", Args: []string{"\"foo\"", "\"world\"", "\"hello\""}, FuncResults: []string{"\"HELLOWORLDFOO\""}, ConditionResult: (*bool)(nil), Finalized: true}, + {Code: "== 'HELLOWORLDFOO'", CodeDepth: 0, Comparison: true, Left: "\"HELLOWORLDFOO\"", Right: "\"HELLOWORLDFOO\"", StrConditionResult: "true", ConditionResult: boolPtr(true), Finalized: true}, + }, + }, + { + Name: "OpCallN", + Expr: "UpperN('hello', 'world', 'foo', 'lol') == UpperN('hello', 'world', 'foo', 'lol')", + Env: defaultEnv, + ExpectedOutputs: []OpOutput{ + {Code: "UpperN('hello', 'world', 'foo', 'lol')", CodeDepth: 0, Func: true, FuncName: "OpCallN", Args: []string{"\"lol\"", "\"foo\"", "\"world\"", "\"hello\""}, FuncResults: []string{"\"HELLOWORLDFOOLOL\""}, ConditionResult: (*bool)(nil), Finalized: true}, + {Code: "UpperN('hello', 'world', 'foo', 'lol')", CodeDepth: 0, Func: true, FuncName: "OpCallN", Args: []string{"\"lol\"", "\"foo\"", "\"world\"", "\"hello\""}, FuncResults: []string{"\"HELLOWORLDFOOLOL\""}, ConditionResult: (*bool)(nil), Finalized: true}, + {Code: "== UpperN('hello', 'world', 'foo', 'lol')", CodeDepth: 0, Comparison: true, Left: "\"HELLOWORLDFOOLOL\"", Right: "\"HELLOWORLDFOOLOL\"", StrConditionResult: "true", ConditionResult: boolPtr(true), Finalized: true}, + }, + }, + { + Name: "base string cmp", + Expr: "base_string == 'hello world'", + Env: defaultEnv, + ExpectedOutputs: []OpOutput{ + {Code: "== 'hello world'", CodeDepth: 0, Comparison: true, Left: "\"hello world\"", Right: "\"hello world\"", StrConditionResult: "true", ConditionResult: boolPtr(true), Finalized: true}, + }, + }, + { + Name: "loop with func call", + Expr: "count(base_struct.Myarr, {Upper(#) == 'C'}) == 1", + Env: defaultEnv, + ExpectedOutputs: []OpOutput{ + {Code: "count(base_struct.Myarr, {", CodeDepth: 4, BlockStart: true, ConditionResult: (*bool)(nil), Finalized: false}, + {Code: "Upper(#)", CodeDepth: 4, Func: true, FuncName: "Upper", Args: []string{"\"a\""}, FuncResults: []string{"\"A\""}, ConditionResult: (*bool)(nil), Finalized: true}, + {Code: "== 'C'})", CodeDepth: 4, Comparison: true, Left: "\"A\"", Right: "\"C\"", StrConditionResult: "[false]", ConditionResult: boolPtr(false), Finalized: true}, + {Code: "count(base_struct.Myarr, {Upper(#) == 'C'})", CodeDepth: 4, JumpIf: true, IfFalse: true, StrConditionResult: "false", ConditionResult: boolPtr(false), Finalized: false}, + {Code: "Upper(#)", CodeDepth: 4, Func: true, FuncName: "Upper", Args: []string{"\"b\""}, FuncResults: []string{"\"B\""}, ConditionResult: (*bool)(nil), Finalized: true}, + {Code: "== 'C'})", CodeDepth: 4, Comparison: true, Left: "\"B\"", Right: "\"C\"", StrConditionResult: "[false]", ConditionResult: boolPtr(false), Finalized: true}, + {Code: "count(base_struct.Myarr, {Upper(#) == 'C'})", CodeDepth: 4, JumpIf: true, IfFalse: true, StrConditionResult: "false", ConditionResult: boolPtr(false), Finalized: false}, + {Code: "Upper(#)", CodeDepth: 4, Func: true, FuncName: "Upper", Args: []string{"\"c\""}, FuncResults: []string{"\"C\""}, ConditionResult: (*bool)(nil), Finalized: true}, + {Code: "== 'C'})", CodeDepth: 4, Comparison: true, Left: "\"C\"", Right: "\"C\"", StrConditionResult: "[true]", ConditionResult: boolPtr(true), Finalized: true}, + {Code: "count(base_struct.Myarr, {Upper(#) == 'C'})", CodeDepth: 4, JumpIf: true, IfFalse: true, StrConditionResult: "true", ConditionResult: boolPtr(true), Finalized: false}, + {Code: "count(base_struct.Myarr, {Upper(#) == 'C'})", CodeDepth: 0, BlockEnd: true, StrConditionResult: "[1]", ConditionResult: (*bool)(nil), Finalized: false}, + {Code: "== 1", CodeDepth: 0, Comparison: true, Left: "1", Right: "1", StrConditionResult: "true", ConditionResult: boolPtr(true), Finalized: true}, + }, + }, + { + Name: "loop with func call and extra check", + Expr: "count(base_struct.Myarr, {Upper(#) == 'C'}) == 1 && Upper(base_struct.Foo) == 'BAR'", + Env: defaultEnv, + ExpectedOutputs: []OpOutput{ + {Code: "count(base_struct.Myarr, {", CodeDepth: 4, BlockStart: true, ConditionResult: (*bool)(nil), Finalized: false}, + {Code: "Upper(#)", CodeDepth: 4, Func: true, FuncName: "Upper", Args: []string{"\"a\""}, FuncResults: []string{"\"A\""}, ConditionResult: (*bool)(nil), Finalized: true}, + {Code: "== 'C'})", CodeDepth: 4, Comparison: true, Left: "\"A\"", Right: "\"C\"", StrConditionResult: "[false]", ConditionResult: boolPtr(false), Finalized: true}, + {Code: "count(base_struct.Myarr, {Upper(#) == 'C'})", CodeDepth: 4, JumpIf: true, IfFalse: true, StrConditionResult: "false", ConditionResult: boolPtr(false), Finalized: false}, + {Code: "Upper(#)", CodeDepth: 4, Func: true, FuncName: "Upper", Args: []string{"\"b\""}, FuncResults: []string{"\"B\""}, ConditionResult: (*bool)(nil), Finalized: true}, + {Code: "== 'C'})", CodeDepth: 4, Comparison: true, Left: "\"B\"", Right: "\"C\"", StrConditionResult: "[false]", ConditionResult: boolPtr(false), Finalized: true}, + {Code: "count(base_struct.Myarr, {Upper(#) == 'C'})", CodeDepth: 4, JumpIf: true, IfFalse: true, StrConditionResult: "false", ConditionResult: boolPtr(false), Finalized: false}, + {Code: "Upper(#)", CodeDepth: 4, Func: true, FuncName: "Upper", Args: []string{"\"c\""}, FuncResults: []string{"\"C\""}, ConditionResult: (*bool)(nil), Finalized: true}, + {Code: "== 'C'})", CodeDepth: 4, Comparison: true, Left: "\"C\"", Right: "\"C\"", StrConditionResult: "[true]", ConditionResult: boolPtr(true), Finalized: true}, + {Code: "count(base_struct.Myarr, {Upper(#) == 'C'})", CodeDepth: 4, JumpIf: true, IfFalse: true, StrConditionResult: "true", ConditionResult: boolPtr(true), Finalized: false}, + {Code: "count(base_struct.Myarr, {Upper(#) == 'C'})", CodeDepth: 0, BlockEnd: true, StrConditionResult: "[1]", ConditionResult: (*bool)(nil), Finalized: false}, + {Code: "== 1", CodeDepth: 0, Comparison: true, Left: "1", Right: "1", StrConditionResult: "[true]", ConditionResult: boolPtr(true), Finalized: true}, + {Code: "&&", CodeDepth: 0, JumpIf: true, IfFalse: true, StrConditionResult: "true", ConditionResult: boolPtr(true), Finalized: false}, + {Code: "Upper(base_struct.Foo)", CodeDepth: 0, Func: true, FuncName: "Upper", Args: []string{"\"bar\""}, FuncResults: []string{"\"BAR\""}, ConditionResult: (*bool)(nil), Finalized: true}, + {Code: "== 'BAR'", CodeDepth: 0, Comparison: true, Left: "\"BAR\"", Right: "\"BAR\"", StrConditionResult: "true", ConditionResult: boolPtr(true), Finalized: true}, + }, + }, + { + Name: "base 'in' test", + Expr: "base_int in [1,2,3,4,42]", + Env: defaultEnv, + ExpectedOutputs: []OpOutput{ + {Code: "in [1,2,3,4,42]", CodeDepth: 0, Args: []string{"42", "map[1:{} 2:{} 3:{} 4:{} 42:{}]"}, Condition: true, ConditionIn: true, StrConditionResult: "true", ConditionResult: boolPtr(true), Finalized: true}, + }, + }, + { + Name: "base string cmp", + Expr: "base_string == 'hello world'", + Env: defaultEnv, + ExpectedOutputs: []OpOutput{ + {Code: "== 'hello world'", CodeDepth: 0, Comparison: true, Left: "\"hello world\"", Right: "\"hello world\"", StrConditionResult: "true", ConditionResult: boolPtr(true), Finalized: true}, + }, + }, + { + Name: "base int cmp", + Expr: "base_int == 42", + Env: defaultEnv, + ExpectedOutputs: []OpOutput{ + {Code: "== 42", CodeDepth: 0, Comparison: true, Left: "42", Right: "42", StrConditionResult: "true", ConditionResult: boolPtr(true), Finalized: true}, + }, + }, + { + Name: "negative check", + Expr: "base_int != 43", + Env: defaultEnv, + ExpectedOutputs: []OpOutput{ + {Code: "!= 43", CodeDepth: 0, Negated: true, Comparison: true, Left: "42", Right: "43", StrConditionResult: "true", ConditionResult: boolPtr(true), Finalized: true}, + }, + }, + { + Name: "testing ORs", + Expr: "base_int == 43 || base_int == 42", + Env: defaultEnv, + ExpectedOutputs: []OpOutput{ + {Code: "== 43", CodeDepth: 0, Comparison: true, Left: "42", Right: "43", StrConditionResult: "[false]", ConditionResult: boolPtr(false), Finalized: true}, + {Code: "||", CodeDepth: 0, JumpIf: true, IfTrue: true, StrConditionResult: "false", ConditionResult: boolPtr(false), Finalized: false}, + {Code: "== 42", CodeDepth: 0, Comparison: true, Left: "42", Right: "42", StrConditionResult: "true", ConditionResult: boolPtr(true), Finalized: true}, + }, + }, + { + Name: "testing basic true", + Expr: "true", + Env: defaultEnv, + ExpectedOutputs: []OpOutput{ + {Code: "true", CodeDepth: 0, Condition: true, StrConditionResult: "true", ConditionResult: boolPtr(true), Finalized: true}, + }, + }, + { + Name: "testing basic false", + Expr: "false", + Env: defaultEnv, + ExpectedOutputs: []OpOutput{ + {Code: "false", CodeDepth: 0, Condition: true, StrConditionResult: "false", ConditionResult: boolPtr(false), Finalized: true}, + }, + }, + { + Name: "testing multi lines", + Expr: `base_int == 42 && + base_string == 'hello world' && + (base_struct.Bar == 41 || base_struct.Bar == 42)`, + Env: defaultEnv, + ExpectedOutputs: []OpOutput{ + {Code: "== 42", CodeDepth: 0, Comparison: true, Left: "42", Right: "42", StrConditionResult: "[true]", ConditionResult: boolPtr(true), Finalized: true}, + {Code: "&&", CodeDepth: 0, JumpIf: true, IfFalse: true, StrConditionResult: "true", ConditionResult: boolPtr(true), Finalized: false}, + {Code: "== 'hello world'", CodeDepth: 0, Comparison: true, Left: "\"hello world\"", Right: "\"hello world\"", StrConditionResult: "[true]", ConditionResult: boolPtr(true), Finalized: true}, + {Code: "&& (", CodeDepth: 0, JumpIf: true, IfFalse: true, StrConditionResult: "true", ConditionResult: boolPtr(true), Finalized: false}, + {Code: "== 41", CodeDepth: 0, Comparison: true, Left: "42", Right: "41", StrConditionResult: "[false]", ConditionResult: boolPtr(false), Finalized: true}, + {Code: "||", CodeDepth: 0, JumpIf: true, IfTrue: true, StrConditionResult: "false", ConditionResult: boolPtr(false), Finalized: false}, + {Code: "== 42)", CodeDepth: 0, Comparison: true, Left: "42", Right: "42", StrConditionResult: "true", ConditionResult: boolPtr(true), Finalized: true}, + }, + }, + { + Name: "upper + in", + Expr: "Upper(base_string) contains Upper('wOrlD')", + Env: defaultEnv, + ExpectedOutputs: []OpOutput{ + {Code: "Upper(base_string)", CodeDepth: 0, Func: true, FuncName: "Upper", Args: []string{"\"hello world\""}, FuncResults: []string{"\"HELLO WORLD\""}, ConditionResult: (*bool)(nil), Finalized: true}, + {Code: "Upper('wOrlD')", CodeDepth: 0, Func: true, FuncName: "Upper", Args: []string{"\"wOrlD\""}, FuncResults: []string{"\"WORLD\""}, ConditionResult: (*bool)(nil), Finalized: true}, + {Code: "contains Upper('wOrlD')", CodeDepth: 0, Args: []string{"\"HELLO WORLD\"", "\"WORLD\""}, Condition: true, ConditionContains: true, StrConditionResult: "true", ConditionResult: boolPtr(true), Finalized: true}, + }, + }, + { + Name: "upper + complex", + Expr: `( Upper(base_string) contains Upper('/someurl?x=1') || + Upper(base_string) contains Upper('/someotherurl?account-name=admin&account-status=1&ow=cmd') ) + and base_string startsWith ('40') and Upper(base_string) == 'POST'`, + Env: defaultEnv, + ExpectedOutputs: []OpOutput{ + {Code: "Upper(base_string)", CodeDepth: 0, Func: true, FuncName: "Upper", Args: []string{"\"hello world\""}, FuncResults: []string{"\"HELLO WORLD\""}, ConditionResult: (*bool)(nil), Finalized: true}, + {Code: "Upper('/someurl?x=1')", CodeDepth: 0, Func: true, FuncName: "Upper", Args: []string{"\"/someurl?x=1\""}, FuncResults: []string{"\"/SOMEURL?X=1\""}, ConditionResult: (*bool)(nil), Finalized: true}, + {Code: "contains Upper('/someurl?x=1')", CodeDepth: 0, Args: []string{"\"HELLO WORLD\"", "\"/SOMEURL?X=1\""}, Condition: true, ConditionContains: true, StrConditionResult: "[false]", ConditionResult: boolPtr(false), Finalized: true}, + {Code: "||", CodeDepth: 0, JumpIf: true, IfTrue: true, StrConditionResult: "false", ConditionResult: boolPtr(false), Finalized: false}, + {Code: "Upper(base_string)", CodeDepth: 0, Func: true, FuncName: "Upper", Args: []string{"\"hello world\""}, FuncResults: []string{"\"HELLO WORLD\""}, ConditionResult: (*bool)(nil), Finalized: true}, + {Code: "Upper('/someotherurl?account-name=admin&account-status=1&ow=cmd') )", CodeDepth: 0, Func: true, FuncName: "Upper", Args: []string{"\"/someotherurl?account-name=admin&account...\""}, FuncResults: []string{"\"/SOMEOTHERURL?ACCOUNT-NAME=ADMIN&ACCOUNT...\""}, ConditionResult: (*bool)(nil), Finalized: true}, + {Code: "contains Upper('/someotherurl?account-name=admin&account-status=1&ow=cmd') )", CodeDepth: 0, Args: []string{"\"HELLO WORLD\"", "\"/SOMEOTHERURL?ACCOUNT-NAME=ADMIN&ACCOUNT...\""}, Condition: true, ConditionContains: true, StrConditionResult: "[false]", ConditionResult: boolPtr(false), Finalized: true}, + {Code: "and", CodeDepth: 0, JumpIf: true, IfFalse: true, StrConditionResult: "false", ConditionResult: boolPtr(false), Finalized: true}, + }, + }, + } + + logger := log.WithField("test", "exprhelpers") + + for _, test := range tests { + if test.LogLevel != 0 { + log.SetLevel(test.LogLevel) + } else { + log.SetLevel(log.DebugLevel) + } + + extraFuncs := []expr.Option{} + extraFuncs = append(extraFuncs, + expr.Function("UpperTwo", + UpperTwo, + []interface{}{new(func(string, string) string)}..., + )) + extraFuncs = append(extraFuncs, + expr.Function("UpperThree", + UpperThree, + []interface{}{new(func(string, string, string) string)}..., + )) + extraFuncs = append(extraFuncs, + expr.Function("UpperN", + UpperN, + []interface{}{new(func(string, string, string, string) string)}..., + )) + supaEnv := GetExprOptions(test.Env) + supaEnv = append(supaEnv, extraFuncs...) + + prog, err := expr.Compile(test.Expr, supaEnv...) + if test.ExpectedFailedCompile { + if err == nil { + t.Fatalf("test %s : expected compile error", test.Name) + } + } else { + if err != nil { + t.Fatalf("test %s : unexpected compile error : %s", test.Name, err) + } + } + + if test.Name == "nil deref" { + test.Env["nilvar"] = nil + } + + outdbg, ret, err := RunWithDebug(prog, test.Env, logger) + + if test.ExpectedFailRuntime { + if err == nil { + t.Fatalf("test %s : expected runtime error", test.Name) + } + } else { + if err != nil { + t.Fatalf("test %s : unexpected runtime error : %s", test.Name, err) + } + } + + log.SetLevel(log.DebugLevel) + DisplayExprDebug(prog, outdbg, logger, ret) + + if len(outdbg) != len(test.ExpectedOutputs) { + t.Errorf("failed test %s", test.Name) + t.Errorf("%#v", outdbg) + // out, _ := yaml.Marshal(outdbg) + // fmt.Printf("%s", string(out)) + t.Fatalf("test %s : expected %d outputs, got %d", test.Name, len(test.ExpectedOutputs), len(outdbg)) + } + + for i, out := range outdbg { + if reflect.DeepEqual(out, test.ExpectedOutputs[i]) { + // DisplayExprDebug(prog, outdbg, logger, ret) + continue + } + + spew.Config.DisableMethods = true + + t.Errorf("failed test %s", test.Name) + t.Errorf("expected : %#v", test.ExpectedOutputs[i]) + t.Errorf("got : %#v", out) + t.Fatalf("%d/%d : mismatch", i, len(outdbg)) + } + } +} diff --git a/pkg/exprhelpers/expr_lib.go b/pkg/exprhelpers/expr_lib.go index f4e1f4722fd..b90c1986153 100644 --- a/pkg/exprhelpers/expr_lib.go +++ b/pkg/exprhelpers/expr_lib.go @@ -1,8 +1,11 @@ package exprhelpers import ( + "net" "time" + "github.com/oschwald/geoip2-golang" + "github.com/crowdsecurity/crowdsec/pkg/cticlient" ) @@ -20,6 +23,21 @@ var exprFuncs = []exprCustomFunc{ new(func(string) (*cticlient.SmokeItem, error)), }, }, + { + name: "Flatten", + function: Flatten, + signature: []interface{}{}, + }, + { + name: "Distinct", + function: Distinct, + signature: []interface{}{}, + }, + { + name: "FlattenDistinct", + function: FlattenDistinct, + signature: []interface{}{}, + }, { name: "Distance", function: Distance, @@ -216,6 +234,20 @@ var exprFuncs = []exprCustomFunc{ new(func(string) int), }, }, + { + name: "GetActiveDecisionsCount", + function: GetActiveDecisionsCount, + signature: []interface{}{ + new(func(string) int), + }, + }, + { + name: "GetActiveDecisionsTimeLeft", + function: GetActiveDecisionsTimeLeft, + signature: []interface{}{ + new(func(string) time.Duration), + }, + }, { name: "GetDecisionsSinceCount", function: GetDecisionsSinceCount, @@ -419,6 +451,48 @@ var exprFuncs = []exprCustomFunc{ new(func() (string, error)), }, }, + { + name: "FloatApproxEqual", + function: FloatApproxEqual, + signature: []interface{}{ + new(func(float64, float64) bool), + }, + }, + { + name: "LibInjectionIsSQLI", + function: LibInjectionIsSQLI, + signature: []interface{}{ + new(func(string) bool), + }, + }, + { + name: "LibInjectionIsXSS", + function: LibInjectionIsXSS, + signature: []interface{}{ + new(func(string) bool), + }, + }, + { + name: "GeoIPEnrich", + function: GeoIPEnrich, + signature: []interface{}{ + new(func(string) *geoip2.City), + }, + }, + { + name: "GeoIPASNEnrich", + function: GeoIPASNEnrich, + signature: []interface{}{ + new(func(string) *geoip2.ASN), + }, + }, + { + name: "GeoIPRangeEnrich", + function: GeoIPRangeEnrich, + signature: []interface{}{ + new(func(string) *net.IPNet), + }, + }, } //go 1.20 "CutPrefix": strings.CutPrefix, diff --git a/pkg/exprhelpers/exprlib_test.go b/pkg/exprhelpers/exprlib_test.go index 53f7d7d15cc..f2eb208ebfa 100644 --- a/pkg/exprhelpers/exprlib_test.go +++ b/pkg/exprhelpers/exprlib_test.go @@ -2,19 +2,18 @@ package exprhelpers import ( "context" - "fmt" + "errors" "os" "testing" "time" - "github.com/antonmedv/expr" - "github.com/pkg/errors" + "github.com/expr-lang/expr" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/crowdsecurity/go-cs-lib/pkg/cstest" - "github.com/crowdsecurity/go-cs-lib/pkg/ptr" + "github.com/crowdsecurity/go-cs-lib/cstest" + "github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/database" @@ -22,23 +21,24 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -var ( - TestFolder = "tests" -) +const TestFolder = "tests" func getDBClient(t *testing.T) *database.Client { t.Helper() + dbPath, err := os.CreateTemp("", "*sqlite") require.NoError(t, err) - testDbClient, err := database.NewClient(&csconfig.DatabaseCfg{ + ctx := context.Background() + + testDBClient, err := database.NewClient(ctx, &csconfig.DatabaseCfg{ Type: "sqlite", DbName: "crowdsec", DbPath: dbPath.Name(), }) require.NoError(t, err) - return testDbClient + return testDBClient } func TestVisitor(t *testing.T) { @@ -77,59 +77,50 @@ func TestVisitor(t *testing.T) { name: "debug : can't compile", filter: "static_one.foo.toto == 'lol'", result: false, - err: fmt.Errorf("bad syntax"), + err: errors.New("bad syntax"), env: map[string]interface{}{"static_one": map[string]string{"foo": "bar"}}, }, { name: "debug : can't compile #2", filter: "static_one.f!oo.to/to == 'lol'", result: false, - err: fmt.Errorf("bad syntax"), + err: errors.New("bad syntax"), env: map[string]interface{}{"static_one": map[string]string{"foo": "bar"}}, }, { name: "debug : can't compile #3", filter: "", result: false, - err: fmt.Errorf("bad syntax"), + err: errors.New("bad syntax"), env: map[string]interface{}{"static_one": map[string]string{"foo": "bar"}}, }, } log.SetLevel(log.DebugLevel) - clog := log.WithFields(log.Fields{ - "type": "test", - }) for _, test := range tests { compiledFilter, err := expr.Compile(test.filter, GetExprOptions(test.env)...) if err != nil && test.err == nil { - log.Fatalf("compile: %s", err) - } - debugFilter, err := NewDebugger(test.filter, GetExprOptions(test.env)...) - if err != nil && test.err == nil { - log.Fatalf("debug: %s", err) + t.Fatalf("compile: %s", err) } if compiledFilter != nil { result, err := expr.Run(compiledFilter, test.env) if err != nil && test.err == nil { - log.Fatalf("run : %s", err) + t.Fatalf("run: %s", err) } + if isOk := assert.Equal(t, test.result, result); !isOk { t.Fatalf("test '%s' : NOK", test.filter) } } - - if debugFilter != nil { - debugFilter.Run(clog, test.result, test.env) - } } } func TestMatch(t *testing.T) { err := Init(nil) require.NoError(t, err) + tests := []struct { glob string val string @@ -159,12 +150,15 @@ func TestMatch(t *testing.T) { "pattern": test.glob, "name": test.val, } + vm, err := expr.Compile(test.expr, GetExprOptions(env)...) if err != nil { t.Fatalf("pattern:%s val:%s NOK %s", test.glob, test.val, err) } + ret, err := expr.Run(vm, env) - assert.NoError(t, err) + require.NoError(t, err) + if isOk := assert.Equal(t, test.ret, ret); !isOk { t.Fatalf("pattern:%s val:%s NOK %t != %t", test.glob, test.val, ret, test.ret) } @@ -198,16 +192,18 @@ func TestDistanceHelper(t *testing.T) { "lat2": test.lat2, "lon2": test.lon2, } + vm, err := expr.Compile(test.expr, GetExprOptions(env)...) if err != nil { t.Fatalf("pattern:%s val:%s NOK %s", test.lat1, test.lon1, err) } + ret, err := expr.Run(vm, env) if test.valid { - assert.NoError(t, err) - assert.Equal(t, test.dist, ret) + require.NoError(t, err) + assert.InDelta(t, test.dist, ret, 0.000001) } else { - assert.NotNil(t, err) + require.Error(t, err) } }) } @@ -221,7 +217,7 @@ func TestRegexpCacheBehavior(t *testing.T) { err = FileInit(TestFolder, filename, "regex") require.NoError(t, err) - //cache with no TTL + // cache with no TTL err = RegexpCacheInit(filename, types.DataSource{Type: "regex", Size: ptr.Of(1)}) require.NoError(t, err) @@ -233,7 +229,7 @@ func TestRegexpCacheBehavior(t *testing.T) { assert.True(t, ret.(bool)) assert.Equal(t, 1, dataFileRegexCache[filename].Len(false)) - //cache with TTL + // cache with TTL ttl := 500 * time.Millisecond err = RegexpCacheInit(filename, types.DataSource{Type: "regex", Size: ptr.Of(2), TTL: &ttl}) require.NoError(t, err) @@ -248,12 +244,12 @@ func TestRegexpCacheBehavior(t *testing.T) { func TestRegexpInFile(t *testing.T) { if err := Init(nil); err != nil { - log.Fatal(err) + t.Fatal(err) } err := FileInit(TestFolder, "test_data_re.txt", "regex") if err != nil { - log.Fatal(err) + t.Fatal(err) } tests := []struct { @@ -291,21 +287,23 @@ func TestRegexpInFile(t *testing.T) { for _, test := range tests { compiledFilter, err := expr.Compile(test.filter, GetExprOptions(map[string]interface{}{})...) if err != nil { - log.Fatal(err) + t.Fatal(err) } + result, err := expr.Run(compiledFilter, map[string]interface{}{}) if err != nil { - log.Fatal(err) + t.Fatal(err) } + if isOk := assert.Equal(t, test.result, result); !isOk { - t.Fatalf("test '%s' : NOK", test.name) + t.Fatalf("test '%s': NOK", test.name) } } } func TestFileInit(t *testing.T) { if err := Init(nil); err != nil { - log.Fatal(err) + t.Fatal(err) } tests := []struct { @@ -343,42 +341,48 @@ func TestFileInit(t *testing.T) { for _, test := range tests { err := FileInit(TestFolder, test.filename, test.types) if err != nil { - log.Fatal(err) + t.Fatal(err) } - if test.types == "string" { + + switch test.types { + case "string": if _, ok := dataFile[test.filename]; !ok { t.Fatalf("test '%s' : NOK", test.name) } - if isOk := assert.Equal(t, test.result, len(dataFile[test.filename])); !isOk { + + if isOk := assert.Len(t, dataFile[test.filename], test.result); !isOk { t.Fatalf("test '%s' : NOK", test.name) } - } else if test.types == "regex" { + case "regex": if _, ok := dataFileRegex[test.filename]; !ok { t.Fatalf("test '%s' : NOK", test.name) } - if isOk := assert.Equal(t, test.result, len(dataFileRegex[test.filename])); !isOk { + + if isOk := assert.Len(t, dataFileRegex[test.filename], test.result); !isOk { t.Fatalf("test '%s' : NOK", test.name) } - } else { + default: if _, ok := dataFileRegex[test.filename]; ok { t.Fatalf("test '%s' : NOK", test.name) } + if _, ok := dataFile[test.filename]; ok { t.Fatalf("test '%s' : NOK", test.name) } } + log.Printf("test '%s' : OK", test.name) } } func TestFile(t *testing.T) { if err := Init(nil); err != nil { - log.Fatal(err) + t.Fatal(err) } err := FileInit(TestFolder, "test_data.txt", "string") if err != nil { - log.Fatal(err) + t.Fatal(err) } tests := []struct { @@ -416,23 +420,25 @@ func TestFile(t *testing.T) { for _, test := range tests { compiledFilter, err := expr.Compile(test.filter, GetExprOptions(map[string]interface{}{})...) if err != nil { - log.Fatal(err) + t.Fatal(err) } + result, err := expr.Run(compiledFilter, map[string]interface{}{}) if err != nil { - log.Fatal(err) + t.Fatal(err) } + if isOk := assert.Equal(t, test.result, result); !isOk { t.Fatalf("test '%s' : NOK", test.name) } - log.Printf("test '%s' : OK", test.name) + log.Printf("test '%s' : OK", test.name) } } func TestIpInRange(t *testing.T) { err := Init(nil) - assert.NoError(t, err) + require.NoError(t, err) tests := []struct { name string env map[string]interface{} @@ -480,12 +486,11 @@ func TestIpInRange(t *testing.T) { require.Equal(t, test.result, output) log.Printf("test '%s' : OK", test.name) } - } func TestIpToRange(t *testing.T) { err := Init(nil) - assert.NoError(t, err) + require.NoError(t, err) tests := []struct { name string env map[string]interface{} @@ -553,13 +558,11 @@ func TestIpToRange(t *testing.T) { require.Equal(t, test.result, output) log.Printf("test '%s' : OK", test.name) } - } func TestAtof(t *testing.T) { - err := Init(nil) - assert.NoError(t, err) + require.NoError(t, err) tests := []struct { name string @@ -590,7 +593,7 @@ func TestAtof(t *testing.T) { require.NoError(t, err) output, err := expr.Run(program, test.env) require.NoError(t, err) - require.Equal(t, test.result, output) + require.InDelta(t, test.result, output, 0.000001) } } @@ -603,13 +606,14 @@ func TestUpper(t *testing.T) { } err := Init(nil) - assert.NoError(t, err) + require.NoError(t, err) vm, err := expr.Compile("Upper(testStr)", GetExprOptions(env)...) - assert.NoError(t, err) + require.NoError(t, err) out, err := expr.Run(vm, env) - assert.NoError(t, err) + require.NoError(t, err) + v, ok := out.(string) if !ok { t.Fatalf("Upper() should return a string") @@ -622,6 +626,7 @@ func TestUpper(t *testing.T) { func TestTimeNow(t *testing.T) { now, _ := TimeNow() + ti, err := time.Parse(time.RFC3339, now.(string)) if err != nil { t.Fatalf("Error parsing the return value of TimeNow: %s", err) @@ -630,6 +635,7 @@ func TestTimeNow(t *testing.T) { if -1*time.Until(ti) > time.Second { t.Fatalf("TimeNow func should return time.Now().UTC()") } + log.Printf("test 'TimeNow()' : OK") } @@ -904,15 +910,14 @@ func TestLower(t *testing.T) { } func TestGetDecisionsCount(t *testing.T) { - var err error - var start_ip, start_sfx, end_ip, end_sfx int64 - var ip_sz int existingIP := "1.2.3.4" unknownIP := "1.2.3.5" - ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(existingIP) + + ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(existingIP) if err != nil { t.Errorf("unable to convert '%s' to int: %s", existingIP, err) } + // Add sample data to DB dbClient = getDBClient(t) @@ -931,11 +936,11 @@ func TestGetDecisionsCount(t *testing.T) { SaveX(context.Background()) if decision == nil { - assert.Error(t, errors.Errorf("Failed to create sample decision")) + require.Error(t, errors.New("Failed to create sample decision")) } err = Init(dbClient) - assert.NoError(t, err) + require.NoError(t, err) tests := []struct { name string @@ -991,13 +996,12 @@ func TestGetDecisionsCount(t *testing.T) { log.Printf("test '%s' : OK", test.name) } } + func TestGetDecisionsSinceCount(t *testing.T) { - var err error - var start_ip, start_sfx, end_ip, end_sfx int64 - var ip_sz int existingIP := "1.2.3.4" unknownIP := "1.2.3.5" - ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(existingIP) + + ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(existingIP) if err != nil { t.Errorf("unable to convert '%s' to int: %s", existingIP, err) } @@ -1018,8 +1022,9 @@ func TestGetDecisionsSinceCount(t *testing.T) { SetOrigin("CAPI"). SaveX(context.Background()) if decision == nil { - assert.Error(t, errors.Errorf("Failed to create sample decision")) + require.Error(t, errors.New("Failed to create sample decision")) } + decision2 := dbClient.Ent.Decision.Create(). SetCreatedAt(time.Now().AddDate(0, 0, -1)). SetUntil(time.Now().AddDate(0, 0, -1)). @@ -1034,12 +1039,13 @@ func TestGetDecisionsSinceCount(t *testing.T) { SetValue(existingIP). SetOrigin("CAPI"). SaveX(context.Background()) + if decision2 == nil { - assert.Error(t, errors.Errorf("Failed to create sample decision")) + require.Error(t, errors.New("Failed to create sample decision")) } err = Init(dbClient) - assert.NoError(t, err) + require.NoError(t, err) tests := []struct { name string @@ -1114,6 +1120,268 @@ func TestGetDecisionsSinceCount(t *testing.T) { } } +func TestGetActiveDecisionsCount(t *testing.T) { + existingIP := "1.2.3.4" + unknownIP := "1.2.3.5" + + ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(existingIP) + if err != nil { + t.Errorf("unable to convert '%s' to int: %s", existingIP, err) + } + + // Add sample data to DB + dbClient = getDBClient(t) + + decision := dbClient.Ent.Decision.Create(). + SetUntil(time.Now().UTC().Add(time.Hour)). + SetScenario("crowdsec/test"). + SetStartIP(start_ip). + SetStartSuffix(start_sfx). + SetEndIP(end_ip). + SetEndSuffix(end_sfx). + SetIPSize(int64(ip_sz)). + SetType("ban"). + SetScope("IP"). + SetValue(existingIP). + SetOrigin("CAPI"). + SaveX(context.Background()) + + if decision == nil { + require.Error(t, errors.New("Failed to create sample decision")) + } + + expiredDecision := dbClient.Ent.Decision.Create(). + SetUntil(time.Now().UTC().Add(-time.Hour)). + SetScenario("crowdsec/test"). + SetStartIP(start_ip). + SetStartSuffix(start_sfx). + SetEndIP(end_ip). + SetEndSuffix(end_sfx). + SetIPSize(int64(ip_sz)). + SetType("ban"). + SetScope("IP"). + SetValue(existingIP). + SetOrigin("CAPI"). + SaveX(context.Background()) + + if expiredDecision == nil { + require.Error(t, errors.New("Failed to create sample decision")) + } + + err = Init(dbClient) + require.NoError(t, err) + + tests := []struct { + name string + env map[string]interface{} + code string + result string + err string + }{ + { + name: "GetActiveDecisionsCount() test: existing IP count", + env: map[string]interface{}{ + "Alert": &models.Alert{ + Source: &models.Source{ + Value: &existingIP, + }, + Decisions: []*models.Decision{ + { + Value: &existingIP, + }, + }, + }, + }, + code: "Sprintf('%d', GetActiveDecisionsCount(Alert.GetValue()))", + result: "1", + err: "", + }, + { + name: "GetActiveDecisionsCount() test: unknown IP count", + env: map[string]interface{}{ + "Alert": &models.Alert{ + Source: &models.Source{ + Value: &unknownIP, + }, + Decisions: []*models.Decision{ + { + Value: &unknownIP, + }, + }, + }, + }, + code: "Sprintf('%d', GetActiveDecisionsCount(Alert.GetValue()))", + result: "0", + err: "", + }, + } + + for _, test := range tests { + program, err := expr.Compile(test.code, GetExprOptions(test.env)...) + require.NoError(t, err) + output, err := expr.Run(program, test.env) + require.NoError(t, err) + require.Equal(t, test.result, output) + log.Printf("test '%s' : OK", test.name) + } +} + +func TestGetActiveDecisionsTimeLeft(t *testing.T) { + existingIP := "1.2.3.4" + unknownIP := "1.2.3.5" + + ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(existingIP) + if err != nil { + t.Errorf("unable to convert '%s' to int: %s", existingIP, err) + } + + // Add sample data to DB + dbClient = getDBClient(t) + + decision := dbClient.Ent.Decision.Create(). + SetUntil(time.Now().UTC().Add(time.Hour)). + SetScenario("crowdsec/test"). + SetStartIP(start_ip). + SetStartSuffix(start_sfx). + SetEndIP(end_ip). + SetEndSuffix(end_sfx). + SetIPSize(int64(ip_sz)). + SetType("ban"). + SetScope("IP"). + SetValue(existingIP). + SetOrigin("CAPI"). + SaveX(context.Background()) + + if decision == nil { + require.Error(t, errors.New("Failed to create sample decision")) + } + + longerDecision := dbClient.Ent.Decision.Create(). + SetUntil(time.Now().UTC().Add(2 * time.Hour)). + SetScenario("crowdsec/test"). + SetStartIP(start_ip). + SetStartSuffix(start_sfx). + SetEndIP(end_ip). + SetEndSuffix(end_sfx). + SetIPSize(int64(ip_sz)). + SetType("ban"). + SetScope("IP"). + SetValue(existingIP). + SetOrigin("CAPI"). + SaveX(context.Background()) + + if longerDecision == nil { + require.Error(t, errors.New("Failed to create sample decision")) + } + + err = Init(dbClient) + require.NoError(t, err) + + tests := []struct { + name string + env map[string]interface{} + code string + min float64 + max float64 + err string + }{ + { + name: "GetActiveDecisionsTimeLeft() test: existing IP time left", + env: map[string]interface{}{ + "Alert": &models.Alert{ + Source: &models.Source{ + Value: &existingIP, + }, + Decisions: []*models.Decision{ + { + Value: &existingIP, + }, + }, + }, + }, + code: "GetActiveDecisionsTimeLeft(Alert.GetValue())", + min: 7195, // 5 seconds margin to make sure the test doesn't fail randomly in the CI + max: 7200, + err: "", + }, + { + name: "GetActiveDecisionsTimeLeft() test: unknown IP time left", + env: map[string]interface{}{ + "Alert": &models.Alert{ + Source: &models.Source{ + Value: &unknownIP, + }, + Decisions: []*models.Decision{ + { + Value: &unknownIP, + }, + }, + }, + }, + code: "GetActiveDecisionsTimeLeft(Alert.GetValue())", + min: 0, + max: 0, + err: "", + }, + { + name: "GetActiveDecisionsTimeLeft() test: existing IP and call time.Duration method", + env: map[string]interface{}{ + "Alert": &models.Alert{ + Source: &models.Source{ + Value: &existingIP, + }, + Decisions: []*models.Decision{ + { + Value: &existingIP, + }, + }, + }, + }, + code: "GetActiveDecisionsTimeLeft(Alert.GetValue()).Hours()", + min: 2, + max: 2, + }, + { + name: "GetActiveDecisionsTimeLeft() test: unknown IP and call time.Duration method", + env: map[string]interface{}{ + "Alert": &models.Alert{ + Source: &models.Source{ + Value: &unknownIP, + }, + Decisions: []*models.Decision{ + { + Value: &unknownIP, + }, + }, + }, + }, + code: "GetActiveDecisionsTimeLeft(Alert.GetValue()).Hours()", + min: 0, + max: 0, + }, + } + + delta := 0.001 + + for _, test := range tests { + program, err := expr.Compile(test.code, GetExprOptions(test.env)...) + require.NoError(t, err) + output, err := expr.Run(program, test.env) + require.NoError(t, err) + + switch o := output.(type) { + case time.Duration: + require.LessOrEqual(t, int(o.Seconds()), int(test.max)) + require.GreaterOrEqual(t, int(o.Seconds()), int(test.min)) + case float64: + require.LessOrEqual(t, o, test.max+delta) + require.GreaterOrEqual(t, o, test.min-delta) + default: + t.Fatalf("GetActiveDecisionsTimeLeft() should return a time.Duration or a float64") + } + } +} + func TestParseUnixTime(t *testing.T) { tests := []struct { name string @@ -1124,12 +1392,12 @@ func TestParseUnixTime(t *testing.T) { { name: "ParseUnix() test: valid value with milli", value: "1672239773.3590894", - expected: time.Date(2022, 12, 28, 15, 02, 53, 0, time.UTC), + expected: time.Date(2022, 12, 28, 15, 2, 53, 0, time.UTC), }, { name: "ParseUnix() test: valid value without milli", value: "1672239773", - expected: time.Date(2022, 12, 28, 15, 02, 53, 0, time.UTC), + expected: time.Date(2022, 12, 28, 15, 2, 53, 0, time.UTC), }, { name: "ParseUnix() test: invalid input", @@ -1146,13 +1414,14 @@ func TestParseUnixTime(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { output, err := ParseUnixTime(tc.value) cstest.RequireErrorContains(t, err, tc.expectedErr) + if tc.expectedErr != "" { return } + require.WithinDuration(t, tc.expected, output.(time.Time), time.Second) }) } @@ -1160,8 +1429,9 @@ func TestParseUnixTime(t *testing.T) { func TestIsIp(t *testing.T) { if err := Init(nil); err != nil { - log.Fatal(err) + t.Fatal(err) } + tests := []struct { name string expr string @@ -1245,17 +1515,18 @@ func TestIsIp(t *testing.T) { expectedBuildErr: true, }, } + for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { vm, err := expr.Compile(tc.expr, GetExprOptions(map[string]interface{}{"value": tc.value})...) if tc.expectedBuildErr { - assert.Error(t, err) + require.Error(t, err) return } - assert.NoError(t, err) + + require.NoError(t, err) output, err := expr.Run(vm, map[string]interface{}{"value": tc.value}) - assert.NoError(t, err) + require.NoError(t, err) assert.IsType(t, tc.expected, output) assert.Equal(t, tc.expected, output.(bool)) }) @@ -1265,6 +1536,7 @@ func TestIsIp(t *testing.T) { func TestToString(t *testing.T) { err := Init(nil) require.NoError(t, err) + tests := []struct { name string value interface{} @@ -1297,12 +1569,11 @@ func TestToString(t *testing.T) { }, } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { vm, err := expr.Compile(tc.expr, GetExprOptions(map[string]interface{}{"value": tc.value})...) - assert.NoError(t, err) + require.NoError(t, err) output, err := expr.Run(vm, map[string]interface{}{"value": tc.value}) - assert.NoError(t, err) + require.NoError(t, err) require.Equal(t, tc.expected, output) }) } @@ -1344,20 +1615,22 @@ func TestB64Decode(t *testing.T) { }, } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { vm, err := expr.Compile(tc.expr, GetExprOptions(map[string]interface{}{"value": tc.value})...) if tc.expectedBuildErr { - assert.Error(t, err) + require.Error(t, err) return } - assert.NoError(t, err) + + require.NoError(t, err) + output, err := expr.Run(vm, map[string]interface{}{"value": tc.value}) if tc.expectedRuntimeErr { - assert.Error(t, err) + require.Error(t, err) return } - assert.NoError(t, err) + + require.NoError(t, err) require.Equal(t, tc.expected, output) }) } @@ -1414,7 +1687,6 @@ func TestParseKv(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { outMap := make(map[string]interface{}) env := map[string]interface{}{ @@ -1422,9 +1694,9 @@ func TestParseKv(t *testing.T) { "out": outMap, } vm, err := expr.Compile(tc.expr, GetExprOptions(env)...) - assert.NoError(t, err) + require.NoError(t, err) _, err = expr.Run(vm, env) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, tc.expected, outMap["a"]) }) } diff --git a/pkg/exprhelpers/geoip.go b/pkg/exprhelpers/geoip.go new file mode 100644 index 00000000000..fb0c344d884 --- /dev/null +++ b/pkg/exprhelpers/geoip.go @@ -0,0 +1,63 @@ +package exprhelpers + +import ( + "net" +) + +func GeoIPEnrich(params ...any) (any, error) { + if geoIPCityReader == nil { + return nil, nil + } + + ip := params[0].(string) + + parsedIP := net.ParseIP(ip) + + city, err := geoIPCityReader.City(parsedIP) + + if err != nil { + return nil, err + } + + return city, nil +} + +func GeoIPASNEnrich(params ...any) (any, error) { + if geoIPASNReader == nil { + return nil, nil + } + + ip := params[0].(string) + + parsedIP := net.ParseIP(ip) + asn, err := geoIPASNReader.ASN(parsedIP) + + if err != nil { + return nil, err + } + + return asn, nil +} + +func GeoIPRangeEnrich(params ...any) (any, error) { + if geoIPRangeReader == nil { + return nil, nil + } + + ip := params[0].(string) + + var dummy interface{} + + parsedIP := net.ParseIP(ip) + rangeIP, ok, err := geoIPRangeReader.LookupNetwork(parsedIP, &dummy) + + if err != nil { + return nil, err + } + + if !ok { + return nil, nil + } + + return rangeIP, nil +} diff --git a/pkg/exprhelpers/helpers.go b/pkg/exprhelpers/helpers.go index a5f45c4b076..9bc991a8f2d 100644 --- a/pkg/exprhelpers/helpers.go +++ b/pkg/exprhelpers/helpers.go @@ -2,28 +2,34 @@ package exprhelpers import ( "bufio" + "context" "encoding/base64" + "errors" "fmt" + "math" "net" "net/url" "os" - "path" + "path/filepath" + "reflect" "regexp" "strconv" "strings" "time" - "github.com/antonmedv/expr" "github.com/bluele/gcache" "github.com/c-robinson/iplib" "github.com/cespare/xxhash/v2" "github.com/davecgh/go-spew/spew" + "github.com/expr-lang/expr" + "github.com/oschwald/geoip2-golang" + "github.com/oschwald/maxminddb-golang" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" "github.com/umahmood/haversine" "github.com/wasilibs/go-re2" - "github.com/crowdsecurity/go-cs-lib/pkg/ptr" + "github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/crowdsec/pkg/cache" "github.com/crowdsecurity/crowdsec/pkg/database" @@ -31,9 +37,11 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -var dataFile map[string][]string -var dataFileRegex map[string][]*regexp.Regexp -var dataFileRe2 map[string][]*re2.Regexp +var ( + dataFile map[string][]string + dataFileRegex map[string][]*regexp.Regexp + dataFileRe2 map[string][]*re2.Regexp +) // This is used to (optionally) cache regexp results for RegexpInFile operations var dataFileRegexCache map[string]gcache.Cache = make(map[string]gcache.Cache) @@ -53,42 +61,88 @@ var exprFunctionOptions []expr.Option var keyValuePattern = regexp.MustCompile(`(?P[^=\s]+)=(?:"(?P[^"\\]*(?:\\.[^"\\]*)*)"|(?P[^=\s]+)|\s*)`) +var ( + geoIPCityReader *geoip2.Reader + geoIPASNReader *geoip2.Reader + geoIPRangeReader *maxminddb.Reader +) + func GetExprOptions(ctx map[string]interface{}) []expr.Option { + if len(exprFunctionOptions) == 0 { + exprFunctionOptions = []expr.Option{} + for _, function := range exprFuncs { + exprFunctionOptions = append(exprFunctionOptions, + expr.Function(function.name, + function.function, + function.signature..., + )) + } + } + ret := []expr.Option{} ret = append(ret, exprFunctionOptions...) ret = append(ret, expr.Env(ctx)) + return ret } +func GeoIPInit(datadir string) error { + var err error + + geoIPCityReader, err = geoip2.Open(filepath.Join(datadir, "GeoLite2-City.mmdb")) + if err != nil { + log.Errorf("unable to open GeoLite2-City.mmdb : %s", err) + return err + } + + geoIPASNReader, err = geoip2.Open(filepath.Join(datadir, "GeoLite2-ASN.mmdb")) + if err != nil { + log.Errorf("unable to open GeoLite2-ASN.mmdb : %s", err) + return err + } + + geoIPRangeReader, err = maxminddb.Open(filepath.Join(datadir, "GeoLite2-ASN.mmdb")) + if err != nil { + log.Errorf("unable to open GeoLite2-ASN.mmdb : %s", err) + return err + } + + return nil +} + +func GeoIPClose() { + if geoIPCityReader != nil { + geoIPCityReader.Close() + } + + if geoIPASNReader != nil { + geoIPASNReader.Close() + } + + if geoIPRangeReader != nil { + geoIPRangeReader.Close() + } +} + func Init(databaseClient *database.Client) error { dataFile = make(map[string][]string) dataFileRegex = make(map[string][]*regexp.Regexp) dataFileRe2 = make(map[string][]*re2.Regexp) dbClient = databaseClient - - exprFunctionOptions = []expr.Option{} - for _, function := range exprFuncs { - exprFunctionOptions = append(exprFunctionOptions, - expr.Function(function.name, - function.function, - function.signature..., - )) - } - + XMLCacheInit() return nil } func RegexpCacheInit(filename string, CacheCfg types.DataSource) error { - - //cache is explicitly disabled + // cache is explicitly disabled if CacheCfg.Cache != nil && !*CacheCfg.Cache { return nil } - //cache is implicitly disabled if no cache config is provided + // cache is implicitly disabled if no cache config is provided if CacheCfg.Strategy == nil && CacheCfg.TTL == nil && CacheCfg.Size == nil { return nil } - //cache is enabled + // cache is enabled if CacheCfg.Size == nil { CacheCfg.Size = ptr.Of(50) @@ -99,6 +153,7 @@ func RegexpCacheInit(filename string, CacheCfg types.DataSource) error { if CacheCfg.Strategy == nil { CacheCfg.Strategy = ptr.Of("LRU") } + switch *CacheCfg.Strategy { case "LRU": gc = gc.LRU() @@ -113,14 +168,17 @@ func RegexpCacheInit(filename string, CacheCfg types.DataSource) error { if CacheCfg.TTL != nil { gc.Expiration(*CacheCfg.TTL) } + cache := gc.Build() dataFileRegexCache[filename] = cache + return nil } // UpdateCacheMetrics is called directly by the prom handler func UpdateRegexpCacheMetrics() { RegexpCacheMetrics.Reset() + for name := range dataFileRegexCache { RegexpCacheMetrics.With(prometheus.Labels{"name": name}).Set(float64(dataFileRegexCache[name].Len(true))) } @@ -128,46 +186,117 @@ func UpdateRegexpCacheMetrics() { func FileInit(fileFolder string, filename string, fileType string) error { log.Debugf("init (folder:%s) (file:%s) (type:%s)", fileFolder, filename, fileType) - filepath := path.Join(fileFolder, filename) - file, err := os.Open(filepath) - if err != nil { - return err - } - defer file.Close() if fileType == "" { log.Debugf("ignored file %s%s because no type specified", fileFolder, filename) return nil } - if _, ok := dataFile[filename]; !ok { - dataFile[filename] = []string{} + + ok, err := existsInFileMaps(filename, fileType) + if ok { + log.Debugf("ignored file %s%s because already loaded", fileFolder, filename) + return nil + } + if err != nil { + return err + } + + filepath := filepath.Join(fileFolder, filename) + + file, err := os.Open(filepath) + if err != nil { + return err } + defer file.Close() + scanner := bufio.NewScanner(file) for scanner.Scan() { if strings.HasPrefix(scanner.Text(), "#") { // allow comments continue } - if len(scanner.Text()) == 0 { //skip empty lines + if scanner.Text() == "" { //skip empty lines continue } + switch fileType { case "regex", "regexp": if fflag.Re2RegexpInfileSupport.IsEnabled() { dataFileRe2[filename] = append(dataFileRe2[filename], re2.MustCompile(scanner.Text())) - } else { - dataFileRegex[filename] = append(dataFileRegex[filename], regexp.MustCompile(scanner.Text())) + continue } + + dataFileRegex[filename] = append(dataFileRegex[filename], regexp.MustCompile(scanner.Text())) case "string": dataFile[filename] = append(dataFile[filename], scanner.Text()) - default: - return fmt.Errorf("unknown data type '%s' for : '%s'", fileType, filename) } } - if err := scanner.Err(); err != nil { - return err + return scanner.Err() +} + +// Expr helpers + +func Distinct(params ...any) (any, error) { + if rt := reflect.TypeOf(params[0]).Kind(); rt != reflect.Slice && rt != reflect.Array { + return nil, nil } - return nil + array := params[0].([]interface{}) + if array == nil { + return []interface{}{}, nil + } + + exists := make(map[any]bool) + ret := make([]interface{}, 0) + + for _, val := range array { + if _, ok := exists[val]; !ok { + exists[val] = true + ret = append(ret, val) + } + } + return ret, nil +} + +func FlattenDistinct(params ...any) (any, error) { + return Distinct(flatten(nil, reflect.ValueOf(params))) //nolint:asasalint +} + +func Flatten(params ...any) (any, error) { + return flatten(nil, reflect.ValueOf(params)), nil +} + +func flatten(args []interface{}, v reflect.Value) []interface{} { + if v.Kind() == reflect.Interface { + v = v.Elem() + } + + if v.Kind() == reflect.Array || v.Kind() == reflect.Slice { + for i := range v.Len() { + args = flatten(args, v.Index(i)) + } + } else { + args = append(args, v.Interface()) + } + + return args +} + +func existsInFileMaps(filename string, ftype string) (bool, error) { + ok := false + var err error + switch ftype { + case "regex", "regexp": + if fflag.Re2RegexpInfileSupport.IsEnabled() { + _, ok = dataFileRe2[filename] + } else { + _, ok = dataFileRegex[filename] + } + case "string": + _, ok = dataFile[filename] + default: + err = fmt.Errorf("unknown data type '%s' for : '%s'", ftype, filename) + } + return ok, err } //Expr helpers @@ -464,7 +593,10 @@ func GetDecisionsCount(params ...any) (any, error) { return 0, nil } - count, err := dbClient.CountDecisionsByValue(value) + + ctx := context.TODO() + + count, err := dbClient.CountDecisionsByValue(ctx, value) if err != nil { log.Errorf("Failed to get decisions count from value '%s'", value) return 0, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility @@ -477,7 +609,7 @@ func GetDecisionsSinceCount(params ...any) (any, error) { value := params[0].(string) since := params[1].(string) if dbClient == nil { - log.Error("No database config to call GetDecisionsCount()") + log.Error("No database config to call GetDecisionsSinceCount()") return 0, nil } sinceDuration, err := time.ParseDuration(since) @@ -485,8 +617,11 @@ func GetDecisionsSinceCount(params ...any) (any, error) { log.Errorf("Failed to parse since parameter '%s' : %s", since, err) return 0, nil } + + ctx := context.TODO() sinceTime := time.Now().UTC().Add(-sinceDuration) - count, err := dbClient.CountDecisionsSinceByValue(value, sinceTime) + + count, err := dbClient.CountDecisionsSinceByValue(ctx, value, sinceTime) if err != nil { log.Errorf("Failed to get decisions count from value '%s'", value) return 0, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility @@ -494,6 +629,36 @@ func GetDecisionsSinceCount(params ...any) (any, error) { return count, nil } +func GetActiveDecisionsCount(params ...any) (any, error) { + value := params[0].(string) + if dbClient == nil { + log.Error("No database config to call GetActiveDecisionsCount()") + return 0, nil + } + ctx := context.TODO() + count, err := dbClient.CountActiveDecisionsByValue(ctx, value) + if err != nil { + log.Errorf("Failed to get active decisions count from value '%s'", value) + return 0, err + } + return count, nil +} + +func GetActiveDecisionsTimeLeft(params ...any) (any, error) { + value := params[0].(string) + if dbClient == nil { + log.Error("No database config to call GetActiveDecisionsTimeLeft()") + return 0, nil + } + ctx := context.TODO() + timeLeft, err := dbClient.GetActiveDecisionsTimeLeftByValue(ctx, value) + if err != nil { + log.Errorf("Failed to get active decisions time left from value '%s'", value) + return 0, err + } + return timeLeft, nil +} + // func LookupHost(value string) []string { func LookupHost(params ...any) (any, error) { value := params[0].(string) @@ -589,6 +754,16 @@ func Match(params ...any) (any, error) { return matched, nil } +func FloatApproxEqual(params ...any) (any, error) { + float1 := params[0].(float64) + float2 := params[1].(float64) + + if math.Abs(float1-float2) < 1e-6 { + return true, nil + } + return false, nil +} + func B64Decode(params ...any) (any, error) { encoded := params[0].(string) decoded, err := base64.StdEncoding.DecodeString(encoded) @@ -599,7 +774,6 @@ func B64Decode(params ...any) (any, error) { } func ParseKV(params ...any) (any, error) { - blob := params[0].(string) target := params[1].(map[string]interface{}) prefix := params[2].(string) @@ -607,7 +781,7 @@ func ParseKV(params ...any) (any, error) { matches := keyValuePattern.FindAllStringSubmatch(blob, -1) if matches == nil { log.Errorf("could not find any key/value pair in line") - return nil, fmt.Errorf("invalid input format") + return nil, errors.New("invalid input format") } if _, ok := target[prefix]; !ok { target[prefix] = make(map[string]string) @@ -615,7 +789,7 @@ func ParseKV(params ...any) (any, error) { _, ok := target[prefix].(map[string]string) if !ok { log.Errorf("ParseKV: target is not a map[string]string") - return nil, fmt.Errorf("target is not a map[string]string") + return nil, errors.New("target is not a map[string]string") } } for _, match := range matches { diff --git a/pkg/exprhelpers/jsonextract.go b/pkg/exprhelpers/jsonextract.go index a616588a76b..64ed97873d6 100644 --- a/pkg/exprhelpers/jsonextract.go +++ b/pkg/exprhelpers/jsonextract.go @@ -7,7 +7,6 @@ import ( "strings" "github.com/buger/jsonparser" - log "github.com/sirupsen/logrus" ) @@ -15,11 +14,11 @@ import ( func JsonExtractLib(params ...any) (any, error) { jsblob := params[0].(string) target := params[1].([]string) + value, dataType, _, err := jsonparser.Get( jsonparser.StringToBytes(jsblob), target..., ) - if err != nil { if errors.Is(err, jsonparser.KeyPathNotFoundError) { log.Debugf("%+v doesn't exist", target) @@ -93,7 +92,6 @@ func jsonExtractType(jsblob string, target string, t jsonparser.ValueType) ([]by jsonparser.StringToBytes(jsblob), fullpath..., ) - if err != nil { if errors.Is(err, jsonparser.KeyPathNotFoundError) { log.Debugf("Key %+v doesn't exist", target) @@ -115,8 +113,8 @@ func jsonExtractType(jsblob string, target string, t jsonparser.ValueType) ([]by func JsonExtractSlice(params ...any) (any, error) { jsblob := params[0].(string) target := params[1].(string) - value, err := jsonExtractType(jsblob, target, jsonparser.Array) + value, err := jsonExtractType(jsblob, target, jsonparser.Array) if err != nil { log.Errorf("JsonExtractSlice : %s", err) return []interface{}(nil), nil @@ -136,8 +134,8 @@ func JsonExtractSlice(params ...any) (any, error) { func JsonExtractObject(params ...any) (any, error) { jsblob := params[0].(string) target := params[1].(string) - value, err := jsonExtractType(jsblob, target, jsonparser.Object) + value, err := jsonExtractType(jsblob, target, jsonparser.Object) if err != nil { log.Errorf("JsonExtractObject: %s", err) return map[string]interface{}(nil), nil @@ -174,7 +172,7 @@ func UnmarshalJSON(params ...any) (any, error) { err := json.Unmarshal([]byte(jsonBlob), &out) if err != nil { - log.Errorf("UnmarshalJSON : %s", err) + log.WithField("line", jsonBlob).Errorf("UnmarshalJSON : %s", err) return nil, err } target[key] = out diff --git a/pkg/exprhelpers/jsonextract_test.go b/pkg/exprhelpers/jsonextract_test.go index 481c7d723ff..5845c3ae66b 100644 --- a/pkg/exprhelpers/jsonextract_test.go +++ b/pkg/exprhelpers/jsonextract_test.go @@ -3,20 +3,19 @@ package exprhelpers import ( "testing" - log "github.com/sirupsen/logrus" - - "github.com/antonmedv/expr" + "github.com/expr-lang/expr" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestJsonExtract(t *testing.T) { if err := Init(nil); err != nil { - log.Fatal(err) + t.Fatal(err) } err := FileInit(TestFolder, "test_data_re.txt", "regex") if err != nil { - log.Fatal(err) + t.Fatal(err) } tests := []struct { @@ -56,22 +55,22 @@ func TestJsonExtract(t *testing.T) { "target": test.targetField, } vm, err := expr.Compile(test.expr, GetExprOptions(env)...) - assert.NoError(t, err) + require.NoError(t, err) out, err := expr.Run(vm, env) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, test.expectResult, out) }) } - } + func TestJsonExtractUnescape(t *testing.T) { if err := Init(nil); err != nil { - log.Fatal(err) + t.Fatal(err) } err := FileInit(TestFolder, "test_data_re.txt", "regex") if err != nil { - log.Fatal(err) + t.Fatal(err) } tests := []struct { @@ -104,9 +103,9 @@ func TestJsonExtractUnescape(t *testing.T) { "target": test.targetField, } vm, err := expr.Compile(test.expr, GetExprOptions(env)...) - assert.NoError(t, err) + require.NoError(t, err) out, err := expr.Run(vm, env) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, test.expectResult, out) }) } @@ -114,12 +113,12 @@ func TestJsonExtractUnescape(t *testing.T) { func TestJsonExtractSlice(t *testing.T) { if err := Init(nil); err != nil { - log.Fatal(err) + t.Fatal(err) } err := FileInit(TestFolder, "test_data_re.txt", "regex") if err != nil { - log.Fatal(err) + t.Fatal(err) } tests := []struct { @@ -160,16 +159,15 @@ func TestJsonExtractSlice(t *testing.T) { } for _, test := range tests { - test := test t.Run(test.name, func(t *testing.T) { env := map[string]interface{}{ "blob": test.jsonBlob, "target": test.targetField, } vm, err := expr.Compile(test.expr, GetExprOptions(env)...) - assert.NoError(t, err) + require.NoError(t, err) out, err := expr.Run(vm, env) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, test.expectResult, out) }) } @@ -177,12 +175,12 @@ func TestJsonExtractSlice(t *testing.T) { func TestJsonExtractObject(t *testing.T) { if err := Init(nil); err != nil { - log.Fatal(err) + t.Fatal(err) } err := FileInit(TestFolder, "test_data_re.txt", "regex") if err != nil { - log.Fatal(err) + t.Fatal(err) } tests := []struct { @@ -216,16 +214,15 @@ func TestJsonExtractObject(t *testing.T) { } for _, test := range tests { - test := test t.Run(test.name, func(t *testing.T) { env := map[string]interface{}{ "blob": test.jsonBlob, "target": test.targetField, } vm, err := expr.Compile(test.expr, GetExprOptions(env)...) - assert.NoError(t, err) + require.NoError(t, err) out, err := expr.Run(vm, env) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, test.expectResult, out) }) } @@ -233,7 +230,8 @@ func TestJsonExtractObject(t *testing.T) { func TestToJson(t *testing.T) { err := Init(nil) - assert.NoError(t, err) + require.NoError(t, err) + tests := []struct { name string obj interface{} @@ -298,9 +296,9 @@ func TestToJson(t *testing.T) { "obj": test.obj, } vm, err := expr.Compile(test.expr, GetExprOptions(env)...) - assert.NoError(t, err) + require.NoError(t, err) out, err := expr.Run(vm, env) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, test.expectResult, out) }) } @@ -308,7 +306,8 @@ func TestToJson(t *testing.T) { func TestUnmarshalJSON(t *testing.T) { err := Init(nil) - assert.NoError(t, err) + require.NoError(t, err) + tests := []struct { name string json string @@ -361,11 +360,10 @@ func TestUnmarshalJSON(t *testing.T) { "out": outMap, } vm, err := expr.Compile(test.expr, GetExprOptions(env)...) - assert.NoError(t, err) + require.NoError(t, err) _, err = expr.Run(vm, env) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, test.expectResult, outMap["a"]) }) } - } diff --git a/pkg/exprhelpers/libinjection.go b/pkg/exprhelpers/libinjection.go new file mode 100644 index 00000000000..e9f33e4f459 --- /dev/null +++ b/pkg/exprhelpers/libinjection.go @@ -0,0 +1,17 @@ +package exprhelpers + +import "github.com/corazawaf/libinjection-go" + +func LibInjectionIsSQLI(params ...any) (any, error) { + str := params[0].(string) + + ret, _ := libinjection.IsSQLi(str) + return ret, nil +} + +func LibInjectionIsXSS(params ...any) (any, error) { + str := params[0].(string) + + ret := libinjection.IsXSS(str) + return ret, nil +} diff --git a/pkg/exprhelpers/libinjection_test.go b/pkg/exprhelpers/libinjection_test.go new file mode 100644 index 00000000000..7b4ab825db9 --- /dev/null +++ b/pkg/exprhelpers/libinjection_test.go @@ -0,0 +1,60 @@ +package exprhelpers + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLibinjectionHelpers(t *testing.T) { + tests := []struct { + name string + function func(params ...any) (any, error) + params []any + expectResult any + }{ + { + name: "LibInjectionIsSQLI", + function: LibInjectionIsSQLI, + params: []any{"?__f__73=73&&__f__75=75&delivery=1&max=24.9&min=15.9&n=12&o=2&p=(select(0)from(select(sleep(15)))v)/*'%2B(select(0)from(select(sleep(15)))v)%2B'\x22%2B(select(0)from(select(sleep(15)))v)%2B\x22*/&rating=4"}, + expectResult: true, + }, + { + name: "LibInjectionIsSQLI - no match", + function: LibInjectionIsSQLI, + params: []any{"?bla=42&foo=bar"}, + expectResult: false, + }, + { + name: "LibInjectionIsSQLI - no match 2", + function: LibInjectionIsSQLI, + params: []any{"https://foo.com/asdkfj?bla=42&foo=bar"}, + expectResult: false, + }, + { + name: "LibInjectionIsXSS", + function: LibInjectionIsXSS, + params: []any{""}, + expectResult: true, + }, + { + name: "LibInjectionIsXSS - no match", + function: LibInjectionIsXSS, + params: []any{"?bla=42&foo=bar"}, + expectResult: false, + }, + { + name: "LibInjectionIsXSS - no match 2", + function: LibInjectionIsXSS, + params: []any{"https://foo.com/asdkfj?bla=42&foo[]=bar&foo"}, + expectResult: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result, _ := test.function(test.params...) + assert.Equal(t, test.expectResult, result) + }) + } +} diff --git a/pkg/exprhelpers/visitor.go b/pkg/exprhelpers/visitor.go deleted file mode 100644 index 0dc1840f5e2..00000000000 --- a/pkg/exprhelpers/visitor.go +++ /dev/null @@ -1,160 +0,0 @@ -package exprhelpers - -import ( - "fmt" - "strconv" - "strings" - - "github.com/antonmedv/expr/parser" - "github.com/google/uuid" - log "github.com/sirupsen/logrus" - - "github.com/antonmedv/expr" - "github.com/antonmedv/expr/ast" - "github.com/antonmedv/expr/vm" -) - -/* -Visitor is used to reconstruct variables with its property called in an expr filter -Thus, we can debug expr filter by displaying all variables contents present in the filter -*/ -type visitor struct { - newVar bool - currentId string - vars map[string][]string - logger *log.Entry -} - -func (v *visitor) Visit(node *ast.Node) { - switch n := (*node).(type) { - case *ast.IdentifierNode: - v.newVar = true - uid, _ := uuid.NewUUID() - v.currentId = uid.String() - v.vars[v.currentId] = []string{n.Value} - case *ast.MemberNode: - if n2, ok := n.Property.(*ast.StringNode); ok { - v.vars[v.currentId] = append(v.vars[v.currentId], n2.Value) - } - case *ast.StringNode: //Don't reset here, as any attribute of a member node is a string node (in evt.X, evt is member node, X is string node) - default: - v.newVar = false - v.currentId = "" - /*case *ast.IntegerNode: - v.logger.Infof("integer node found: %+v", n) - case *ast.FloatNode: - v.logger.Infof("float node found: %+v", n) - case *ast.BoolNode: - v.logger.Infof("boolean node found: %+v", n) - case *ast.ArrayNode: - v.logger.Infof("array node found: %+v", n) - case *ast.ConstantNode: - v.logger.Infof("constant node found: %+v", n) - case *ast.UnaryNode: - v.logger.Infof("unary node found: %+v", n) - case *ast.BinaryNode: - v.logger.Infof("binary node found: %+v", n) - case *ast.CallNode: - v.logger.Infof("call node found: %+v", n) - case *ast.BuiltinNode: - v.logger.Infof("builtin node found: %+v", n) - case *ast.ConditionalNode: - v.logger.Infof("conditional node found: %+v", n) - case *ast.ChainNode: - v.logger.Infof("chain node found: %+v", n) - case *ast.PairNode: - v.logger.Infof("pair node found: %+v", n) - case *ast.MapNode: - v.logger.Infof("map node found: %+v", n) - case *ast.SliceNode: - v.logger.Infof("slice node found: %+v", n) - case *ast.ClosureNode: - v.logger.Infof("closure node found: %+v", n) - case *ast.PointerNode: - v.logger.Infof("pointer node found: %+v", n) - default: - v.logger.Infof("unknown node found: %+v | type: %T", n, n)*/ - } -} - -/* -Build reconstruct all the variables used in a filter (to display their content later). -*/ -func (v *visitor) Build(filter string, exprEnv ...expr.Option) (*ExprDebugger, error) { - var expressions []*expression - ret := &ExprDebugger{ - filter: filter, - } - if filter == "" { - v.logger.Debugf("unable to create expr debugger with empty filter") - return &ExprDebugger{}, nil - } - v.newVar = false - v.vars = make(map[string][]string) - tree, err := parser.Parse(filter) - if err != nil { - return nil, err - } - ast.Walk(&tree.Node, v) - log.Debugf("vars: %+v", v.vars) - - for _, variable := range v.vars { - if variable[0] != "evt" { - continue - } - toBuild := strings.Join(variable, ".") - v.logger.Debugf("compiling expression '%s'", toBuild) - debugFilter, err := expr.Compile(toBuild, exprEnv...) - if err != nil { - return ret, fmt.Errorf("compilation of variable '%s' failed: %v", toBuild, err) - } - tmpExpression := &expression{ - toBuild, - debugFilter, - } - expressions = append(expressions, tmpExpression) - - } - ret.expression = expressions - return ret, nil -} - -// ExprDebugger contains the list of expression to be run when debugging an expression filter -type ExprDebugger struct { - filter string - expression []*expression -} - -// expression is the structure that represents the variable in string and compiled format -type expression struct { - Str string - Compiled *vm.Program -} - -/* -Run display the content of each variable of a filter by evaluating them with expr, -again the expr environment given in parameter -*/ -func (e *ExprDebugger) Run(logger *log.Entry, filterResult bool, exprEnv map[string]interface{}) { - if len(e.expression) == 0 { - logger.Tracef("no variable to eval for filter '%s'", e.filter) - return - } - logger.Debugf("eval(%s) = %s", e.filter, strings.ToUpper(strconv.FormatBool(filterResult))) - logger.Debugf("eval variables:") - for _, expression := range e.expression { - debug, err := expr.Run(expression.Compiled, exprEnv) - if err != nil { - logger.Errorf("unable to print debug expression for '%s': %s", expression.Str, err) - } - logger.Debugf(" %s = '%v'", expression.Str, debug) - } -} - -// NewDebugger is the exported function that build the debuggers expressions -func NewDebugger(filter string, exprEnv ...expr.Option) (*ExprDebugger, error) { - logger := log.WithField("component", "expr-debugger") - visitor := &visitor{logger: logger} - exprDebugger, err := visitor.Build(filter, exprEnv...) - return exprDebugger, err -} diff --git a/pkg/exprhelpers/visitor_test.go b/pkg/exprhelpers/visitor_test.go deleted file mode 100644 index 969d1982118..00000000000 --- a/pkg/exprhelpers/visitor_test.go +++ /dev/null @@ -1,100 +0,0 @@ -package exprhelpers - -import ( - "sort" - "testing" - - "github.com/antonmedv/expr" - log "github.com/sirupsen/logrus" -) - -func TestVisitorBuild(t *testing.T) { - tests := []struct { - name string - expr string - want []string - env map[string]interface{} - }{ - { - name: "simple", - expr: "evt.X", - want: []string{"evt.X"}, - env: map[string]interface{}{ - "evt": map[string]interface{}{ - "X": 1, - }, - }, - }, - { - name: "two vars", - expr: "evt.X && evt.Y", - want: []string{"evt.X", "evt.Y"}, - env: map[string]interface{}{ - "evt": map[string]interface{}{ - "X": 1, - "Y": 2, - }, - }, - }, - { - name: "in", - expr: "evt.X in [1,2,3]", - want: []string{"evt.X"}, - env: map[string]interface{}{ - "evt": map[string]interface{}{ - "X": 1, - }, - }, - }, - { - name: "in complex", - expr: "evt.X in [1,2,3] && evt.Y in [1,2,3] || evt.Z in [1,2,3]", - want: []string{"evt.X", "evt.Y", "evt.Z"}, - env: map[string]interface{}{ - "evt": map[string]interface{}{ - "X": 1, - "Y": 2, - "Z": 3, - }, - }, - }, - { - name: "function call", - expr: "Foo(evt.X, 'ads')", - want: []string{"evt.X"}, - env: map[string]interface{}{ - "evt": map[string]interface{}{ - "X": 1, - }, - "Foo": func(x int, y string) int { - return x - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - v := &visitor{logger: log.NewEntry(log.New())} - ret, err := v.Build(tt.expr, expr.Env(tt.env)) - if err != nil { - t.Errorf("visitor.Build() error = %v", err) - return - } - if len(ret.expression) != len(tt.want) { - t.Errorf("visitor.Build() = %v, want %v", ret.expression, tt.want) - } - //Sort both slices as the order is not guaranteed ?? - sort.Slice(tt.want, func(i, j int) bool { - return tt.want[i] < tt.want[j] - }) - sort.Slice(ret.expression, func(i, j int) bool { - return ret.expression[i].Str < ret.expression[j].Str - }) - for idx, v := range ret.expression { - if v.Str != tt.want[idx] { - t.Errorf("visitor.Build() = %v, want %v", v.Str, tt.want[idx]) - } - } - }) - } -} diff --git a/pkg/exprhelpers/xml.go b/pkg/exprhelpers/xml.go index 75758e18316..0b550bdb641 100644 --- a/pkg/exprhelpers/xml.go +++ b/pkg/exprhelpers/xml.go @@ -1,43 +1,103 @@ package exprhelpers import ( + "errors" + "sync" + "time" + "github.com/beevik/etree" + "github.com/bluele/gcache" + "github.com/cespare/xxhash/v2" log "github.com/sirupsen/logrus" ) -var pathCache = make(map[string]etree.Path) +var ( + pathCache = make(map[string]etree.Path) + rwMutex = sync.RWMutex{} + xmlDocumentCache gcache.Cache +) + +func compileOrGetPath(path string) (etree.Path, error) { + rwMutex.RLock() + compiledPath, ok := pathCache[path] + rwMutex.RUnlock() + + if !ok { + var err error + compiledPath, err = etree.CompilePath(path) + if err != nil { + return etree.Path{}, err + } + + rwMutex.Lock() + pathCache[path] = compiledPath + rwMutex.Unlock() + } + + return compiledPath, nil +} + +func getXMLDocumentFromCache(xmlString string) (*etree.Document, error) { + cacheKey := xxhash.Sum64String(xmlString) + cacheObj, err := xmlDocumentCache.Get(cacheKey) + + if err != nil && !errors.Is(err, gcache.KeyNotFoundError) { + return nil, err + } + + doc, ok := cacheObj.(*etree.Document) + if !ok || cacheObj == nil { + doc = etree.NewDocument() + if err := doc.ReadFromString(xmlString); err != nil { + return nil, err + } + if err := xmlDocumentCache.Set(cacheKey, doc); err != nil { + log.Warnf("Could not set XML document in cache: %s", err) + } + } + + return doc, nil +} + +func XMLCacheInit() { + gc := gcache.New(50) + // Short cache expiration because we each line we read is different, but we can call multiple times XML helpers on each of them + gc.Expiration(5 * time.Second) + gc = gc.LRU() + + xmlDocumentCache = gc.Build() +} // func XMLGetAttributeValue(xmlString string, path string, attributeName string) string { func XMLGetAttributeValue(params ...any) (any, error) { xmlString := params[0].(string) path := params[1].(string) attributeName := params[2].(string) - if _, ok := pathCache[path]; !ok { - compiledPath, err := etree.CompilePath(path) - if err != nil { - log.Errorf("Could not compile path %s: %s", path, err) - return "", nil - } - pathCache[path] = compiledPath + + compiledPath, err := compileOrGetPath(path) + if err != nil { + log.Errorf("Could not compile path %s: %s", path, err) + return "", nil } - compiledPath := pathCache[path] - doc := etree.NewDocument() - err := doc.ReadFromString(xmlString) + doc, err := getXMLDocumentFromCache(xmlString) if err != nil { log.Tracef("Could not parse XML: %s", err) return "", nil } + elem := doc.FindElementPath(compiledPath) if elem == nil { log.Debugf("Could not find element %s", path) return "", nil } + attr := elem.SelectAttr(attributeName) if attr == nil { log.Debugf("Could not find attribute %s", attributeName) return "", nil } + return attr.Value, nil } @@ -45,26 +105,24 @@ func XMLGetAttributeValue(params ...any) (any, error) { func XMLGetNodeValue(params ...any) (any, error) { xmlString := params[0].(string) path := params[1].(string) - if _, ok := pathCache[path]; !ok { - compiledPath, err := etree.CompilePath(path) - if err != nil { - log.Errorf("Could not compile path %s: %s", path, err) - return "", nil - } - pathCache[path] = compiledPath + + compiledPath, err := compileOrGetPath(path) + if err != nil { + log.Errorf("Could not compile path %s: %s", path, err) + return "", nil } - compiledPath := pathCache[path] - doc := etree.NewDocument() - err := doc.ReadFromString(xmlString) + doc, err := getXMLDocumentFromCache(xmlString) if err != nil { log.Tracef("Could not parse XML: %s", err) return "", nil } + elem := doc.FindElementPath(compiledPath) if elem == nil { log.Debugf("Could not find element %s", path) return "", nil } + return elem.Text(), nil } diff --git a/pkg/exprhelpers/xml_test.go b/pkg/exprhelpers/xml_test.go index 516387f764b..42823884025 100644 --- a/pkg/exprhelpers/xml_test.go +++ b/pkg/exprhelpers/xml_test.go @@ -9,7 +9,7 @@ import ( func TestXMLGetAttributeValue(t *testing.T) { if err := Init(nil); err != nil { - log.Fatal(err) + t.Fatal(err) } tests := []struct { @@ -58,17 +58,19 @@ func TestXMLGetAttributeValue(t *testing.T) { for _, test := range tests { result, _ := XMLGetAttributeValue(test.xmlString, test.path, test.attribute) + isOk := assert.Equal(t, test.expectResult, result) if !isOk { t.Fatalf("test '%s' failed", test.name) } + log.Printf("test '%s' : OK", test.name) } - } + func TestXMLGetNodeValue(t *testing.T) { if err := Init(nil); err != nil { - log.Fatal(err) + t.Fatal(err) } tests := []struct { @@ -105,11 +107,12 @@ func TestXMLGetNodeValue(t *testing.T) { for _, test := range tests { result, _ := XMLGetNodeValue(test.xmlString, test.path) + isOk := assert.Equal(t, test.expectResult, result) if !isOk { t.Fatalf("test '%s' failed", test.name) } + log.Printf("test '%s' : OK", test.name) } - } diff --git a/pkg/fflag/crowdsec.go b/pkg/fflag/crowdsec.go index 889f62dcf2f..d42d6a05ef6 100644 --- a/pkg/fflag/crowdsec.go +++ b/pkg/fflag/crowdsec.go @@ -5,7 +5,7 @@ var Crowdsec = FeatureRegister{EnvPrefix: "CROWDSEC_FEATURE_"} var CscliSetup = &Feature{Name: "cscli_setup", Description: "Enable cscli setup command (service detection)"} var DisableHttpRetryBackoff = &Feature{Name: "disable_http_retry_backoff", Description: "Disable http retry backoff"} var ChunkedDecisionsStream = &Feature{Name: "chunked_decisions_stream", Description: "Enable chunked decisions stream"} -var PapiClient = &Feature{Name: "papi_client", Description: "Enable Polling API client"} +var PapiClient = &Feature{Name: "papi_client", Description: "Enable Polling API client", State: DeprecatedState} var Re2GrokSupport = &Feature{Name: "re2_grok_support", Description: "Enable RE2 support for GROK patterns"} var Re2RegexpInfileSupport = &Feature{Name: "re2_regexp_in_file_support", Description: "Enable RE2 support for RegexpInFile expr helper"} @@ -14,22 +14,27 @@ func RegisterAllFeatures() error { if err != nil { return err } + err = Crowdsec.RegisterFeature(DisableHttpRetryBackoff) if err != nil { return err } + err = Crowdsec.RegisterFeature(ChunkedDecisionsStream) if err != nil { return err } + err = Crowdsec.RegisterFeature(PapiClient) if err != nil { return err } + err = Crowdsec.RegisterFeature(Re2GrokSupport) if err != nil { return err } + err = Crowdsec.RegisterFeature(Re2RegexpInfileSupport) if err != nil { return err diff --git a/pkg/fflag/features.go b/pkg/fflag/features.go index b3c0a8bfe54..c8a3d7755ea 100644 --- a/pkg/fflag/features.go +++ b/pkg/fflag/features.go @@ -18,10 +18,10 @@ // in feature.yaml. Features cannot be disabled in the file. // // A feature flag can be deprecated or retired. A deprecated feature flag is -// still accepted but a warning is logged. A retired feature flag is ignored -// and an error is logged. +// still accepted but a warning is logged (only if a deprecation message is provided). +// A retired feature flag is ignored and an error is logged. // -// A specific deprecation message is used to inform the user of the behavior +// The message is inteded to inform the user of the behavior // that has been decided when the flag is/was finally retired. package fflag @@ -97,7 +97,7 @@ type FeatureRegister struct { features map[string]*Feature } -var featureNameRexp = regexp.MustCompile(`^[a-z0-9_\.]+$`) +var featureNameRexp = regexp.MustCompile(`^[a-z0-9_.]+$`) func validateFeatureName(featureName string) error { if featureName == "" { @@ -176,7 +176,9 @@ func (fr *FeatureRegister) SetFromEnv(logger *logrus.Logger) error { logger.Errorf("Ignored envvar '%s': %s. %s", varName, err, feat.DeprecationMsg) continue case errors.Is(err, ErrFeatureDeprecated): - logger.Warningf("Envvar '%s': %s. %s", varName, err, feat.DeprecationMsg) + if feat.DeprecationMsg != "" { + logger.Warningf("Envvar '%s': %s. %s", varName, err, feat.DeprecationMsg) + } case err != nil: return err } diff --git a/pkg/fflag/features_test.go b/pkg/fflag/features_test.go index be7434c2a6c..481e86573e8 100644 --- a/pkg/fflag/features_test.go +++ b/pkg/fflag/features_test.go @@ -9,7 +9,7 @@ import ( logtest "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" - "github.com/crowdsecurity/go-cs-lib/pkg/cstest" + "github.com/crowdsecurity/go-cs-lib/cstest" "github.com/crowdsecurity/crowdsec/pkg/fflag" ) @@ -50,8 +50,6 @@ func TestRegisterFeature(t *testing.T) { } for _, tc := range tests { - tc := tc - t.Run("", func(t *testing.T) { fr := fflag.FeatureRegister{EnvPrefix: "FFLAG_TEST_"} err := fr.RegisterFeature(&tc.feature) @@ -112,7 +110,6 @@ func TestGetFeature(t *testing.T) { fr := setUp(t) for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { _, err := fr.GetFeature(tc.feature) cstest.RequireErrorMessage(t, err, tc.expectedErr) @@ -145,7 +142,6 @@ func TestIsEnabled(t *testing.T) { fr := setUp(t) for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { feat, err := fr.GetFeature(tc.feature) require.NoError(t, err) @@ -204,7 +200,6 @@ func TestFeatureSet(t *testing.T) { fr := setUp(t) for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { feat, err := fr.GetFeature(tc.feature) cstest.RequireErrorMessage(t, err, tc.expectedGetErr) @@ -284,7 +279,6 @@ func TestSetFromEnv(t *testing.T) { fr := setUp(t) for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { logger, hook := logtest.NewNullLogger() logger.SetLevel(logrus.DebugLevel) @@ -344,7 +338,6 @@ func TestSetFromYaml(t *testing.T) { fr := setUp(t) for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { logger, hook := logtest.NewNullLogger() logger.SetLevel(logrus.DebugLevel) @@ -364,7 +357,7 @@ func TestSetFromYamlFile(t *testing.T) { defer os.Remove(tmpfile.Name()) // write the config file - _, err = tmpfile.Write([]byte("- experimental1")) + _, err = tmpfile.WriteString("- experimental1") require.NoError(t, err) require.NoError(t, tmpfile.Close()) diff --git a/pkg/hubtest/appsecrule.go b/pkg/hubtest/appsecrule.go new file mode 100644 index 00000000000..1c4416c2e9b --- /dev/null +++ b/pkg/hubtest/appsecrule.go @@ -0,0 +1,95 @@ +package hubtest + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/pkg/cwhub" +) + +func (t *HubTestItem) installAppsecRuleItem(item *cwhub.Item) error { + sourcePath, err := filepath.Abs(filepath.Join(t.HubPath, item.RemotePath)) + if err != nil { + return fmt.Errorf("can't get absolute path of '%s': %w", sourcePath, err) + } + + sourceFilename := filepath.Base(sourcePath) + + // runtime/hub/appsec-rules/author/appsec-rule + hubDirAppsecRuleDest := filepath.Join(t.RuntimeHubPath, filepath.Dir(item.RemotePath)) + + // runtime/appsec-rules/ + itemTypeDirDest := fmt.Sprintf("%s/appsec-rules/", t.RuntimePath) + + if err := createDirs([]string{hubDirAppsecRuleDest, itemTypeDirDest}); err != nil { + return err + } + + // runtime/hub/appsec-rules/crowdsecurity/rule.yaml + hubDirAppsecRulePath := filepath.Join(itemTypeDirDest, sourceFilename) + if err := Copy(sourcePath, hubDirAppsecRulePath); err != nil { + return fmt.Errorf("unable to copy '%s' to '%s': %w", sourcePath, hubDirAppsecRulePath, err) + } + + // runtime/appsec-rules/rule.yaml + appsecRulePath := filepath.Join(itemTypeDirDest, sourceFilename) + if err := os.Symlink(hubDirAppsecRulePath, appsecRulePath); err != nil { + if !os.IsExist(err) { + return fmt.Errorf("unable to symlink appsec-rule '%s' to '%s': %w", hubDirAppsecRulePath, appsecRulePath, err) + } + } + + return nil +} + +func (t *HubTestItem) installAppsecRuleCustomFrom(appsecrule string, customPath string) (bool, error) { + // we check if its a custom appsec-rule + customAppsecRulePath := filepath.Join(customPath, appsecrule) + if _, err := os.Stat(customAppsecRulePath); os.IsNotExist(err) { + return false, nil + } + + customAppsecRulePathSplit := strings.Split(customAppsecRulePath, "/") + customAppsecRuleName := customAppsecRulePathSplit[len(customAppsecRulePathSplit)-1] + + itemTypeDirDest := fmt.Sprintf("%s/appsec-rules/", t.RuntimePath) + if err := os.MkdirAll(itemTypeDirDest, os.ModePerm); err != nil { + return false, fmt.Errorf("unable to create folder '%s': %w", itemTypeDirDest, err) + } + + customAppsecRuleDest := fmt.Sprintf("%s/appsec-rules/%s", t.RuntimePath, customAppsecRuleName) + if err := Copy(customAppsecRulePath, customAppsecRuleDest); err != nil { + return false, fmt.Errorf("unable to copy appsec-rule from '%s' to '%s': %w", customAppsecRulePath, customAppsecRuleDest, err) + } + + return true, nil +} + +func (t *HubTestItem) installAppsecRuleCustom(appsecrule string) error { + for _, customPath := range t.CustomItemsLocation { + found, err := t.installAppsecRuleCustomFrom(appsecrule, customPath) + if err != nil { + return err + } + + if found { + return nil + } + } + + return fmt.Errorf("couldn't find custom appsec-rule '%s' in the following location: %+v", appsecrule, t.CustomItemsLocation) +} + +func (t *HubTestItem) installAppsecRule(name string) error { + log.Debugf("adding rule '%s'", name) + + if item := t.HubIndex.GetItem(cwhub.APPSEC_RULES, name); item != nil { + return t.installAppsecRuleItem(item) + } + + return t.installAppsecRuleCustom(name) +} diff --git a/pkg/hubtest/coverage.go b/pkg/hubtest/coverage.go index eeff24b57b6..e42c1e23455 100644 --- a/pkg/hubtest/coverage.go +++ b/pkg/hubtest/coverage.go @@ -2,176 +2,265 @@ package hubtest import ( "bufio" + "errors" "fmt" "os" "path/filepath" - "regexp" - "sort" "strings" - "github.com/crowdsecurity/crowdsec/pkg/cwhub" log "github.com/sirupsen/logrus" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/go-cs-lib/maptools" + + "github.com/crowdsecurity/crowdsec/pkg/appsec/appsec_rule" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -type ParserCoverage struct { - Parser string +type Coverage struct { + Name string TestsCount int - PresentIn map[string]bool //poorman's set + PresentIn map[string]bool // poorman's set } -type ScenarioCoverage struct { - Scenario string - TestsCount int - PresentIn map[string]bool -} +func (h *HubTest) GetAppsecCoverage() ([]Coverage, error) { + if len(h.HubIndex.GetItemMap(cwhub.APPSEC_RULES)) == 0 { + return nil, errors.New("no appsec rules in hub index") + } + + // populate from hub, iterate in alphabetical order + pkeys := maptools.SortedKeys(h.HubIndex.GetItemMap(cwhub.APPSEC_RULES)) + coverage := make([]Coverage, len(pkeys)) -func (h *HubTest) GetParsersCoverage() ([]ParserCoverage, error) { - var coverage []ParserCoverage - if _, ok := h.HubIndex.Data[cwhub.PARSERS]; !ok { - return coverage, fmt.Errorf("no parsers in hub index") + for i, name := range pkeys { + coverage[i] = Coverage{ + Name: name, + TestsCount: 0, + PresentIn: make(map[string]bool), + } } - //populate from hub, iterate in alphabetical order - var pkeys []string - for pname := range h.HubIndex.Data[cwhub.PARSERS] { - pkeys = append(pkeys, pname) + + // parser the expressions a-la-oneagain + appsecTestConfigs, err := filepath.Glob(".appsec-tests/*/config.yaml") + if err != nil { + return nil, fmt.Errorf("while find appsec-tests config: %w", err) } - sort.Strings(pkeys) - for _, pname := range pkeys { - coverage = append(coverage, ParserCoverage{ - Parser: pname, + + for _, appsecTestConfigPath := range appsecTestConfigs { + configFileData := &HubTestItemConfig{} + + yamlFile, err := os.ReadFile(appsecTestConfigPath) + if err != nil { + log.Printf("unable to open appsec test config file '%s': %s", appsecTestConfigPath, err) + continue + } + + err = yaml.Unmarshal(yamlFile, configFileData) + if err != nil { + return nil, fmt.Errorf("parsing: %v", err) + } + + for _, appsecRulesFile := range configFileData.AppsecRules { + appsecRuleData := &appsec_rule.CustomRule{} + + yamlFile, err := os.ReadFile(appsecRulesFile) + if err != nil { + log.Printf("unable to open appsec rule '%s': %s", appsecRulesFile, err) + } + + err = yaml.Unmarshal(yamlFile, appsecRuleData) + if err != nil { + return nil, fmt.Errorf("parsing: %v", err) + } + + appsecRuleName := appsecRuleData.Name + + for idx, cov := range coverage { + if cov.Name == appsecRuleName { + coverage[idx].TestsCount++ + coverage[idx].PresentIn[appsecTestConfigPath] = true + } + } + } + } + + return coverage, nil +} + +func (h *HubTest) GetParsersCoverage() ([]Coverage, error) { + if len(h.HubIndex.GetItemMap(cwhub.PARSERS)) == 0 { + return nil, errors.New("no parsers in hub index") + } + + // populate from hub, iterate in alphabetical order + pkeys := maptools.SortedKeys(h.HubIndex.GetItemMap(cwhub.PARSERS)) + coverage := make([]Coverage, len(pkeys)) + + for i, name := range pkeys { + coverage[i] = Coverage{ + Name: name, TestsCount: 0, PresentIn: make(map[string]bool), - }) + } } - //parser the expressions a-la-oneagain + // parser the expressions a-la-oneagain passerts, err := filepath.Glob(".tests/*/parser.assert") if err != nil { - return coverage, fmt.Errorf("while find parser asserts : %s", err) + return nil, fmt.Errorf("while find parser asserts: %w", err) } + for _, assert := range passerts { file, err := os.Open(assert) if err != nil { - return coverage, fmt.Errorf("while reading %s : %s", assert, err) + return nil, fmt.Errorf("while reading %s: %w", assert, err) } + scanner := bufio.NewScanner(file) for scanner.Scan() { - assertLine := regexp.MustCompile(`^results\["[^"]+"\]\["(?P[^"]+)"\]\[[0-9]+\]\.Evt\..*`) line := scanner.Text() log.Debugf("assert line : %s", line) - match := assertLine.FindStringSubmatch(line) + + match := parserResultRE.FindStringSubmatch(line) if len(match) == 0 { log.Debugf("%s doesn't match", line) continue } - sidx := assertLine.SubexpIndex("parser") + + sidx := parserResultRE.SubexpIndex("parser") capturedParser := match[sidx] + for idx, pcover := range coverage { - if pcover.Parser == capturedParser { + if pcover.Name == capturedParser { coverage[idx].TestsCount++ coverage[idx].PresentIn[assert] = true + continue } - parserNameSplit := strings.Split(pcover.Parser, "/") + + parserNameSplit := strings.Split(pcover.Name, "/") parserNameOnly := parserNameSplit[len(parserNameSplit)-1] + if parserNameOnly == capturedParser { coverage[idx].TestsCount++ coverage[idx].PresentIn[assert] = true + continue } + capturedParserSplit := strings.Split(capturedParser, "/") capturedParserName := capturedParserSplit[len(capturedParserSplit)-1] + if capturedParserName == parserNameOnly { coverage[idx].TestsCount++ coverage[idx].PresentIn[assert] = true + continue } + if capturedParserName == parserNameOnly+"-logs" { coverage[idx].TestsCount++ coverage[idx].PresentIn[assert] = true + continue } } } + file.Close() } + return coverage, nil } -func (h *HubTest) GetScenariosCoverage() ([]ScenarioCoverage, error) { - var coverage []ScenarioCoverage - if _, ok := h.HubIndex.Data[cwhub.SCENARIOS]; !ok { - return coverage, fmt.Errorf("no scenarios in hub index") +func (h *HubTest) GetScenariosCoverage() ([]Coverage, error) { + if len(h.HubIndex.GetItemMap(cwhub.SCENARIOS)) == 0 { + return nil, errors.New("no scenarios in hub index") } - //populate from hub, iterate in alphabetical order - var pkeys []string - for scenarioName := range h.HubIndex.Data[cwhub.SCENARIOS] { - pkeys = append(pkeys, scenarioName) - } - sort.Strings(pkeys) - for _, scenarioName := range pkeys { - coverage = append(coverage, ScenarioCoverage{ - Scenario: scenarioName, + + // populate from hub, iterate in alphabetical order + pkeys := maptools.SortedKeys(h.HubIndex.GetItemMap(cwhub.SCENARIOS)) + coverage := make([]Coverage, len(pkeys)) + + for i, name := range pkeys { + coverage[i] = Coverage{ + Name: name, TestsCount: 0, PresentIn: make(map[string]bool), - }) + } } - //parser the expressions a-la-oneagain + // parser the expressions a-la-oneagain passerts, err := filepath.Glob(".tests/*/scenario.assert") if err != nil { - return coverage, fmt.Errorf("while find scenario asserts : %s", err) + return nil, fmt.Errorf("while find scenario asserts: %w", err) } + for _, assert := range passerts { file, err := os.Open(assert) if err != nil { - return coverage, fmt.Errorf("while reading %s : %s", assert, err) + return nil, fmt.Errorf("while reading %s: %w", assert, err) } + scanner := bufio.NewScanner(file) for scanner.Scan() { - assertLine := regexp.MustCompile(`^results\[[0-9]+\].Overflow.Alert.GetScenario\(\) == "(?P[^"]+)"`) line := scanner.Text() log.Debugf("assert line : %s", line) - match := assertLine.FindStringSubmatch(line) + match := scenarioResultRE.FindStringSubmatch(line) + if len(match) == 0 { log.Debugf("%s doesn't match", line) continue } - sidx := assertLine.SubexpIndex("scenario") - scanner_name := match[sidx] + + sidx := scenarioResultRE.SubexpIndex("scenario") + scannerName := match[sidx] + for idx, pcover := range coverage { - if pcover.Scenario == scanner_name { + if pcover.Name == scannerName { coverage[idx].TestsCount++ coverage[idx].PresentIn[assert] = true + continue } - scenarioNameSplit := strings.Split(pcover.Scenario, "/") + + scenarioNameSplit := strings.Split(pcover.Name, "/") scenarioNameOnly := scenarioNameSplit[len(scenarioNameSplit)-1] - if scenarioNameOnly == scanner_name { + + if scenarioNameOnly == scannerName { coverage[idx].TestsCount++ coverage[idx].PresentIn[assert] = true + continue } - fixedProbingWord := strings.ReplaceAll(pcover.Scenario, "probbing", "probing") - fixedProbingAssert := strings.ReplaceAll(scanner_name, "probbing", "probing") + + fixedProbingWord := strings.ReplaceAll(pcover.Name, "probbing", "probing") + fixedProbingAssert := strings.ReplaceAll(scannerName, "probbing", "probing") + if fixedProbingWord == fixedProbingAssert { coverage[idx].TestsCount++ coverage[idx].PresentIn[assert] = true + continue } - if fmt.Sprintf("%s-detection", pcover.Scenario) == scanner_name { + + if fmt.Sprintf("%s-detection", pcover.Name) == scannerName { coverage[idx].TestsCount++ coverage[idx].PresentIn[assert] = true + continue } + if fmt.Sprintf("%s-detection", fixedProbingWord) == fixedProbingAssert { coverage[idx].TestsCount++ coverage[idx].PresentIn[assert] = true + continue } } } file.Close() } + return coverage, nil } diff --git a/pkg/hubtest/hubtest.go b/pkg/hubtest/hubtest.go index c1aa4251ca1..93f5abaa879 100644 --- a/pkg/hubtest/hubtest.go +++ b/pkg/hubtest/hubtest.go @@ -6,70 +6,138 @@ import ( "os/exec" "path/filepath" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) type HubTest struct { - CrowdSecPath string - CscliPath string - HubPath string - HubTestPath string - HubIndexFile string - TemplateConfigPath string - TemplateProfilePath string - TemplateSimulationPath string - HubIndex *HubIndex - Tests []*HubTestItem + CrowdSecPath string + CscliPath string + HubPath string + HubTestPath string //generic parser/scenario tests .tests + HubAppsecTestPath string //dir specific to appsec tests .appsec-tests + HubIndexFile string + TemplateConfigPath string + TemplateProfilePath string + TemplateSimulationPath string + TemplateAcquisPath string + TemplateAppsecProfilePath string + NucleiTargetHost string + AppSecHost string + + HubIndex *cwhub.Hub + Tests []*HubTestItem } const ( - templateConfigFile = "template_config.yaml" - templateSimulationFile = "template_simulation.yaml" - templateProfileFile = "template_profiles.yaml" + templateConfigFile = "template_config.yaml" + templateSimulationFile = "template_simulation.yaml" + templateProfileFile = "template_profiles.yaml" + templateAcquisFile = "template_acquis.yaml" + templateAppsecProfilePath = "template_appsec-profile.yaml" + TemplateNucleiFile = `id: {{.TestName}} +info: + name: {{.TestName}} + author: crowdsec + severity: info + description: {{.TestName}} testing + tags: appsec-testing +http: +#this is a dummy request, edit the request(s) to match your needs + - raw: + - | + GET /test HTTP/1.1 + Host: {{"{{"}}Hostname{{"}}"}} + + cookie-reuse: true +#test will fail because we won't match http status + matchers: + - type: status + status: + - 403 +` ) -func NewHubTest(hubPath string, crowdsecPath string, cscliPath string) (HubTest, error) { - var err error - - hubPath, err = filepath.Abs(hubPath) +func NewHubTest(hubPath string, crowdsecPath string, cscliPath string, isAppsecTest bool) (HubTest, error) { + hubPath, err := filepath.Abs(hubPath) if err != nil { return HubTest{}, fmt.Errorf("can't get absolute path of hub: %+v", err) } + // we can't use hubtest without the hub - if _, err := os.Stat(hubPath); os.IsNotExist(err) { + if _, err = os.Stat(hubPath); os.IsNotExist(err) { return HubTest{}, fmt.Errorf("path to hub '%s' doesn't exist, can't run", hubPath) } - HubTestPath := filepath.Join(hubPath, "./.tests/") - // we can't use hubtest without crowdsec binary - if _, err := exec.LookPath(crowdsecPath); err != nil { - if _, err := os.Stat(crowdsecPath); os.IsNotExist(err) { + if _, err = exec.LookPath(crowdsecPath); err != nil { + if _, err = os.Stat(crowdsecPath); os.IsNotExist(err) { return HubTest{}, fmt.Errorf("path to crowdsec binary '%s' doesn't exist or is not in $PATH, can't run", crowdsecPath) } } // we can't use hubtest without cscli binary - if _, err := exec.LookPath(cscliPath); err != nil { - if _, err := os.Stat(cscliPath); os.IsNotExist(err) { + if _, err = exec.LookPath(cscliPath); err != nil { + if _, err = os.Stat(cscliPath); os.IsNotExist(err) { return HubTest{}, fmt.Errorf("path to cscli binary '%s' doesn't exist or is not in $PATH, can't run", cscliPath) } } + if isAppsecTest { + HubTestPath := filepath.Join(hubPath, ".appsec-tests") + hubIndexFile := filepath.Join(hubPath, ".index.json") + + local := &csconfig.LocalHubCfg{ + HubDir: hubPath, + HubIndexFile: hubIndexFile, + InstallDir: HubTestPath, + InstallDataDir: HubTestPath, + } + + hub, err := cwhub.NewHub(local, nil, nil) + if err != nil { + return HubTest{}, err + } + + if err := hub.Load(); err != nil { + return HubTest{}, err + } + + return HubTest{ + CrowdSecPath: crowdsecPath, + CscliPath: cscliPath, + HubPath: hubPath, + HubTestPath: HubTestPath, + HubIndexFile: hubIndexFile, + TemplateConfigPath: filepath.Join(HubTestPath, templateConfigFile), + TemplateProfilePath: filepath.Join(HubTestPath, templateProfileFile), + TemplateSimulationPath: filepath.Join(HubTestPath, templateSimulationFile), + TemplateAppsecProfilePath: filepath.Join(HubTestPath, templateAppsecProfilePath), + TemplateAcquisPath: filepath.Join(HubTestPath, templateAcquisFile), + NucleiTargetHost: DefaultNucleiTarget, + AppSecHost: DefaultAppsecHost, + HubIndex: hub, + }, nil + } + + HubTestPath := filepath.Join(hubPath, ".tests") + hubIndexFile := filepath.Join(hubPath, ".index.json") - bidx, err := os.ReadFile(hubIndexFile) - if err != nil { - return HubTest{}, fmt.Errorf("unable to read index file: %s", err) + + local := &csconfig.LocalHubCfg{ + HubDir: hubPath, + HubIndexFile: hubIndexFile, + InstallDir: HubTestPath, + InstallDataDir: HubTestPath, } - // load hub index - hubIndex, err := cwhub.LoadPkgIndex(bidx) + hub, err := cwhub.NewHub(local, nil, nil) if err != nil { - return HubTest{}, fmt.Errorf("unable to load hub index file: %s", err) + return HubTest{}, err } - templateConfigFilePath := filepath.Join(HubTestPath, templateConfigFile) - templateProfilePath := filepath.Join(HubTestPath, templateProfileFile) - templateSimulationPath := filepath.Join(HubTestPath, templateSimulationFile) + if err := hub.Load(); err != nil { + return HubTest{}, err + } return HubTest{ CrowdSecPath: crowdsecPath, @@ -77,19 +145,21 @@ func NewHubTest(hubPath string, crowdsecPath string, cscliPath string) (HubTest, HubPath: hubPath, HubTestPath: HubTestPath, HubIndexFile: hubIndexFile, - TemplateConfigPath: templateConfigFilePath, - TemplateProfilePath: templateProfilePath, - TemplateSimulationPath: templateSimulationPath, - HubIndex: &HubIndex{Data: hubIndex}, + TemplateConfigPath: filepath.Join(HubTestPath, templateConfigFile), + TemplateProfilePath: filepath.Join(HubTestPath, templateProfileFile), + TemplateSimulationPath: filepath.Join(HubTestPath, templateSimulationFile), + HubIndex: hub, }, nil } func (h *HubTest) LoadTestItem(name string) (*HubTestItem, error) { HubTestItem := &HubTestItem{} + testItem, err := NewTest(name, h) if err != nil { return HubTestItem, err } + h.Tests = append(h.Tests, testItem) return testItem, nil @@ -108,5 +178,6 @@ func (h *HubTest) LoadAllTests() error { } } } + return nil } diff --git a/pkg/hubtest/hubtest_item.go b/pkg/hubtest/hubtest_item.go index 475f42cf697..bc9c8955d0d 100644 --- a/pkg/hubtest/hubtest_item.go +++ b/pkg/hubtest/hubtest_item.go @@ -1,32 +1,35 @@ package hubtest import ( + "context" + "errors" "fmt" + "net/url" "os" "os/exec" "path/filepath" "strings" + log "github.com/sirupsen/logrus" + "gopkg.in/yaml.v3" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/cwhub" "github.com/crowdsecurity/crowdsec/pkg/parser" - log "github.com/sirupsen/logrus" - "gopkg.in/yaml.v2" ) type HubTestItemConfig struct { - Parsers []string `yaml:"parsers"` - Scenarios []string `yaml:"scenarios"` - PostOVerflows []string `yaml:"postoverflows"` - LogFile string `yaml:"log_file"` - LogType string `yaml:"log_type"` - Labels map[string]string `yaml:"labels"` - IgnoreParsers bool `yaml:"ignore_parsers"` // if we test a scenario, we don't want to assert on Parser - OverrideStatics []parser.ExtraField `yaml:"override_statics"` //Allow to override statics. Executed before s00 -} - -type HubIndex struct { - Data map[string]map[string]cwhub.Item + Parsers []string `yaml:"parsers,omitempty"` + Scenarios []string `yaml:"scenarios,omitempty"` + PostOverflows []string `yaml:"postoverflows,omitempty"` + AppsecRules []string `yaml:"appsec-rules,omitempty"` + NucleiTemplate string `yaml:"nuclei_template,omitempty"` + ExpectedNucleiFailure bool `yaml:"expect_failure,omitempty"` + LogFile string `yaml:"log_file,omitempty"` + LogType string `yaml:"log_type,omitempty"` + Labels map[string]string `yaml:"labels,omitempty"` + IgnoreParsers bool `yaml:"ignore_parsers,omitempty"` // if we test a scenario, we don't want to assert on Parser + OverrideStatics []parser.ExtraField `yaml:"override_statics,omitempty"` // Allow to override statics. Executed before s00 } type HubTestItem struct { @@ -43,20 +46,23 @@ type HubTestItem struct { RuntimeConfigFilePath string RuntimeProfileFilePath string RuntimeSimulationFilePath string - RuntimeHubConfig *csconfig.Hub + RuntimeAcquisFilePath string + RuntimeHubConfig *csconfig.LocalHubCfg ResultsPath string ParserResultFile string ScenarioResultFile string BucketPourResultFile string - HubPath string - HubTestPath string - HubIndexFile string - TemplateConfigPath string - TemplateProfilePath string - TemplateSimulationPath string - HubIndex *HubIndex + HubPath string + HubTestPath string + HubIndexFile string + TemplateConfigPath string + TemplateProfilePath string + TemplateSimulationPath string + TemplateAcquisPath string + TemplateAppsecProfilePath string + HubIndex *cwhub.Hub Config *HubTestItemConfig @@ -68,6 +74,9 @@ type HubTestItem struct { ScenarioAssert *ScenarioAssert CustomItemsLocation []string + + NucleiTargetHost string + AppSecHost string } const ( @@ -78,9 +87,12 @@ const ( ScenarioResultFileName = "bucket-dump.yaml" BucketPourResultFileName = "bucketpour-dump.yaml" -) -var crowdsecPatternsFolder = csconfig.DefaultConfigPath("patterns") + TestBouncerApiKey = "this_is_a_bad_password" + + DefaultNucleiTarget = "http://127.0.0.1:7822/" + DefaultAppsecHost = "127.0.0.1:4241" +) func NewTest(name string, hubTest *HubTest) (*HubTestItem, error) { testPath := filepath.Join(hubTest.HubTestPath, name) @@ -91,13 +103,15 @@ func NewTest(name string, hubTest *HubTest) (*HubTestItem, error) { // read test configuration file configFileData := &HubTestItemConfig{} + yamlFile, err := os.ReadFile(configFilePath) if err != nil { log.Printf("no config file found in '%s': %v", testPath, err) } + err = yaml.Unmarshal(yamlFile, configFileData) if err != nil { - return nil, fmt.Errorf("Unmarshal: %v", err) + return nil, fmt.Errorf("parsing: %w", err) } parserAssertFilePath := filepath.Join(testPath, ParserAssertFileName) @@ -105,6 +119,7 @@ func NewTest(name string, hubTest *HubTest) (*HubTestItem, error) { scenarioAssertFilePath := filepath.Join(testPath, ScenarioAssertFileName) ScenarioAssert := NewScenarioAssert(scenarioAssertFilePath) + return &HubTestItem{ Name: name, Path: testPath, @@ -117,265 +132,64 @@ func NewTest(name string, hubTest *HubTest) (*HubTestItem, error) { RuntimeConfigFilePath: filepath.Join(runtimeFolder, "config.yaml"), RuntimeProfileFilePath: filepath.Join(runtimeFolder, "profiles.yaml"), RuntimeSimulationFilePath: filepath.Join(runtimeFolder, "simulation.yaml"), + RuntimeAcquisFilePath: filepath.Join(runtimeFolder, "acquis.yaml"), ResultsPath: resultPath, ParserResultFile: filepath.Join(resultPath, ParserResultFileName), ScenarioResultFile: filepath.Join(resultPath, ScenarioResultFileName), BucketPourResultFile: filepath.Join(resultPath, BucketPourResultFileName), - RuntimeHubConfig: &csconfig.Hub{ - HubDir: runtimeHubFolder, - ConfigDir: runtimeFolder, - HubIndexFile: hubTest.HubIndexFile, - DataDir: filepath.Join(runtimeFolder, "data"), + RuntimeHubConfig: &csconfig.LocalHubCfg{ + HubDir: runtimeHubFolder, + HubIndexFile: hubTest.HubIndexFile, + InstallDir: runtimeFolder, + InstallDataDir: filepath.Join(runtimeFolder, "data"), }, - Config: configFileData, - HubPath: hubTest.HubPath, - HubTestPath: hubTest.HubTestPath, - HubIndexFile: hubTest.HubIndexFile, - TemplateConfigPath: hubTest.TemplateConfigPath, - TemplateProfilePath: hubTest.TemplateProfilePath, - TemplateSimulationPath: hubTest.TemplateSimulationPath, - HubIndex: hubTest.HubIndex, - ScenarioAssert: ScenarioAssert, - ParserAssert: ParserAssert, - CustomItemsLocation: []string{hubTest.HubPath, testPath}, + Config: configFileData, + HubPath: hubTest.HubPath, + HubTestPath: hubTest.HubTestPath, + HubIndexFile: hubTest.HubIndexFile, + TemplateConfigPath: hubTest.TemplateConfigPath, + TemplateProfilePath: hubTest.TemplateProfilePath, + TemplateSimulationPath: hubTest.TemplateSimulationPath, + TemplateAcquisPath: hubTest.TemplateAcquisPath, + TemplateAppsecProfilePath: hubTest.TemplateAppsecProfilePath, + HubIndex: hubTest.HubIndex, + ScenarioAssert: ScenarioAssert, + ParserAssert: ParserAssert, + CustomItemsLocation: []string{hubTest.HubPath, testPath}, + NucleiTargetHost: hubTest.NucleiTargetHost, + AppSecHost: hubTest.AppSecHost, }, nil } -func (t *HubTestItem) InstallHub() error { - // install parsers in runtime environment - for _, parser := range t.Config.Parsers { - if parser == "" { +func (t *HubTestItem) installHubItems(names []string, installFunc func(string) error) error { + for _, name := range names { + if name == "" { continue } - var parserDirDest string - if hubParser, ok := t.HubIndex.Data[cwhub.PARSERS][parser]; ok { - parserSource, err := filepath.Abs(filepath.Join(t.HubPath, hubParser.RemotePath)) - if err != nil { - return fmt.Errorf("can't get absolute path of '%s': %s", parserSource, err) - } - parserFileName := filepath.Base(parserSource) - - // runtime/hub/parsers/s00-raw/crowdsecurity/ - hubDirParserDest := filepath.Join(t.RuntimeHubPath, filepath.Dir(hubParser.RemotePath)) - - // runtime/parsers/s00-raw/ - parserDirDest = fmt.Sprintf("%s/parsers/%s/", t.RuntimePath, hubParser.Stage) - - if err := os.MkdirAll(hubDirParserDest, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %s", hubDirParserDest, err) - } - if err := os.MkdirAll(parserDirDest, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %s", parserDirDest, err) - } - - // runtime/hub/parsers/s00-raw/crowdsecurity/syslog-logs.yaml - hubDirParserPath := filepath.Join(hubDirParserDest, parserFileName) - if err := Copy(parserSource, hubDirParserPath); err != nil { - return fmt.Errorf("unable to copy '%s' to '%s': %s", parserSource, hubDirParserPath, err) - } - // runtime/parsers/s00-raw/syslog-logs.yaml - parserDirParserPath := filepath.Join(parserDirDest, parserFileName) - if err := os.Symlink(hubDirParserPath, parserDirParserPath); err != nil { - if !os.IsExist(err) { - return fmt.Errorf("unable to symlink parser '%s' to '%s': %s", hubDirParserPath, parserDirParserPath, err) - } - } - } else { - customParserExist := false - for _, customPath := range t.CustomItemsLocation { - // we check if its a custom parser - customParserPath := filepath.Join(customPath, parser) - if _, err := os.Stat(customParserPath); os.IsNotExist(err) { - continue - //return fmt.Errorf("parser '%s' doesn't exist in the hub and doesn't appear to be a custom one.", parser) - } - - customParserPathSplit, customParserName := filepath.Split(customParserPath) - // because path is parsers///parser.yaml and we wan't the stage - splittedPath := strings.Split(customParserPathSplit, string(os.PathSeparator)) - customParserStage := splittedPath[len(splittedPath)-3] - - // check if stage exist - hubStagePath := filepath.Join(t.HubPath, fmt.Sprintf("parsers/%s", customParserStage)) - - if _, err := os.Stat(hubStagePath); os.IsNotExist(err) { - continue - //return fmt.Errorf("stage '%s' extracted from '%s' doesn't exist in the hub", customParserStage, hubStagePath) - } - - parserDirDest = fmt.Sprintf("%s/parsers/%s/", t.RuntimePath, customParserStage) - if err := os.MkdirAll(parserDirDest, os.ModePerm); err != nil { - continue - //return fmt.Errorf("unable to create folder '%s': %s", parserDirDest, err) - } - - customParserDest := filepath.Join(parserDirDest, customParserName) - // if path to parser exist, copy it - if err := Copy(customParserPath, customParserDest); err != nil { - continue - //return fmt.Errorf("unable to copy custom parser '%s' to '%s': %s", customParserPath, customParserDest, err) - } - - customParserExist = true - break - } - if !customParserExist { - return fmt.Errorf("couldn't find custom parser '%s' in the following location: %+v", parser, t.CustomItemsLocation) - } + if err := installFunc(name); err != nil { + return err } } - // install scenarios in runtime environment - for _, scenario := range t.Config.Scenarios { - if scenario == "" { - continue - } - var scenarioDirDest string - if hubScenario, ok := t.HubIndex.Data[cwhub.SCENARIOS][scenario]; ok { - scenarioSource, err := filepath.Abs(filepath.Join(t.HubPath, hubScenario.RemotePath)) - if err != nil { - return fmt.Errorf("can't get absolute path to: %s", scenarioSource) - } - scenarioFileName := filepath.Base(scenarioSource) - - // runtime/hub/scenarios/crowdsecurity/ - hubDirScenarioDest := filepath.Join(t.RuntimeHubPath, filepath.Dir(hubScenario.RemotePath)) - - // runtime/parsers/scenarios/ - scenarioDirDest = fmt.Sprintf("%s/scenarios/", t.RuntimePath) - - if err := os.MkdirAll(hubDirScenarioDest, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %s", hubDirScenarioDest, err) - } - if err := os.MkdirAll(scenarioDirDest, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %s", scenarioDirDest, err) - } - - // runtime/hub/scenarios/crowdsecurity/ssh-bf.yaml - hubDirScenarioPath := filepath.Join(hubDirScenarioDest, scenarioFileName) - if err := Copy(scenarioSource, hubDirScenarioPath); err != nil { - return fmt.Errorf("unable to copy '%s' to '%s': %s", scenarioSource, hubDirScenarioPath, err) - } + return nil +} - // runtime/scenarios/ssh-bf.yaml - scenarioDirParserPath := filepath.Join(scenarioDirDest, scenarioFileName) - if err := os.Symlink(hubDirScenarioPath, scenarioDirParserPath); err != nil { - if !os.IsExist(err) { - return fmt.Errorf("unable to symlink scenario '%s' to '%s': %s", hubDirScenarioPath, scenarioDirParserPath, err) - } - } - } else { - customScenarioExist := false - for _, customPath := range t.CustomItemsLocation { - // we check if its a custom scenario - customScenarioPath := filepath.Join(customPath, scenario) - if _, err := os.Stat(customScenarioPath); os.IsNotExist(err) { - continue - //return fmt.Errorf("scenarios '%s' doesn't exist in the hub and doesn't appear to be a custom one.", scenario) - } - - scenarioDirDest = fmt.Sprintf("%s/scenarios/", t.RuntimePath) - if err := os.MkdirAll(scenarioDirDest, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %s", scenarioDirDest, err) - } - - scenarioFileName := filepath.Base(customScenarioPath) - scenarioFileDest := filepath.Join(scenarioDirDest, scenarioFileName) - if err := Copy(customScenarioPath, scenarioFileDest); err != nil { - continue - //return fmt.Errorf("unable to copy scenario from '%s' to '%s': %s", customScenarioPath, scenarioFileDest, err) - } - customScenarioExist = true - break - } - if !customScenarioExist { - return fmt.Errorf("couldn't find custom scenario '%s' in the following location: %+v", scenario, t.CustomItemsLocation) - } - } +func (t *HubTestItem) InstallHub() error { + if err := t.installHubItems(t.Config.Parsers, t.installParser); err != nil { + return err } - // install postoverflows in runtime environment - for _, postoverflow := range t.Config.PostOVerflows { - if postoverflow == "" { - continue - } - var postoverflowDirDest string - if hubPostOverflow, ok := t.HubIndex.Data[cwhub.PARSERS_OVFLW][postoverflow]; ok { - postoverflowSource, err := filepath.Abs(filepath.Join(t.HubPath, hubPostOverflow.RemotePath)) - if err != nil { - return fmt.Errorf("can't get absolute path of '%s': %s", postoverflowSource, err) - } - postoverflowFileName := filepath.Base(postoverflowSource) - - // runtime/hub/postoverflows/s00-enrich/crowdsecurity/ - hubDirPostoverflowDest := filepath.Join(t.RuntimeHubPath, filepath.Dir(hubPostOverflow.RemotePath)) - - // runtime/postoverflows/s00-enrich - postoverflowDirDest = fmt.Sprintf("%s/postoverflows/%s/", t.RuntimePath, hubPostOverflow.Stage) - - if err := os.MkdirAll(hubDirPostoverflowDest, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %s", hubDirPostoverflowDest, err) - } - if err := os.MkdirAll(postoverflowDirDest, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %s", postoverflowDirDest, err) - } + if err := t.installHubItems(t.Config.Scenarios, t.installScenario); err != nil { + return err + } - // runtime/hub/postoverflows/s00-enrich/crowdsecurity/rdns.yaml - hubDirPostoverflowPath := filepath.Join(hubDirPostoverflowDest, postoverflowFileName) - if err := Copy(postoverflowSource, hubDirPostoverflowPath); err != nil { - return fmt.Errorf("unable to copy '%s' to '%s': %s", postoverflowSource, hubDirPostoverflowPath, err) - } + if err := t.installHubItems(t.Config.PostOverflows, t.installPostoverflow); err != nil { + return err + } - // runtime/postoverflows/s00-enrich/rdns.yaml - postoverflowDirParserPath := filepath.Join(postoverflowDirDest, postoverflowFileName) - if err := os.Symlink(hubDirPostoverflowPath, postoverflowDirParserPath); err != nil { - if !os.IsExist(err) { - return fmt.Errorf("unable to symlink postoverflow '%s' to '%s': %s", hubDirPostoverflowPath, postoverflowDirParserPath, err) - } - } - } else { - customPostoverflowExist := false - for _, customPath := range t.CustomItemsLocation { - // we check if its a custom postoverflow - customPostOverflowPath := filepath.Join(customPath, postoverflow) - if _, err := os.Stat(customPostOverflowPath); os.IsNotExist(err) { - continue - //return fmt.Errorf("postoverflow '%s' doesn't exist in the hub and doesn't appear to be a custom one.", postoverflow) - } - - customPostOverflowPathSplit := strings.Split(customPostOverflowPath, "/") - customPostoverflowName := customPostOverflowPathSplit[len(customPostOverflowPathSplit)-1] - // because path is postoverflows///parser.yaml and we wan't the stage - customPostoverflowStage := customPostOverflowPathSplit[len(customPostOverflowPathSplit)-3] - - // check if stage exist - hubStagePath := filepath.Join(t.HubPath, fmt.Sprintf("postoverflows/%s", customPostoverflowStage)) - - if _, err := os.Stat(hubStagePath); os.IsNotExist(err) { - continue - //return fmt.Errorf("stage '%s' from extracted '%s' doesn't exist in the hub", customPostoverflowStage, hubStagePath) - } - - postoverflowDirDest = fmt.Sprintf("%s/postoverflows/%s/", t.RuntimePath, customPostoverflowStage) - if err := os.MkdirAll(postoverflowDirDest, os.ModePerm); err != nil { - continue - //return fmt.Errorf("unable to create folder '%s': %s", postoverflowDirDest, err) - } - - customPostoverflowDest := filepath.Join(postoverflowDirDest, customPostoverflowName) - // if path to postoverflow exist, copy it - if err := Copy(customPostOverflowPath, customPostoverflowDest); err != nil { - continue - //return fmt.Errorf("unable to copy custom parser '%s' to '%s': %s", customPostOverflowPath, customPostoverflowDest, err) - } - customPostoverflowExist = true - break - } - if !customPostoverflowExist { - return fmt.Errorf("couldn't find custom postoverflow '%s' in the following location: %+v", postoverflow, t.CustomItemsLocation) - } - } + if err := t.installHubItems(t.Config.AppsecRules, t.installAppsecRule); err != nil { + return err } if len(t.Config.OverrideStatics) > 0 { @@ -384,53 +198,55 @@ func (t *HubTestItem) InstallHub() error { Filter: "1==1", Statics: t.Config.OverrideStatics, } + b, err := yaml.Marshal(n) if err != nil { - return fmt.Errorf("unable to marshal overrides: %s", err) + return fmt.Errorf("unable to serialize overrides: %w", err) } + tgtFilename := fmt.Sprintf("%s/parsers/s00-raw/00_overrides.yaml", t.RuntimePath) if err := os.WriteFile(tgtFilename, b, os.ModePerm); err != nil { - return fmt.Errorf("unable to write overrides to '%s': %s", tgtFilename, err) + return fmt.Errorf("unable to write overrides to '%s': %w", tgtFilename, err) } } // load installed hub - err := cwhub.GetHubIdx(t.RuntimeHubConfig) + hub, err := cwhub.NewHub(t.RuntimeHubConfig, nil, nil) if err != nil { - log.Fatalf("can't local sync the hub: %+v", err) + return err + } + + if err := hub.Load(); err != nil { + return err } + ctx := context.Background() + // install data for parsers if needed - ret := cwhub.GetItemMap(cwhub.PARSERS) - for parserName, item := range ret { - if item.Installed { - if err := cwhub.DownloadDataIfNeeded(t.RuntimeHubConfig, item, true); err != nil { - return fmt.Errorf("unable to download data for parser '%s': %+v", parserName, err) - } - log.Debugf("parser '%s' installed successfully in runtime environment", parserName) + for _, item := range hub.GetInstalledByType(cwhub.PARSERS, true) { + if err := item.DownloadDataIfNeeded(ctx, true); err != nil { + return fmt.Errorf("unable to download data for parser '%s': %+v", item.Name, err) } + + log.Debugf("parser '%s' installed successfully in runtime environment", item.Name) } // install data for scenarios if needed - ret = cwhub.GetItemMap(cwhub.SCENARIOS) - for scenarioName, item := range ret { - if item.Installed { - if err := cwhub.DownloadDataIfNeeded(t.RuntimeHubConfig, item, true); err != nil { - return fmt.Errorf("unable to download data for parser '%s': %+v", scenarioName, err) - } - log.Debugf("scenario '%s' installed successfully in runtime environment", scenarioName) + for _, item := range hub.GetInstalledByType(cwhub.SCENARIOS, true) { + if err := item.DownloadDataIfNeeded(ctx, true); err != nil { + return fmt.Errorf("unable to download data for parser '%s': %+v", item.Name, err) } + + log.Debugf("scenario '%s' installed successfully in runtime environment", item.Name) } // install data for postoverflows if needed - ret = cwhub.GetItemMap(cwhub.PARSERS_OVFLW) - for postoverflowName, item := range ret { - if item.Installed { - if err := cwhub.DownloadDataIfNeeded(t.RuntimeHubConfig, item, true); err != nil { - return fmt.Errorf("unable to download data for parser '%s': %+v", postoverflowName, err) - } - log.Debugf("postoverflow '%s' installed successfully in runtime environment", postoverflowName) + for _, item := range hub.GetInstalledByType(cwhub.POSTOVERFLOWS, true) { + if err := item.DownloadDataIfNeeded(ctx, true); err != nil { + return fmt.Errorf("unable to download data for parser '%s': %+v", item.Name, err) } + + log.Debugf("postoverflow '%s' installed successfully in runtime environment", item.Name) } return nil @@ -440,88 +256,199 @@ func (t *HubTestItem) Clean() error { return os.RemoveAll(t.RuntimePath) } -func (t *HubTestItem) Run() error { - t.Success = false - t.ErrorsList = make([]string, 0) +func (t *HubTestItem) RunWithNucleiTemplate() error { + crowdsecLogFile := fmt.Sprintf("%s/log/crowdsec.log", t.RuntimePath) testPath := filepath.Join(t.HubTestPath, t.Name) if _, err := os.Stat(testPath); os.IsNotExist(err) { return fmt.Errorf("test '%s' doesn't exist in '%s', exiting", t.Name, t.HubTestPath) } - currentDir, err := os.Getwd() + if err := os.Chdir(testPath); err != nil { + return fmt.Errorf("can't 'cd' to '%s': %w", testPath, err) + } + + // machine add + cmdArgs := []string{"-c", t.RuntimeConfigFilePath, "machines", "add", "testMachine", "--force", "--auto"} + cscliRegisterCmd := exec.Command(t.CscliPath, cmdArgs...) + + output, err := cscliRegisterCmd.CombinedOutput() if err != nil { - return fmt.Errorf("can't get current directory: %+v", err) + if !strings.Contains(string(output), "unable to create machine: user 'testMachine': user already exist") { + fmt.Println(string(output)) + return fmt.Errorf("fail to run '%s' for test '%s': %v", cscliRegisterCmd.String(), t.Name, err) + } } - // create runtime folder - if err := os.MkdirAll(t.RuntimePath, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %+v", t.RuntimePath, err) + // hardcode bouncer key + cmdArgs = []string{"-c", t.RuntimeConfigFilePath, "bouncers", "add", "appsectests", "-k", TestBouncerApiKey} + cscliBouncerCmd := exec.Command(t.CscliPath, cmdArgs...) + + output, err = cscliBouncerCmd.CombinedOutput() + if err != nil { + if !strings.Contains(string(output), "unable to create bouncer: bouncer appsectests already exists") { + fmt.Println(string(output)) + return fmt.Errorf("fail to run '%s' for test '%s': %v", cscliRegisterCmd.String(), t.Name, err) + } } - // create runtime data folder - if err := os.MkdirAll(t.RuntimeDataPath, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %+v", t.RuntimeDataPath, err) + // start crowdsec service + cmdArgs = []string{"-c", t.RuntimeConfigFilePath} + crowdsecDaemon := exec.Command(t.CrowdSecPath, cmdArgs...) + + crowdsecDaemon.Start() + + // wait for the appsec port to be available + if _, err := IsAlive(t.AppSecHost); err != nil { + crowdsecLog, err2 := os.ReadFile(crowdsecLogFile) + if err2 != nil { + log.Errorf("unable to read crowdsec log file '%s': %s", crowdsecLogFile, err) + } else { + log.Errorf("crowdsec log file '%s'", crowdsecLogFile) + log.Errorf("%s\n", string(crowdsecLog)) + } + + return fmt.Errorf("appsec is down: %w", err) } - // create runtime hub folder - if err := os.MkdirAll(t.RuntimeHubPath, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %+v", t.RuntimeHubPath, err) + // check if the target is available + nucleiTargetParsedURL, err := url.Parse(t.NucleiTargetHost) + if err != nil { + return fmt.Errorf("unable to parse target '%s': %w", t.NucleiTargetHost, err) } - if err := Copy(t.HubIndexFile, filepath.Join(t.RuntimeHubPath, ".index.json")); err != nil { - return fmt.Errorf("unable to copy .index.json file in '%s': %s", filepath.Join(t.RuntimeHubPath, ".index.json"), err) + nucleiTargetHost := nucleiTargetParsedURL.Host + if _, err := IsAlive(nucleiTargetHost); err != nil { + return fmt.Errorf("target is down: %w", err) } - // create results folder - if err := os.MkdirAll(t.ResultsPath, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %+v", t.ResultsPath, err) + nucleiConfig := NucleiConfig{ + Path: "nuclei", + OutputDir: t.RuntimePath, + CmdLineOptions: []string{ + "-ev", // allow variables from environment + "-nc", // no colors in output + "-dresp", // dump response + "-j", // json output + }, + } + + err = nucleiConfig.RunNucleiTemplate(t.Name, t.Config.NucleiTemplate, t.NucleiTargetHost) + if t.Config.ExpectedNucleiFailure { + if err != nil && errors.Is(err, ErrNucleiTemplateFail) { + log.Infof("Appsec test %s failed as expected", t.Name) + t.Success = true + } else { + log.Errorf("Appsec test %s failed: %s", t.Name, err) + + crowdsecLog, err := os.ReadFile(crowdsecLogFile) + if err != nil { + log.Errorf("unable to read crowdsec log file '%s': %s", crowdsecLogFile, err) + } else { + log.Errorf("crowdsec log file '%s'", crowdsecLogFile) + log.Errorf("%s\n", string(crowdsecLog)) + } + } + } else { + if err == nil { + log.Infof("Appsec test %s succeeded", t.Name) + t.Success = true + } else { + log.Errorf("Appsec test %s failed: %s", t.Name, err) + + crowdsecLog, err := os.ReadFile(crowdsecLogFile) + if err != nil { + log.Errorf("unable to read crowdsec log file '%s': %s", crowdsecLogFile, err) + } else { + log.Errorf("crowdsec log file '%s'", crowdsecLogFile) + log.Errorf("%s\n", string(crowdsecLog)) + } + } + } + + crowdsecDaemon.Process.Kill() + + return nil +} + +func createDirs(dirs []string) error { + for _, dir := range dirs { + if err := os.MkdirAll(dir, os.ModePerm); err != nil { + return fmt.Errorf("unable to create directory '%s': %w", dir, err) + } + } + + return nil +} + +func (t *HubTestItem) RunWithLogFile() error { + testPath := filepath.Join(t.HubTestPath, t.Name) + if _, err := os.Stat(testPath); os.IsNotExist(err) { + return fmt.Errorf("test '%s' doesn't exist in '%s', exiting", t.Name, t.HubTestPath) + } + + currentDir, err := os.Getwd() // xx + if err != nil { + return fmt.Errorf("can't get current directory: %+v", err) + } + + // create runtime, data, hub folders + if err = createDirs([]string{t.RuntimePath, t.RuntimeDataPath, t.RuntimeHubPath, t.ResultsPath}); err != nil { + return err + } + + if err = Copy(t.HubIndexFile, filepath.Join(t.RuntimeHubPath, ".index.json")); err != nil { + return fmt.Errorf("unable to copy .index.json file in '%s': %w", filepath.Join(t.RuntimeHubPath, ".index.json"), err) } // copy template config file to runtime folder - if err := Copy(t.TemplateConfigPath, t.RuntimeConfigFilePath); err != nil { + if err = Copy(t.TemplateConfigPath, t.RuntimeConfigFilePath); err != nil { return fmt.Errorf("unable to copy '%s' to '%s': %v", t.TemplateConfigPath, t.RuntimeConfigFilePath, err) } // copy template profile file to runtime folder - if err := Copy(t.TemplateProfilePath, t.RuntimeProfileFilePath); err != nil { + if err = Copy(t.TemplateProfilePath, t.RuntimeProfileFilePath); err != nil { return fmt.Errorf("unable to copy '%s' to '%s': %v", t.TemplateProfilePath, t.RuntimeProfileFilePath, err) } // copy template simulation file to runtime folder - if err := Copy(t.TemplateSimulationPath, t.RuntimeSimulationFilePath); err != nil { + if err = Copy(t.TemplateSimulationPath, t.RuntimeSimulationFilePath); err != nil { return fmt.Errorf("unable to copy '%s' to '%s': %v", t.TemplateSimulationPath, t.RuntimeSimulationFilePath, err) } + crowdsecPatternsFolder := csconfig.DefaultConfigPath("patterns") + // copy template patterns folder to runtime folder - if err := CopyDir(crowdsecPatternsFolder, t.RuntimePatternsPath); err != nil { - return fmt.Errorf("unable to copy 'patterns' from '%s' to '%s': %s", crowdsecPatternsFolder, t.RuntimePatternsPath, err) + if err = CopyDir(crowdsecPatternsFolder, t.RuntimePatternsPath); err != nil { + return fmt.Errorf("unable to copy 'patterns' from '%s' to '%s': %w", crowdsecPatternsFolder, t.RuntimePatternsPath, err) } // install the hub in the runtime folder - if err := t.InstallHub(); err != nil { - return fmt.Errorf("unable to install hub in '%s': %s", t.RuntimeHubPath, err) + if err = t.InstallHub(); err != nil { + return fmt.Errorf("unable to install hub in '%s': %w", t.RuntimeHubPath, err) } logFile := t.Config.LogFile logType := t.Config.LogType dsn := fmt.Sprintf("file://%s", logFile) - if err := os.Chdir(testPath); err != nil { - return fmt.Errorf("can't 'cd' to '%s': %s", testPath, err) + if err = os.Chdir(testPath); err != nil { + return fmt.Errorf("can't 'cd' to '%s': %w", testPath, err) } logFileStat, err := os.Stat(logFile) if err != nil { - return fmt.Errorf("unable to stat log file '%s': %s", logFile, err) + return fmt.Errorf("unable to stat log file '%s': %w", logFile, err) } + if logFileStat.Size() == 0 { - return fmt.Errorf("Log file '%s' is empty, please fill it with log", logFile) + return fmt.Errorf("log file '%s' is empty, please fill it with log", logFile) } - cmdArgs := []string{"-c", t.RuntimeConfigFilePath, "machines", "add", "testMachine", "--auto"} + cmdArgs := []string{"-c", t.RuntimeConfigFilePath, "machines", "add", "testMachine", "--force", "--auto"} cscliRegisterCmd := exec.Command(t.CscliPath, cmdArgs...) log.Debugf("%s", cscliRegisterCmd.String()) + output, err := cscliRegisterCmd.CombinedOutput() if err != nil { if !strings.Contains(string(output), "unable to create machine: user 'testMachine': user already exist") { @@ -531,22 +458,26 @@ func (t *HubTestItem) Run() error { } cmdArgs = []string{"-c", t.RuntimeConfigFilePath, "-type", logType, "-dsn", dsn, "-dump-data", t.ResultsPath, "-order-event"} + for labelKey, labelValue := range t.Config.Labels { arg := fmt.Sprintf("%s:%s", labelKey, labelValue) cmdArgs = append(cmdArgs, "-label", arg) } + crowdsecCmd := exec.Command(t.CrowdSecPath, cmdArgs...) log.Debugf("%s", crowdsecCmd.String()) output, err = crowdsecCmd.CombinedOutput() + if log.GetLevel() >= log.DebugLevel || err != nil { fmt.Println(string(output)) } + if err != nil { return fmt.Errorf("fail to run '%s' for test '%s': %v", crowdsecCmd.String(), t.Name, err) } if err := os.Chdir(currentDir); err != nil { - return fmt.Errorf("can't 'cd' to '%s': %s", currentDir, err) + return fmt.Errorf("can't 'cd' to '%s': %w", currentDir, err) } // assert parsers @@ -555,61 +486,70 @@ func (t *HubTestItem) Run() error { if os.IsNotExist(err) { parserAssertFile, err := os.Create(t.ParserAssert.File) if err != nil { - log.Fatal(err) + return err } + parserAssertFile.Close() } + assertFileStat, err := os.Stat(t.ParserAssert.File) if err != nil { - return fmt.Errorf("error while stats '%s': %s", t.ParserAssert.File, err) + return fmt.Errorf("error while stats '%s': %w", t.ParserAssert.File, err) } if assertFileStat.Size() == 0 { assertData, err := t.ParserAssert.AutoGenFromFile(t.ParserResultFile) if err != nil { - return fmt.Errorf("couldn't generate assertion: %s", err) + return fmt.Errorf("couldn't generate assertion: %w", err) } + t.ParserAssert.AutoGenAssertData = assertData t.ParserAssert.AutoGenAssert = true } else { if err := t.ParserAssert.AssertFile(t.ParserResultFile); err != nil { - return fmt.Errorf("unable to run assertion on file '%s': %s", t.ParserResultFile, err) + return fmt.Errorf("unable to run assertion on file '%s': %w", t.ParserResultFile, err) } } } // assert scenarios nbScenario := 0 + for _, scenario := range t.Config.Scenarios { if scenario == "" { continue } - nbScenario += 1 + + nbScenario++ } + if nbScenario > 0 { _, err := os.Stat(t.ScenarioAssert.File) if os.IsNotExist(err) { scenarioAssertFile, err := os.Create(t.ScenarioAssert.File) if err != nil { - log.Fatal(err) + return err } + scenarioAssertFile.Close() } + assertFileStat, err := os.Stat(t.ScenarioAssert.File) if err != nil { - return fmt.Errorf("error while stats '%s': %s", t.ScenarioAssert.File, err) + return fmt.Errorf("error while stats '%s': %w", t.ScenarioAssert.File, err) } if assertFileStat.Size() == 0 { assertData, err := t.ScenarioAssert.AutoGenFromFile(t.ScenarioResultFile) if err != nil { - return fmt.Errorf("couldn't generate assertion: %s", err) + return fmt.Errorf("couldn't generate assertion: %w", err) } + t.ScenarioAssert.AutoGenAssertData = assertData t.ScenarioAssert.AutoGenAssert = true } else { if err := t.ScenarioAssert.AssertFile(t.ScenarioResultFile); err != nil { - return fmt.Errorf("unable to run assertion on file '%s': %s", t.ScenarioResultFile, err) + return fmt.Errorf("unable to run assertion on file '%s': %w", t.ScenarioResultFile, err) } } } @@ -624,3 +564,78 @@ func (t *HubTestItem) Run() error { return nil } + +func (t *HubTestItem) Run() error { + var err error + + t.Success = false + t.ErrorsList = make([]string, 0) + + // create runtime, data, hub, result folders + if err = createDirs([]string{t.RuntimePath, t.RuntimeDataPath, t.RuntimeHubPath, t.ResultsPath}); err != nil { + return err + } + + if err = Copy(t.HubIndexFile, filepath.Join(t.RuntimeHubPath, ".index.json")); err != nil { + return fmt.Errorf("unable to copy .index.json file in '%s': %w", filepath.Join(t.RuntimeHubPath, ".index.json"), err) + } + + // copy template config file to runtime folder + if err = Copy(t.TemplateConfigPath, t.RuntimeConfigFilePath); err != nil { + return fmt.Errorf("unable to copy '%s' to '%s': %v", t.TemplateConfigPath, t.RuntimeConfigFilePath, err) + } + + // copy template profile file to runtime folder + if err = Copy(t.TemplateProfilePath, t.RuntimeProfileFilePath); err != nil { + return fmt.Errorf("unable to copy '%s' to '%s': %v", t.TemplateProfilePath, t.RuntimeProfileFilePath, err) + } + + // copy template simulation file to runtime folder + if err = Copy(t.TemplateSimulationPath, t.RuntimeSimulationFilePath); err != nil { + return fmt.Errorf("unable to copy '%s' to '%s': %v", t.TemplateSimulationPath, t.RuntimeSimulationFilePath, err) + } + + crowdsecPatternsFolder := csconfig.DefaultConfigPath("patterns") + + // copy template patterns folder to runtime folder + if err = CopyDir(crowdsecPatternsFolder, t.RuntimePatternsPath); err != nil { + return fmt.Errorf("unable to copy 'patterns' from '%s' to '%s': %w", crowdsecPatternsFolder, t.RuntimePatternsPath, err) + } + + // create the appsec-configs dir + if err = os.MkdirAll(filepath.Join(t.RuntimePath, "appsec-configs"), os.ModePerm); err != nil { + return fmt.Errorf("unable to create folder '%s': %+v", t.RuntimePath, err) + } + + // if it's an appsec rule test, we need acquis and appsec profile + if len(t.Config.AppsecRules) > 0 { + // copy template acquis file to runtime folder + log.Debugf("copying %s to %s", t.TemplateAcquisPath, t.RuntimeAcquisFilePath) + + if err = Copy(t.TemplateAcquisPath, t.RuntimeAcquisFilePath); err != nil { + return fmt.Errorf("unable to copy '%s' to '%s': %v", t.TemplateAcquisPath, t.RuntimeAcquisFilePath, err) + } + + log.Debugf("copying %s to %s", t.TemplateAppsecProfilePath, filepath.Join(t.RuntimePath, "appsec-configs", "config.yaml")) + // copy template appsec-config file to runtime folder + if err = Copy(t.TemplateAppsecProfilePath, filepath.Join(t.RuntimePath, "appsec-configs", "config.yaml")); err != nil { + return fmt.Errorf("unable to copy '%s' to '%s': %v", t.TemplateAppsecProfilePath, filepath.Join(t.RuntimePath, "appsec-configs", "config.yaml"), err) + } + } else { // otherwise we drop a blank acquis file + if err = os.WriteFile(t.RuntimeAcquisFilePath, []byte(""), os.ModePerm); err != nil { + return fmt.Errorf("unable to write blank acquis file '%s': %w", t.RuntimeAcquisFilePath, err) + } + } + + // install the hub in the runtime folder + if err = t.InstallHub(); err != nil { + return fmt.Errorf("unable to install hub in '%s': %w", t.RuntimeHubPath, err) + } + + if t.Config.LogFile != "" { + return t.RunWithLogFile() + } else if t.Config.NucleiTemplate != "" { + return t.RunWithNucleiTemplate() + } + return fmt.Errorf("log file or nuclei template must be set in '%s'", t.Name) +} diff --git a/pkg/hubtest/nucleirunner.go b/pkg/hubtest/nucleirunner.go new file mode 100644 index 00000000000..32c81eb64d8 --- /dev/null +++ b/pkg/hubtest/nucleirunner.go @@ -0,0 +1,67 @@ +package hubtest + +import ( + "bytes" + "errors" + "fmt" + "os" + "os/exec" + "time" + + log "github.com/sirupsen/logrus" +) + +type NucleiConfig struct { + Path string `yaml:"nuclei_path"` + OutputDir string `yaml:"output_dir"` + CmdLineOptions []string `yaml:"cmdline_options"` +} + +var ErrNucleiTemplateFail = errors.New("nuclei template failed") + +func (nc *NucleiConfig) RunNucleiTemplate(testName string, templatePath string, target string) error { + tstamp := time.Now().Unix() + + outputPrefix := fmt.Sprintf("%s/%s-%d", nc.OutputDir, testName, tstamp) + // CVE-2023-34362_CVE-2023-34362-1702562399_stderr.txt + args := []string{ + "-u", target, + "-t", templatePath, + "-o", outputPrefix + ".json", + } + args = append(args, nc.CmdLineOptions...) + cmd := exec.Command(nc.Path, args...) + + log.Debugf("Running Nuclei command: '%s'", cmd.String()) + + var out bytes.Buffer + var outErr bytes.Buffer + + cmd.Stdout = &out + cmd.Stderr = &outErr + + err := cmd.Run() + + if err := os.WriteFile(outputPrefix+"_stdout.txt", out.Bytes(), 0o644); err != nil { + log.Warningf("Error writing stdout: %s", err) + } + + if err := os.WriteFile(outputPrefix+"_stderr.txt", outErr.Bytes(), 0o644); err != nil { + log.Warningf("Error writing stderr: %s", err) + } + + if err != nil { + log.Warningf("Error running nuclei: %s", err) + log.Warningf("Stdout saved to %s", outputPrefix+"_stdout.txt") + log.Warningf("Stderr saved to %s", outputPrefix+"_stderr.txt") + log.Warningf("Nuclei generated output saved to %s", outputPrefix+".json") + return err + } else if out.String() == "" { + log.Warningf("Stdout saved to %s", outputPrefix+"_stdout.txt") + log.Warningf("Stderr saved to %s", outputPrefix+"_stderr.txt") + log.Warningf("Nuclei generated output saved to %s", outputPrefix+".json") + //No stdout means no finding, it means our test failed + return ErrNucleiTemplateFail + } + return nil +} diff --git a/pkg/hubtest/parser.go b/pkg/hubtest/parser.go new file mode 100644 index 00000000000..31ff459ca77 --- /dev/null +++ b/pkg/hubtest/parser.go @@ -0,0 +1,100 @@ +package hubtest + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/crowdsecurity/crowdsec/pkg/cwhub" +) + +func (t *HubTestItem) installParserItem(item *cwhub.Item) error { + sourcePath, err := filepath.Abs(filepath.Join(t.HubPath, item.RemotePath)) + if err != nil { + return fmt.Errorf("can't get absolute path of '%s': %w", sourcePath, err) + } + + sourceFilename := filepath.Base(sourcePath) + + // runtime/hub/parsers/s00-raw/crowdsecurity/ + hubDirParserDest := filepath.Join(t.RuntimeHubPath, filepath.Dir(item.RemotePath)) + + // runtime/parsers/s00-raw/ + itemTypeDirDest := fmt.Sprintf("%s/parsers/%s/", t.RuntimePath, item.Stage) + + if err := createDirs([]string{hubDirParserDest, itemTypeDirDest}); err != nil { + return err + } + + // runtime/hub/parsers/s00-raw/crowdsecurity/syslog-logs.yaml + hubDirParserPath := filepath.Join(hubDirParserDest, sourceFilename) + if err := Copy(sourcePath, hubDirParserPath); err != nil { + return fmt.Errorf("unable to copy '%s' to '%s': %w", sourcePath, hubDirParserPath, err) + } + + // runtime/parsers/s00-raw/syslog-logs.yaml + parserDirParserPath := filepath.Join(itemTypeDirDest, sourceFilename) + if err := os.Symlink(hubDirParserPath, parserDirParserPath); err != nil { + if !os.IsExist(err) { + return fmt.Errorf("unable to symlink parser '%s' to '%s': %w", hubDirParserPath, parserDirParserPath, err) + } + } + + return nil +} + +func (t *HubTestItem) installParserCustomFrom(parser string, customPath string) (bool, error) { + // we check if its a custom parser + customParserPath := filepath.Join(customPath, parser) + if _, err := os.Stat(customParserPath); os.IsNotExist(err) { + return false, nil + } + + customParserPathSplit, customParserName := filepath.Split(customParserPath) + // because path is parsers///parser.yaml and we wan't the stage + splitPath := strings.Split(customParserPathSplit, string(os.PathSeparator)) + customParserStage := splitPath[len(splitPath)-3] + + // check if stage exist + hubStagePath := filepath.Join(t.HubPath, fmt.Sprintf("parsers/%s", customParserStage)) + if _, err := os.Stat(hubStagePath); os.IsNotExist(err) { + return false, fmt.Errorf("stage '%s' extracted from '%s' doesn't exist in the hub", customParserStage, hubStagePath) + } + + stageDirDest := fmt.Sprintf("%s/parsers/%s/", t.RuntimePath, customParserStage) + if err := os.MkdirAll(stageDirDest, os.ModePerm); err != nil { + return false, fmt.Errorf("unable to create folder '%s': %w", stageDirDest, err) + } + + customParserDest := filepath.Join(stageDirDest, customParserName) + // if path to parser exist, copy it + if err := Copy(customParserPath, customParserDest); err != nil { + return false, fmt.Errorf("unable to copy custom parser '%s' to '%s': %w", customParserPath, customParserDest, err) + } + + return true, nil +} + +func (t *HubTestItem) installParserCustom(parser string) error { + for _, customPath := range t.CustomItemsLocation { + found, err := t.installParserCustomFrom(parser, customPath) + if err != nil { + return err + } + + if found { + return nil + } + } + + return fmt.Errorf("couldn't find custom parser '%s' in the following locations: %+v", parser, t.CustomItemsLocation) +} + +func (t *HubTestItem) installParser(name string) error { + if item := t.HubIndex.GetItem(cwhub.PARSERS, name); item != nil { + return t.installParserItem(item) + } + + return t.installParserCustom(name) +} diff --git a/pkg/hubtest/parser_assert.go b/pkg/hubtest/parser_assert.go index 95400b50d1a..be4fdbdb5e6 100644 --- a/pkg/hubtest/parser_assert.go +++ b/pkg/hubtest/parser_assert.go @@ -2,24 +2,19 @@ package hubtest import ( "bufio" + "errors" "fmt" - "io" "os" - "regexp" - "sort" "strings" - "time" - "github.com/antonmedv/expr" - "github.com/antonmedv/expr/vm" - "github.com/enescakir/emoji" - "github.com/fatih/color" - diff "github.com/r3labs/diff/v2" + "github.com/expr-lang/expr" log "github.com/sirupsen/logrus" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" + "github.com/crowdsecurity/go-cs-lib/maptools" + + "github.com/crowdsecurity/crowdsec/pkg/dumps" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" - "github.com/crowdsecurity/crowdsec/pkg/types" ) type AssertFail struct { @@ -36,25 +31,19 @@ type ParserAssert struct { NbAssert int Fails []AssertFail Success bool - TestData *ParserResults -} - -type ParserResult struct { - Evt types.Event - Success bool + TestData *dumps.ParserResults } -type ParserResults map[string]map[string][]ParserResult func NewParserAssert(file string) *ParserAssert { - ParserAssert := &ParserAssert{ File: file, NbAssert: 0, Success: false, Fails: make([]AssertFail, 0), AutoGenAssert: false, - TestData: &ParserResults{}, + TestData: &dumps.ParserResults{}, } + return ParserAssert } @@ -63,80 +52,98 @@ func (p *ParserAssert) AutoGenFromFile(filename string) (string, error) { if err != nil { return "", err } + ret := p.AutoGenParserAssert() + return ret, nil } func (p *ParserAssert) LoadTest(filename string) error { - var err error - parserDump, err := LoadParserDump(filename) + parserDump, err := dumps.LoadParserDump(filename) if err != nil { return fmt.Errorf("loading parser dump file: %+v", err) } + p.TestData = parserDump + return nil } func (p *ParserAssert) AssertFile(testFile string) error { - file, err := os.Open(p.File) - if err != nil { - return fmt.Errorf("failed to open") + return errors.New("failed to open") } if err := p.LoadTest(testFile); err != nil { - return fmt.Errorf("unable to load parser dump file '%s': %s", testFile, err) + return fmt.Errorf("unable to load parser dump file '%s': %w", testFile, err) } + scanner := bufio.NewScanner(file) scanner.Split(bufio.ScanLines) + nbLine := 0 + for scanner.Scan() { - nbLine += 1 + nbLine++ + if scanner.Text() == "" { continue } + ok, err := p.Run(scanner.Text()) if err != nil { return fmt.Errorf("unable to run assert '%s': %+v", scanner.Text(), err) } - p.NbAssert += 1 + + p.NbAssert++ + if !ok { log.Debugf("%s is FALSE", scanner.Text()) - //fmt.SPrintf(" %s '%s'\n", emoji.RedSquare, scanner.Text()) failedAssert := &AssertFail{ File: p.File, Line: nbLine, Expression: scanner.Text(), Debug: make(map[string]string), } - variableRE := regexp.MustCompile(`(?P[^ =]+) == .*`) + match := variableRE.FindStringSubmatch(scanner.Text()) + + var variable string + if len(match) == 0 { log.Infof("Couldn't get variable of line '%s'", scanner.Text()) + variable = scanner.Text() + } else { + variable = match[1] } - variable := match[1] + result, err := p.EvalExpression(variable) if err != nil { log.Errorf("unable to evaluate variable '%s': %s", variable, err) continue } + failedAssert.Debug[variable] = result p.Fails = append(p.Fails, *failedAssert) + continue } - //fmt.Printf(" %s '%s'\n", emoji.GreenSquare, scanner.Text()) - + // fmt.Printf(" %s '%s'\n", emoji.GreenSquare, scanner.Text()) } + file.Close() + if p.NbAssert == 0 { assertData, err := p.AutoGenFromFile(testFile) if err != nil { - return fmt.Errorf("couldn't generate assertion: %s", err) + return fmt.Errorf("couldn't generate assertion: %w", err) } + p.AutoGenAssertData = assertData p.AutoGenAssert = true } + if len(p.Fails) == 0 { p.Success = true } @@ -145,27 +152,29 @@ func (p *ParserAssert) AssertFile(testFile string) error { } func (p *ParserAssert) RunExpression(expression string) (interface{}, error) { - var err error - //debug doesn't make much sense with the ability to evaluate "on the fly" - //var debugFilter *exprhelpers.ExprDebugger - var runtimeFilter *vm.Program + // debug doesn't make much sense with the ability to evaluate "on the fly" + // var debugFilter *exprhelpers.ExprDebugger var output interface{} env := map[string]interface{}{"results": *p.TestData} - if runtimeFilter, err = expr.Compile(expression, exprhelpers.GetExprOptions(env)...); err != nil { + runtimeFilter, err := expr.Compile(expression, exprhelpers.GetExprOptions(env)...) + if err != nil { + log.Errorf("failed to compile '%s' : %s", expression, err) return output, err } - //dump opcode in trace level + // dump opcode in trace level log.Tracef("%s", runtimeFilter.Disassemble()) - output, err = expr.Run(runtimeFilter, map[string]interface{}{"results": *p.TestData}) + output, err = expr.Run(runtimeFilter, env) if err != nil { log.Warningf("running : %s", expression) log.Warningf("runtime error : %s", err) + return output, fmt.Errorf("while running expression %s: %w", expression, err) } + return output, nil } @@ -174,10 +183,12 @@ func (p *ParserAssert) EvalExpression(expression string) (string, error) { if err != nil { return "", err } + ret, err := yaml.Marshal(output) if err != nil { return "", err } + return string(ret), nil } @@ -186,6 +197,7 @@ func (p *ParserAssert) Run(assert string) (bool, error) { if err != nil { return false, err } + switch out := output.(type) { case bool: return out, nil @@ -197,72 +209,89 @@ func (p *ParserAssert) Run(assert string) (bool, error) { func Escape(val string) string { val = strings.ReplaceAll(val, `\`, `\\`) val = strings.ReplaceAll(val, `"`, `\"`) + return val } func (p *ParserAssert) AutoGenParserAssert() string { - //attempt to autogen parser asserts - var ret string + // attempt to autogen parser asserts + ret := fmt.Sprintf("len(results) == %d\n", len(*p.TestData)) + + // sort map keys for consistent order + stages := maptools.SortedKeys(*p.TestData) - //sort map keys for consistent ordre - var stages []string - for stage := range *p.TestData { - stages = append(stages, stage) - } - sort.Strings(stages) - ret += fmt.Sprintf("len(results) == %d\n", len(*p.TestData)) for _, stage := range stages { parsers := (*p.TestData)[stage] - //sort map keys for consistent ordre - var pnames []string - for pname := range parsers { - pnames = append(pnames, pname) - } - sort.Strings(pnames) + + // sort map keys for consistent order + pnames := maptools.SortedKeys(parsers) + for _, parser := range pnames { presults := parsers[parser] ret += fmt.Sprintf(`len(results["%s"]["%s"]) == %d`+"\n", stage, parser, len(presults)) + for pidx, result := range presults { ret += fmt.Sprintf(`results["%s"]["%s"][%d].Success == %t`+"\n", stage, parser, pidx, result.Success) if !result.Success { continue } - for pkey, pval := range result.Evt.Parsed { + + for _, pkey := range maptools.SortedKeys(result.Evt.Parsed) { + pval := result.Evt.Parsed[pkey] if pval == "" { continue } + ret += fmt.Sprintf(`results["%s"]["%s"][%d].Evt.Parsed["%s"] == "%s"`+"\n", stage, parser, pidx, pkey, Escape(pval)) } - for mkey, mval := range result.Evt.Meta { + + for _, mkey := range maptools.SortedKeys(result.Evt.Meta) { + mval := result.Evt.Meta[mkey] if mval == "" { continue } + ret += fmt.Sprintf(`results["%s"]["%s"][%d].Evt.Meta["%s"] == "%s"`+"\n", stage, parser, pidx, mkey, Escape(mval)) } - for ekey, eval := range result.Evt.Enriched { + + for _, ekey := range maptools.SortedKeys(result.Evt.Enriched) { + eval := result.Evt.Enriched[ekey] if eval == "" { continue } + ret += fmt.Sprintf(`results["%s"]["%s"][%d].Evt.Enriched["%s"] == "%s"`+"\n", stage, parser, pidx, ekey, Escape(eval)) } - for ekey, eval := range result.Evt.Unmarshaled { - if eval == "" { + + for _, ukey := range maptools.SortedKeys(result.Evt.Unmarshaled) { + uval := result.Evt.Unmarshaled[ukey] + if uval == "" { continue } - base := fmt.Sprintf(`results["%s"]["%s"][%d].Evt.Unmarshaled["%s"]`, stage, parser, pidx, ekey) - for _, line := range p.buildUnmarshaledAssert("", eval) { - ret += base + line + + base := fmt.Sprintf(`results["%s"]["%s"][%d].Evt.Unmarshaled["%s"]`, stage, parser, pidx, ukey) + + for _, line := range p.buildUnmarshaledAssert(base, uval) { + ret += line } } + + ret += fmt.Sprintf(`results["%s"]["%s"][%d].Evt.Whitelisted == %t`+"\n", stage, parser, pidx, result.Evt.Whitelisted) + + if result.Evt.WhitelistReason != "" { + ret += fmt.Sprintf(`results["%s"]["%s"][%d].Evt.WhitelistReason == "%s"`+"\n", stage, parser, pidx, Escape(result.Evt.WhitelistReason)) + } } } } + return ret } func (p *ParserAssert) buildUnmarshaledAssert(ekey string, eval interface{}) []string { ret := make([]string, 0) + switch val := eval.(type) { case map[string]interface{}: for k, v := range val { @@ -280,250 +309,11 @@ func (p *ParserAssert) buildUnmarshaledAssert(ekey string, eval interface{}) []s case int: ret = append(ret, fmt.Sprintf(`%s == %d`+"\n", ekey, val)) case float64: - ret = append(ret, fmt.Sprintf(`%s == %f`+"\n", ekey, val)) + ret = append(ret, fmt.Sprintf(`FloatApproxEqual(%s, %f)`+"\n", + ekey, val)) default: log.Warningf("unknown type '%T' for key '%s'", val, ekey) } - return ret -} - -func LoadParserDump(filepath string) (*ParserResults, error) { - var pdump ParserResults - - dumpData, err := os.Open(filepath) - if err != nil { - return nil, err - } - defer dumpData.Close() - - results, err := io.ReadAll(dumpData) - if err != nil { - return nil, err - } - if err := yaml.Unmarshal(results, &pdump); err != nil { - return nil, err - } - - /* we know that some variables should always be set, - let's check if they're present in last parser output of last stage */ - stages := make([]string, 0, len(pdump)) - for k := range pdump { - stages = append(stages, k) - } - sort.Strings(stages) - /*the very last one is set to 'success' which is just a bool indicating if the line was successfully parsed*/ - lastStage := stages[len(stages)-2] - - parsers := make([]string, 0, len(pdump[lastStage])) - for k := range pdump[lastStage] { - parsers = append(parsers, k) - } - sort.Strings(parsers) - lastParser := parsers[len(parsers)-1] - - for idx, result := range pdump[lastStage][lastParser] { - if result.Evt.StrTime == "" { - log.Warningf("Line %d/%d is missing evt.StrTime. It is most likely a mistake as it will prevent your logs to be processed in time-machine/forensic mode.", idx, len(pdump[lastStage][lastParser])) - } else { - log.Debugf("Line %d/%d has evt.StrTime set to '%s'", idx, len(pdump[lastStage][lastParser]), result.Evt.StrTime) - } - } - - return &pdump, nil -} - -type DumpOpts struct { - Details bool - SkipOk bool - ShowNotOkParsers bool -} - -func DumpTree(parser_results ParserResults, bucket_pour BucketPourInfo, opts DumpOpts) { - //note : we can use line -> time as the unique identifier (of acquisition) - - state := make(map[time.Time]map[string]map[string]ParserResult) - assoc := make(map[time.Time]string, 0) - - for stage, parsers := range parser_results { - for parser, results := range parsers { - for _, parser_res := range results { - evt := parser_res.Evt - if _, ok := state[evt.Line.Time]; !ok { - state[evt.Line.Time] = make(map[string]map[string]ParserResult) - assoc[evt.Line.Time] = evt.Line.Raw - } - if _, ok := state[evt.Line.Time][stage]; !ok { - state[evt.Line.Time][stage] = make(map[string]ParserResult) - } - state[evt.Line.Time][stage][parser] = ParserResult{Evt: evt, Success: parser_res.Success} - } - - } - } - - for bname, evtlist := range bucket_pour { - for _, evt := range evtlist { - if evt.Line.Raw == "" { - continue - } - //it might be bucket overflow being reprocessed, skip this - if _, ok := state[evt.Line.Time]; !ok { - state[evt.Line.Time] = make(map[string]map[string]ParserResult) - assoc[evt.Line.Time] = evt.Line.Raw - } - //there is a trick : to know if an event successfully exit the parsers, we check if it reached the pour() phase - //we thus use a fake stage "buckets" and a fake parser "OK" to know if it entered - if _, ok := state[evt.Line.Time]["buckets"]; !ok { - state[evt.Line.Time]["buckets"] = make(map[string]ParserResult) - } - state[evt.Line.Time]["buckets"][bname] = ParserResult{Success: true} - } - } - yellow := color.New(color.FgYellow).SprintFunc() - red := color.New(color.FgRed).SprintFunc() - green := color.New(color.FgGreen).SprintFunc() - whitelistReason := "" - //get each line - for tstamp, rawstr := range assoc { - if opts.SkipOk { - if _, ok := state[tstamp]["buckets"]["OK"]; ok { - continue - } - } - fmt.Printf("line: %s\n", rawstr) - skeys := make([]string, 0, len(state[tstamp])) - for k := range state[tstamp] { - //there is a trick : to know if an event successfully exit the parsers, we check if it reached the pour() phase - //we thus use a fake stage "buckets" and a fake parser "OK" to know if it entered - if k == "buckets" { - continue - } - skeys = append(skeys, k) - } - sort.Strings(skeys) - //iterate stage - var prev_item types.Event - - for _, stage := range skeys { - parsers := state[tstamp][stage] - - sep := "├" - presep := "|" - - fmt.Printf("\t%s %s\n", sep, stage) - - pkeys := make([]string, 0, len(parsers)) - for k := range parsers { - pkeys = append(pkeys, k) - } - sort.Strings(pkeys) - - for idx, parser := range pkeys { - res := parsers[parser].Success - sep := "├" - if idx == len(pkeys)-1 { - sep = "└" - } - created := 0 - updated := 0 - deleted := 0 - whitelisted := false - changeStr := "" - detailsDisplay := "" - - if res { - changelog, _ := diff.Diff(prev_item, parsers[parser].Evt) - for _, change := range changelog { - switch change.Type { - case "create": - created++ - detailsDisplay += fmt.Sprintf("\t%s\t\t%s %s evt.%s : %s\n", presep, sep, change.Type, strings.Join(change.Path, "."), green(change.To)) - case "update": - detailsDisplay += fmt.Sprintf("\t%s\t\t%s %s evt.%s : %s -> %s\n", presep, sep, change.Type, strings.Join(change.Path, "."), change.From, yellow(change.To)) - if change.Path[0] == "Whitelisted" && change.To == true { - whitelisted = true - if whitelistReason == "" { - whitelistReason = parsers[parser].Evt.WhitelistReason - } - } - updated++ - case "delete": - deleted++ - detailsDisplay += fmt.Sprintf("\t%s\t\t%s %s evt.%s\n", presep, sep, change.Type, red(strings.Join(change.Path, "."))) - } - } - prev_item = parsers[parser].Evt - } - - if created > 0 { - changeStr += green(fmt.Sprintf("+%d", created)) - } - if updated > 0 { - if len(changeStr) > 0 { - changeStr += " " - } - changeStr += yellow(fmt.Sprintf("~%d", updated)) - } - if deleted > 0 { - if len(changeStr) > 0 { - changeStr += " " - } - changeStr += red(fmt.Sprintf("-%d", deleted)) - } - if whitelisted { - if len(changeStr) > 0 { - changeStr += " " - } - changeStr += red("[whitelisted]") - } - if changeStr == "" { - changeStr = yellow("unchanged") - } - if res { - fmt.Printf("\t%s\t%s %s %s (%s)\n", presep, sep, emoji.GreenCircle, parser, changeStr) - if opts.Details { - fmt.Print(detailsDisplay) - } - } else if opts.ShowNotOkParsers { - fmt.Printf("\t%s\t%s %s %s\n", presep, sep, emoji.RedCircle, parser) - - } - } - } - sep := "└" - if len(state[tstamp]["buckets"]) > 0 { - sep = "├" - } - //did the event enter the bucket pour phase ? - if _, ok := state[tstamp]["buckets"]["OK"]; ok { - fmt.Printf("\t%s-------- parser success %s\n", sep, emoji.GreenCircle) - } else if whitelistReason != "" { - fmt.Printf("\t%s-------- parser success, ignored by whitelist (%s) %s\n", sep, whitelistReason, emoji.GreenCircle) - } else { - fmt.Printf("\t%s-------- parser failure %s\n", sep, emoji.RedCircle) - } - //now print bucket info - if len(state[tstamp]["buckets"]) > 0 { - fmt.Printf("\t├ Scenarios\n") - } - bnames := make([]string, 0, len(state[tstamp]["buckets"])) - for k := range state[tstamp]["buckets"] { - //there is a trick : to know if an event successfully exit the parsers, we check if it reached the pour() phase - //we thus use a fake stage "buckets" and a fake parser "OK" to know if it entered - if k == "OK" { - continue - } - bnames = append(bnames, k) - } - sort.Strings(bnames) - for idx, bname := range bnames { - sep := "├" - if idx == len(bnames)-1 { - sep = "└" - } - fmt.Printf("\t\t%s %s %s\n", sep, emoji.GreenCircle, bname) - } - fmt.Println() - } + return ret } diff --git a/pkg/hubtest/postoverflow.go b/pkg/hubtest/postoverflow.go new file mode 100644 index 00000000000..65fd0bfbc5d --- /dev/null +++ b/pkg/hubtest/postoverflow.go @@ -0,0 +1,100 @@ +package hubtest + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/crowdsecurity/crowdsec/pkg/cwhub" +) + +func (t *HubTestItem) installPostoverflowItem(item *cwhub.Item) error { + sourcePath, err := filepath.Abs(filepath.Join(t.HubPath, item.RemotePath)) + if err != nil { + return fmt.Errorf("can't get absolute path of '%s': %w", sourcePath, err) + } + + sourceFilename := filepath.Base(sourcePath) + + // runtime/hub/postoverflows/s00-enrich/crowdsecurity/ + hubDirPostoverflowDest := filepath.Join(t.RuntimeHubPath, filepath.Dir(item.RemotePath)) + + // runtime/postoverflows/s00-enrich + itemTypeDirDest := fmt.Sprintf("%s/postoverflows/%s/", t.RuntimePath, item.Stage) + + if err := createDirs([]string{hubDirPostoverflowDest, itemTypeDirDest}); err != nil { + return err + } + + // runtime/hub/postoverflows/s00-enrich/crowdsecurity/rdns.yaml + hubDirPostoverflowPath := filepath.Join(hubDirPostoverflowDest, sourceFilename) + if err := Copy(sourcePath, hubDirPostoverflowPath); err != nil { + return fmt.Errorf("unable to copy '%s' to '%s': %w", sourcePath, hubDirPostoverflowPath, err) + } + + // runtime/postoverflows/s00-enrich/rdns.yaml + postoverflowDirParserPath := filepath.Join(itemTypeDirDest, sourceFilename) + if err := os.Symlink(hubDirPostoverflowPath, postoverflowDirParserPath); err != nil { + if !os.IsExist(err) { + return fmt.Errorf("unable to symlink postoverflow '%s' to '%s': %w", hubDirPostoverflowPath, postoverflowDirParserPath, err) + } + } + + return nil +} + +func (t *HubTestItem) installPostoverflowCustomFrom(postoverflow string, customPath string) (bool, error) { + // we check if its a custom postoverflow + customPostOverflowPath := filepath.Join(customPath, postoverflow) + if _, err := os.Stat(customPostOverflowPath); os.IsNotExist(err) { + return false, nil + } + + customPostOverflowPathSplit := strings.Split(customPostOverflowPath, "/") + customPostoverflowName := customPostOverflowPathSplit[len(customPostOverflowPathSplit)-1] + // because path is postoverflows///parser.yaml and we wan't the stage + customPostoverflowStage := customPostOverflowPathSplit[len(customPostOverflowPathSplit)-3] + + // check if stage exist + hubStagePath := filepath.Join(t.HubPath, fmt.Sprintf("postoverflows/%s", customPostoverflowStage)) + if _, err := os.Stat(hubStagePath); os.IsNotExist(err) { + return false, fmt.Errorf("stage '%s' from extracted '%s' doesn't exist in the hub", customPostoverflowStage, hubStagePath) + } + + stageDirDest := fmt.Sprintf("%s/postoverflows/%s/", t.RuntimePath, customPostoverflowStage) + if err := os.MkdirAll(stageDirDest, os.ModePerm); err != nil { + return false, fmt.Errorf("unable to create folder '%s': %w", stageDirDest, err) + } + + customPostoverflowDest := filepath.Join(stageDirDest, customPostoverflowName) + // if path to postoverflow exist, copy it + if err := Copy(customPostOverflowPath, customPostoverflowDest); err != nil { + return false, fmt.Errorf("unable to copy custom parser '%s' to '%s': %w", customPostOverflowPath, customPostoverflowDest, err) + } + + return true, nil +} + +func (t *HubTestItem) installPostoverflowCustom(postoverflow string) error { + for _, customPath := range t.CustomItemsLocation { + found, err := t.installPostoverflowCustomFrom(postoverflow, customPath) + if err != nil { + return err + } + + if found { + return nil + } + } + + return fmt.Errorf("couldn't find custom postoverflow '%s' in the following location: %+v", postoverflow, t.CustomItemsLocation) +} + +func (t *HubTestItem) installPostoverflow(name string) error { + if hubPostOverflow := t.HubIndex.GetItem(cwhub.POSTOVERFLOWS, name); hubPostOverflow != nil { + return t.installPostoverflowItem(hubPostOverflow) + } + + return t.installPostoverflowCustom(name) +} diff --git a/pkg/hubtest/regexp.go b/pkg/hubtest/regexp.go new file mode 100644 index 00000000000..8b2fcc928dd --- /dev/null +++ b/pkg/hubtest/regexp.go @@ -0,0 +1,11 @@ +package hubtest + +import ( + "regexp" +) + +var ( + variableRE = regexp.MustCompile(`(?P[^ =]+) == .*`) + parserResultRE = regexp.MustCompile(`^results\["[^"]+"\]\["(?P[^"]+)"\]\[[0-9]+\]\.Evt\..*`) + scenarioResultRE = regexp.MustCompile(`^results\[[0-9]+\].Overflow.Alert.GetScenario\(\) == "(?P[^"]+)"`) +) diff --git a/pkg/hubtest/scenario.go b/pkg/hubtest/scenario.go new file mode 100644 index 00000000000..7f61e48accf --- /dev/null +++ b/pkg/hubtest/scenario.go @@ -0,0 +1,89 @@ +package hubtest + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/crowdsecurity/crowdsec/pkg/cwhub" +) + +func (t *HubTestItem) installScenarioItem(item *cwhub.Item) error { + sourcePath, err := filepath.Abs(filepath.Join(t.HubPath, item.RemotePath)) + if err != nil { + return fmt.Errorf("can't get absolute path of '%s': %w", sourcePath, err) + } + + sourceFilename := filepath.Base(sourcePath) + + // runtime/hub/scenarios/crowdsecurity/ + hubDirScenarioDest := filepath.Join(t.RuntimeHubPath, filepath.Dir(item.RemotePath)) + + // runtime/parsers/scenarios/ + itemTypeDirDest := fmt.Sprintf("%s/scenarios/", t.RuntimePath) + + if err := createDirs([]string{hubDirScenarioDest, itemTypeDirDest}); err != nil { + return err + } + + // runtime/hub/scenarios/crowdsecurity/ssh-bf.yaml + hubDirScenarioPath := filepath.Join(hubDirScenarioDest, sourceFilename) + if err := Copy(sourcePath, hubDirScenarioPath); err != nil { + return fmt.Errorf("unable to copy '%s' to '%s': %w", sourcePath, hubDirScenarioPath, err) + } + + // runtime/scenarios/ssh-bf.yaml + scenarioDirParserPath := filepath.Join(itemTypeDirDest, sourceFilename) + if err := os.Symlink(hubDirScenarioPath, scenarioDirParserPath); err != nil { + if !os.IsExist(err) { + return fmt.Errorf("unable to symlink scenario '%s' to '%s': %w", hubDirScenarioPath, scenarioDirParserPath, err) + } + } + + return nil +} + +func (t *HubTestItem) installScenarioCustomFrom(scenario string, customPath string) (bool, error) { + // we check if its a custom scenario + customScenarioPath := filepath.Join(customPath, scenario) + if _, err := os.Stat(customScenarioPath); os.IsNotExist(err) { + return false, nil + } + + itemTypeDirDest := fmt.Sprintf("%s/scenarios/", t.RuntimePath) + if err := os.MkdirAll(itemTypeDirDest, os.ModePerm); err != nil { + return false, fmt.Errorf("unable to create folder '%s': %w", itemTypeDirDest, err) + } + + scenarioFileName := filepath.Base(customScenarioPath) + + scenarioFileDest := filepath.Join(itemTypeDirDest, scenarioFileName) + if err := Copy(customScenarioPath, scenarioFileDest); err != nil { + return false, fmt.Errorf("unable to copy scenario from '%s' to '%s': %w", customScenarioPath, scenarioFileDest, err) + } + + return true, nil +} + +func (t *HubTestItem) installScenarioCustom(scenario string) error { + for _, customPath := range t.CustomItemsLocation { + found, err := t.installScenarioCustomFrom(scenario, customPath) + if err != nil { + return err + } + + if found { + return nil + } + } + + return fmt.Errorf("couldn't find custom scenario '%s' in the following location: %+v", scenario, t.CustomItemsLocation) +} + +func (t *HubTestItem) installScenario(name string) error { + if item := t.HubIndex.GetItem(cwhub.SCENARIOS, name); item != nil { + return t.installScenarioItem(item) + } + + return t.installScenarioCustom(name) +} diff --git a/pkg/hubtest/scenario_assert.go b/pkg/hubtest/scenario_assert.go index 2e2a4e9c8be..f32abf9e110 100644 --- a/pkg/hubtest/scenario_assert.go +++ b/pkg/hubtest/scenario_assert.go @@ -2,18 +2,18 @@ package hubtest import ( "bufio" + "errors" "fmt" "io" "os" - "regexp" "sort" "strings" - "github.com/antonmedv/expr" - "github.com/antonmedv/expr/vm" + "github.com/expr-lang/expr" log "github.com/sirupsen/logrus" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" + "github.com/crowdsecurity/crowdsec/pkg/dumps" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -26,11 +26,10 @@ type ScenarioAssert struct { Fails []AssertFail Success bool TestData *BucketResults - PourData *BucketPourInfo + PourData *dumps.BucketPourInfo } type BucketResults []types.Event -type BucketPourInfo map[string][]types.Event func NewScenarioAssert(file string) *ScenarioAssert { ScenarioAssert := &ScenarioAssert{ @@ -40,8 +39,9 @@ func NewScenarioAssert(file string) *ScenarioAssert { Fails: make([]AssertFail, 0), AutoGenAssert: false, TestData: &BucketResults{}, - PourData: &BucketPourInfo{}, + PourData: &dumps.BucketPourInfo{}, } + return ScenarioAssert } @@ -50,51 +50,61 @@ func (s *ScenarioAssert) AutoGenFromFile(filename string) (string, error) { if err != nil { return "", err } + ret := s.AutoGenScenarioAssert() + return ret, nil } func (s *ScenarioAssert) LoadTest(filename string, bucketpour string) error { - var err error bucketDump, err := LoadScenarioDump(filename) if err != nil { return fmt.Errorf("loading scenario dump file '%s': %+v", filename, err) } + s.TestData = bucketDump if bucketpour != "" { - pourDump, err := LoadBucketPourDump(bucketpour) + pourDump, err := dumps.LoadBucketPourDump(bucketpour) if err != nil { return fmt.Errorf("loading bucket pour dump file '%s': %+v", filename, err) } + s.PourData = pourDump } + return nil } func (s *ScenarioAssert) AssertFile(testFile string) error { file, err := os.Open(s.File) - if err != nil { - return fmt.Errorf("failed to open") + return errors.New("failed to open") } if err := s.LoadTest(testFile, ""); err != nil { - return fmt.Errorf("unable to load parser dump file '%s': %s", testFile, err) + return fmt.Errorf("unable to load parser dump file '%s': %w", testFile, err) } + scanner := bufio.NewScanner(file) scanner.Split(bufio.ScanLines) + nbLine := 0 + for scanner.Scan() { - nbLine += 1 + nbLine++ + if scanner.Text() == "" { continue } + ok, err := s.Run(scanner.Text()) if err != nil { return fmt.Errorf("unable to run assert '%s': %+v", scanner.Text(), err) } - s.NbAssert += 1 + + s.NbAssert++ + if !ok { log.Debugf("%s is FALSE", scanner.Text()) failedAssert := &AssertFail{ @@ -103,31 +113,38 @@ func (s *ScenarioAssert) AssertFile(testFile string) error { Expression: scanner.Text(), Debug: make(map[string]string), } - variableRE := regexp.MustCompile(`(?P[^ ]+) == .*`) + match := variableRE.FindStringSubmatch(scanner.Text()) + if len(match) == 0 { log.Infof("Couldn't get variable of line '%s'", scanner.Text()) continue } + variable := match[1] + result, err := s.EvalExpression(variable) if err != nil { log.Errorf("unable to evaluate variable '%s': %s", variable, err) continue } + failedAssert.Debug[variable] = result s.Fails = append(s.Fails, *failedAssert) + continue } - //fmt.Printf(" %s '%s'\n", emoji.GreenSquare, scanner.Text()) - + // fmt.Printf(" %s '%s'\n", emoji.GreenSquare, scanner.Text()) } + file.Close() + if s.NbAssert == 0 { assertData, err := s.AutoGenFromFile(testFile) if err != nil { - return fmt.Errorf("couldn't generate assertion: %s", err) + return fmt.Errorf("couldn't generate assertion: %w", err) } + s.AutoGenAssertData = assertData s.AutoGenAssert = true } @@ -140,30 +157,31 @@ func (s *ScenarioAssert) AssertFile(testFile string) error { } func (s *ScenarioAssert) RunExpression(expression string) (interface{}, error) { - var err error - //debug doesn't make much sense with the ability to evaluate "on the fly" - //var debugFilter *exprhelpers.ExprDebugger - var runtimeFilter *vm.Program + // debug doesn't make much sense with the ability to evaluate "on the fly" + // var debugFilter *exprhelpers.ExprDebugger var output interface{} env := map[string]interface{}{"results": *s.TestData} - if runtimeFilter, err = expr.Compile(expression, exprhelpers.GetExprOptions(env)...); err != nil { + runtimeFilter, err := expr.Compile(expression, exprhelpers.GetExprOptions(env)...) + if err != nil { return nil, err } // if debugFilter, err = exprhelpers.NewDebugger(assert, expr.Env(env)); err != nil { // log.Warningf("Failed building debugher for %s : %s", assert, err) // } - //dump opcode in trace level + // dump opcode in trace level log.Tracef("%s", runtimeFilter.Disassemble()) output, err = expr.Run(runtimeFilter, map[string]interface{}{"results": *s.TestData}) if err != nil { log.Warningf("running : %s", expression) log.Warningf("runtime error : %s", err) + return nil, fmt.Errorf("while running expression %s: %w", expression, err) } + return output, nil } @@ -172,10 +190,12 @@ func (s *ScenarioAssert) EvalExpression(expression string) (string, error) { if err != nil { return "", err } + ret, err := yaml.Marshal(output) if err != nil { return "", err } + return string(ret), nil } @@ -184,6 +204,7 @@ func (s *ScenarioAssert) Run(assert string) (bool, error) { if err != nil { return false, err } + switch out := output.(type) { case bool: return out, nil @@ -193,9 +214,9 @@ func (s *ScenarioAssert) Run(assert string) (bool, error) { } func (s *ScenarioAssert) AutoGenScenarioAssert() string { - //attempt to autogen parser asserts - var ret string - ret += fmt.Sprintf(`len(results) == %d`+"\n", len(*s.TestData)) + // attempt to autogen scenario asserts + ret := fmt.Sprintf(`len(results) == %d`+"\n", len(*s.TestData)) + for eventIndex, event := range *s.TestData { for ipSrc, source := range event.Overflow.Sources { ret += fmt.Sprintf(`"%s" in results[%d].Overflow.GetSources()`+"\n", ipSrc, eventIndex) @@ -204,15 +225,18 @@ func (s *ScenarioAssert) AutoGenScenarioAssert() string { ret += fmt.Sprintf(`results[%d].Overflow.Sources["%s"].GetScope() == "%s"`+"\n", eventIndex, ipSrc, *source.Scope) ret += fmt.Sprintf(`results[%d].Overflow.Sources["%s"].GetValue() == "%s"`+"\n", eventIndex, ipSrc, *source.Value) } + for evtIndex, evt := range event.Overflow.Alert.Events { for _, meta := range evt.Meta { - ret += fmt.Sprintf(`results[%d].Overflow.Alert.Events[%d].GetMeta("%s") == "%s"`+"\n", eventIndex, evtIndex, meta.Key, meta.Value) + ret += fmt.Sprintf(`results[%d].Overflow.Alert.Events[%d].GetMeta("%s") == "%s"`+"\n", eventIndex, evtIndex, meta.Key, Escape(meta.Value)) } } + ret += fmt.Sprintf(`results[%d].Overflow.Alert.GetScenario() == "%s"`+"\n", eventIndex, *event.Overflow.Alert.Scenario) ret += fmt.Sprintf(`results[%d].Overflow.Alert.Remediation == %t`+"\n", eventIndex, event.Overflow.Alert.Remediation) ret += fmt.Sprintf(`results[%d].Overflow.Alert.GetEventsCount() == %d`+"\n", eventIndex, *event.Overflow.Alert.EventsCount) } + return ret } @@ -228,9 +252,7 @@ func (b BucketResults) Swap(i, j int) { b[i], b[j] = b[j], b[i] } -func LoadBucketPourDump(filepath string) (*BucketPourInfo, error) { - var bucketDump BucketPourInfo - +func LoadScenarioDump(filepath string) (*BucketResults, error) { dumpData, err := os.Open(filepath) if err != nil { return nil, err @@ -242,27 +264,8 @@ func LoadBucketPourDump(filepath string) (*BucketPourInfo, error) { return nil, err } - if err := yaml.Unmarshal(results, &bucketDump); err != nil { - return nil, err - } - - return &bucketDump, nil -} - -func LoadScenarioDump(filepath string) (*BucketResults, error) { var bucketDump BucketResults - dumpData, err := os.Open(filepath) - if err != nil { - return nil, err - } - defer dumpData.Close() - - results, err := io.ReadAll(dumpData) - if err != nil { - return nil, err - } - if err := yaml.Unmarshal(results, &bucketDump); err != nil { return nil, err } diff --git a/pkg/hubtest/utils.go b/pkg/hubtest/utils.go index 73de3510b9d..b42a73461f3 100644 --- a/pkg/hubtest/utils.go +++ b/pkg/hubtest/utils.go @@ -1,21 +1,43 @@ package hubtest import ( + "errors" "fmt" + "net" "os" "path/filepath" + "time" + + log "github.com/sirupsen/logrus" ) -func Copy(sourceFile string, destinationFile string) error { - input, err := os.ReadFile(sourceFile) +func IsAlive(target string) (bool, error) { + start := time.Now() + for { + conn, err := net.Dial("tcp", target) + if err == nil { + log.Debugf("'%s' is up after %s", target, time.Since(start)) + conn.Close() + return true, nil + } + time.Sleep(500 * time.Millisecond) + if time.Since(start) > 10*time.Second { + return false, fmt.Errorf("took more than 10s for %s to be available", target) + } + } +} + +func Copy(src string, dst string) error { + content, err := os.ReadFile(src) if err != nil { return err } - err = os.WriteFile(destinationFile, input, 0644) + err = os.WriteFile(dst, content, 0o644) if err != nil { return err } + return nil } @@ -32,16 +54,20 @@ func checkPathNotContained(path string, subpath string) error { } current := absSubPath + for { if current == absPath { - return fmt.Errorf("cannot copy a folder onto itself") + return errors.New("cannot copy a folder onto itself") } + up := filepath.Dir(current) if current == up { break } + current = up } + return nil } @@ -60,11 +86,12 @@ func CopyDir(src string, dest string) error { if err != nil { return err } + if !file.IsDir() { - return fmt.Errorf("Source " + file.Name() + " is not a directory!") + return errors.New("Source " + file.Name() + " is not a directory!") } - err = os.MkdirAll(dest, 0755) + err = os.MkdirAll(dest, 0o755) if err != nil { return err } @@ -75,32 +102,15 @@ func CopyDir(src string, dest string) error { } for _, f := range files { - if f.IsDir() { - - err = CopyDir(src+"/"+f.Name(), dest+"/"+f.Name()) - if err != nil { + if err = CopyDir(filepath.Join(src, f.Name()), filepath.Join(dest, f.Name())); err != nil { return err } - - } - - if !f.IsDir() { - - content, err := os.ReadFile(src + "/" + f.Name()) - if err != nil { + } else { + if err = Copy(filepath.Join(src, f.Name()), filepath.Join(dest, f.Name())); err != nil { return err - } - - err = os.WriteFile(dest+"/"+f.Name(), content, 0755) - if err != nil { - return err - - } - } - } return nil diff --git a/pkg/hubtest/utils_test.go b/pkg/hubtest/utils_test.go index de4f1aac386..ce86785af9e 100644 --- a/pkg/hubtest/utils_test.go +++ b/pkg/hubtest/utils_test.go @@ -3,16 +3,16 @@ package hubtest import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestCheckPathNotContained(t *testing.T) { - assert.Nil(t, checkPathNotContained("/foo", "/bar")) - assert.Nil(t, checkPathNotContained("/foo/bar", "/foo")) - assert.Nil(t, checkPathNotContained("/foo/bar", "/")) - assert.Nil(t, checkPathNotContained("/path/to/somewhere", "/path/to/somewhere-else")) - assert.Nil(t, checkPathNotContained("~/.local/path/to/somewhere", "~/.local/path/to/somewhere-else")) - assert.NotNil(t, checkPathNotContained("/foo", "/foo/bar")) - assert.NotNil(t, checkPathNotContained("/", "/foo")) - assert.NotNil(t, checkPathNotContained("/", "/foo/bar/baz")) + require.NoError(t, checkPathNotContained("/foo", "/bar")) + require.NoError(t, checkPathNotContained("/foo/bar", "/foo")) + require.NoError(t, checkPathNotContained("/foo/bar", "/")) + require.NoError(t, checkPathNotContained("/path/to/somewhere", "/path/to/somewhere-else")) + require.NoError(t, checkPathNotContained("~/.local/path/to/somewhere", "~/.local/path/to/somewhere-else")) + require.Error(t, checkPathNotContained("/foo", "/foo/bar")) + require.Error(t, checkPathNotContained("/", "/foo")) + require.Error(t, checkPathNotContained("/", "/foo/bar/baz")) } diff --git a/pkg/leakybucket/bayesian.go b/pkg/leakybucket/bayesian.go index bd9aaed96b4..357d51f597b 100644 --- a/pkg/leakybucket/bayesian.go +++ b/pkg/leakybucket/bayesian.go @@ -3,8 +3,9 @@ package leakybucket import ( "fmt" - "github.com/antonmedv/expr" - "github.com/antonmedv/expr/vm" + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" + "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -107,9 +108,9 @@ func (b *BayesianEvent) bayesianUpdate(c *BayesianBucket, msg types.Event, l *Le } l.logger.Debugf("running condition expression: %s", b.rawCondition.ConditionalFilterName) - ret, err := expr.Run(b.conditionalFilterRuntime, map[string]interface{}{"evt": &msg, "queue": l.Queue, "leaky": l}) + ret, err := exprhelpers.Run(b.conditionalFilterRuntime, map[string]interface{}{"evt": &msg, "queue": l.Queue, "leaky": l}, l.logger, l.BucketConfig.Debug) if err != nil { - return fmt.Errorf("unable to run conditional filter: %s", err) + return fmt.Errorf("unable to run conditional filter: %w", err) } l.logger.Tracef("bayesian bucket expression %s returned : %v", b.rawCondition.ConditionalFilterName, ret) @@ -151,7 +152,7 @@ func (b *BayesianEvent) compileCondition() error { conditionalExprCacheLock.Unlock() //release the lock during compile same as coditional bucket - compiledExpr, err = expr.Compile(b.rawCondition.ConditionalFilterName, exprhelpers.GetExprOptions(map[string]interface{}{"queue": &Queue{}, "leaky": &Leaky{}, "evt": &types.Event{}})...) + compiledExpr, err = expr.Compile(b.rawCondition.ConditionalFilterName, exprhelpers.GetExprOptions(map[string]interface{}{"queue": &types.Queue{}, "leaky": &Leaky{}, "evt": &types.Event{}})...) if err != nil { return fmt.Errorf("bayesian condition compile error: %w", err) } diff --git a/pkg/leakybucket/blackhole.go b/pkg/leakybucket/blackhole.go index 3a2740c465d..b12f169acd9 100644 --- a/pkg/leakybucket/blackhole.go +++ b/pkg/leakybucket/blackhole.go @@ -31,8 +31,8 @@ func NewBlackhole(bucketFactory *BucketFactory) (*Blackhole, error) { }, nil } -func (bl *Blackhole) OnBucketOverflow(bucketFactory *BucketFactory) func(*Leaky, types.RuntimeAlert, *Queue) (types.RuntimeAlert, *Queue) { - return func(leaky *Leaky, alert types.RuntimeAlert, queue *Queue) (types.RuntimeAlert, *Queue) { +func (bl *Blackhole) OnBucketOverflow(bucketFactory *BucketFactory) func(*Leaky, types.RuntimeAlert, *types.Queue) (types.RuntimeAlert, *types.Queue) { + return func(leaky *Leaky, alert types.RuntimeAlert, queue *types.Queue) (types.RuntimeAlert, *types.Queue) { var blackholed = false var tmp []HiddenKey // search if we are blackholed and refresh the slice diff --git a/pkg/leakybucket/bucket.go b/pkg/leakybucket/bucket.go index 4589be32aff..e981551af8f 100644 --- a/pkg/leakybucket/bucket.go +++ b/pkg/leakybucket/bucket.go @@ -6,15 +6,16 @@ import ( "sync/atomic" "time" - "github.com/crowdsecurity/go-cs-lib/pkg/trace" - - "github.com/crowdsecurity/crowdsec/pkg/time/rate" - "github.com/crowdsecurity/crowdsec/pkg/types" "github.com/davecgh/go-spew/spew" "github.com/mohae/deepcopy" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" "gopkg.in/tomb.v2" + + "github.com/crowdsecurity/go-cs-lib/trace" + + "github.com/crowdsecurity/crowdsec/pkg/time/rate" + "github.com/crowdsecurity/crowdsec/pkg/types" ) // those constants are now defined in types/constants @@ -30,13 +31,13 @@ type Leaky struct { //the limiter is what holds the proper "leaky aspect", it determines when/if we can pour objects Limiter rate.RateLimiter `json:"-"` SerializedState rate.Lstate - //Queue is used to held the cache of objects in the bucket, it is used to know 'how many' objects we have in buffer. - Queue *Queue + //Queue is used to hold the cache of objects in the bucket, it is used to know 'how many' objects we have in buffer. + Queue *types.Queue //Leaky buckets are receiving message through a chan In chan *types.Event `json:"-"` //Leaky buckets are pushing their overflows through a chan - Out chan *Queue `json:"-"` - // shared for all buckets (the idea is to kill this afterwards) + Out chan *types.Queue `json:"-"` + // shared for all buckets (the idea is to kill this afterward) AllOut chan types.Event `json:"-"` //max capacity (for burst) Capacity int @@ -159,9 +160,9 @@ func FromFactory(bucketFactory BucketFactory) *Leaky { Name: bucketFactory.Name, Limiter: limiter, Uuid: seed.Generate(), - Queue: NewQueue(Qsize), + Queue: types.NewQueue(Qsize), CacheSize: bucketFactory.CacheSize, - Out: make(chan *Queue, 1), + Out: make(chan *types.Queue, 1), Suicide: make(chan bool, 1), AllOut: bucketFactory.ret, Capacity: bucketFactory.Capacity, @@ -216,7 +217,7 @@ func LeakRoutine(leaky *Leaky) error { defer BucketsCurrentCount.With(prometheus.Labels{"name": leaky.Name}).Dec() /*todo : we create a logger at runtime while we want leakroutine to be up asap, might not be a good idea*/ - leaky.logger = leaky.BucketConfig.logger.WithFields(log.Fields{"capacity": leaky.Capacity, "partition": leaky.Mapkey, "bucket_id": leaky.Uuid}) + leaky.logger = leaky.BucketConfig.logger.WithFields(log.Fields{"partition": leaky.Mapkey, "bucket_id": leaky.Uuid}) //We copy the processors, as they are coming from the BucketFactory, and thus are shared between buckets //If we don't copy, processors using local cache (such as Uniq) are subject to race conditions @@ -332,7 +333,7 @@ func LeakRoutine(leaky *Leaky) error { } if leaky.logger.Level >= log.TraceLevel { - /*don't sdump if it's not going to printed, it's expensive*/ + /*don't sdump if it's not going to be printed, it's expensive*/ leaky.logger.Tracef("Overflow event: %s", spew.Sdump(types.Event{Overflow: alert})) } @@ -374,7 +375,7 @@ func Pour(leaky *Leaky, msg types.Event) { } } -func (leaky *Leaky) overflow(ofw *Queue) { +func (leaky *Leaky) overflow(ofw *types.Queue) { close(leaky.Signal) alert, err := NewAlert(leaky, ofw) if err != nil { diff --git a/pkg/leakybucket/buckets_test.go b/pkg/leakybucket/buckets_test.go index e08887be818..1da906cb555 100644 --- a/pkg/leakybucket/buckets_test.go +++ b/pkg/leakybucket/buckets_test.go @@ -8,19 +8,23 @@ import ( "html/template" "io" "os" + "path/filepath" "reflect" "sync" "testing" "time" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" - "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" - "github.com/crowdsecurity/crowdsec/pkg/parser" - "github.com/crowdsecurity/crowdsec/pkg/types" "github.com/davecgh/go-spew/spew" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" "gopkg.in/tomb.v2" yaml "gopkg.in/yaml.v2" + + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" + "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" + "github.com/crowdsecurity/crowdsec/pkg/parser" + "github.com/crowdsecurity/crowdsec/pkg/types" ) type TestFile struct { @@ -33,33 +37,57 @@ func TestBucket(t *testing.T) { envSetting = os.Getenv("TEST_ONLY") tomb = &tomb.Tomb{} ) - err := exprhelpers.Init(nil) + + testdata := "./tests" + + hubCfg := &csconfig.LocalHubCfg{ + HubDir: filepath.Join(testdata, "hub"), + HubIndexFile: filepath.Join(testdata, "hub", "index.json"), + InstallDataDir: testdata, + } + + hub, err := cwhub.NewHub(hubCfg, nil, nil) + require.NoError(t, err) + + err = hub.Load() + require.NoError(t, err) + + err = exprhelpers.Init(nil) if err != nil { - log.Fatalf("exprhelpers init failed: %s", err) + t.Fatalf("exprhelpers init failed: %s", err) } if envSetting != "" { - if err := testOneBucket(t, envSetting, tomb); err != nil { + if err := testOneBucket(t, hub, envSetting, tomb); err != nil { t.Fatalf("Test '%s' failed : %s", envSetting, err) } } else { wg := new(sync.WaitGroup) - fds, err := os.ReadDir("./tests/") + + fds, err := os.ReadDir(testdata) if err != nil { t.Fatalf("Unable to read test directory : %s", err) } + for _, fd := range fds { - fname := "./tests/" + fd.Name() + if fd.Name() == "hub" { + continue + } + + fname := filepath.Join(testdata, fd.Name()) log.Infof("Running test on %s", fname) tomb.Go(func() error { wg.Add(1) defer wg.Done() - if err := testOneBucket(t, fname, tomb); err != nil { + + if err := testOneBucket(t, hub, fname, tomb); err != nil { t.Fatalf("Test '%s' failed : %s", fname, err) } + return nil }) } + wg.Wait() } } @@ -68,16 +96,16 @@ func TestBucket(t *testing.T) { // we want to avoid the death of the tomb because all existing buckets have been destroyed. func watchTomb(tomb *tomb.Tomb) { for { - if tomb.Alive() == false { + if !tomb.Alive() { log.Warning("Tomb is dead") break } + time.Sleep(100 * time.Millisecond) } } -func testOneBucket(t *testing.T, dir string, tomb *tomb.Tomb) error { - +func testOneBucket(t *testing.T, hub *cwhub.Hub, dir string, tomb *tomb.Tomb) error { var ( holders []BucketFactory @@ -85,9 +113,9 @@ func testOneBucket(t *testing.T, dir string, tomb *tomb.Tomb) error { stagecfg string stages []parser.Stagefile err error - buckets *Buckets ) - buckets = NewBuckets() + + buckets := NewBuckets() /*load the scenarios*/ stagecfg = dir + "/scenarios.yaml" @@ -97,53 +125,59 @@ func testOneBucket(t *testing.T, dir string, tomb *tomb.Tomb) error { tmpl, err := template.New("test").Parse(string(stagefiles)) if err != nil { - return fmt.Errorf("failed to parse template %s : %s", stagefiles, err) + return fmt.Errorf("failed to parse template %s: %w", stagefiles, err) } + var out bytes.Buffer + err = tmpl.Execute(&out, map[string]string{"TestDirectory": dir}) if err != nil { panic(err) } + if err := yaml.UnmarshalStrict(out.Bytes(), &stages); err != nil { - log.Fatalf("failed unmarshaling %s : %s", stagecfg, err) + t.Fatalf("failed to parse %s : %s", stagecfg, err) } + files := []string{} for _, x := range stages { files = append(files, x.Filename) } - cscfg := &csconfig.CrowdsecServiceCfg{ - DataDir: "tests", - } - holders, response, err := LoadBuckets(cscfg, files, tomb, buckets, false) + cscfg := &csconfig.CrowdsecServiceCfg{} + + holders, response, err := LoadBuckets(cscfg, hub, files, tomb, buckets, false) if err != nil { t.Fatalf("failed loading bucket : %s", err) } + tomb.Go(func() error { watchTomb(tomb) return nil }) - if !testFile(t, dir+"/test.json", dir+"/in-buckets_state.json", holders, response, buckets) { + + if !testFile(t, filepath.Join(dir, "test.json"), filepath.Join(dir, "in-buckets_state.json"), holders, response, buckets) { return fmt.Errorf("tests from %s failed", dir) } + return nil } func testFile(t *testing.T, file string, bs string, holders []BucketFactory, response chan types.Event, buckets *Buckets) bool { - var results []types.Event var dump bool - //should we restore + // should we restore if _, err := os.Stat(bs); err == nil { dump = true + if err := LoadBucketsState(bs, buckets, holders); err != nil { t.Fatalf("Failed to load bucket state : %s", err) } } /* now we can load the test files */ - //process the yaml + // process the yaml yamlFile, err := os.Open(file) if err != nil { t.Errorf("yamlFile.Get err #%v ", err) @@ -165,9 +199,11 @@ func testFile(t *testing.T, file string, bs string, holders []BucketFactory, res //just to avoid any race during ingestion of funny scenarios time.Sleep(50 * time.Millisecond) var ts time.Time + if err := ts.UnmarshalText([]byte(in.MarshaledTime)); err != nil { - t.Fatalf("Failed to unmarshal time from input event : %s", err) + t.Fatalf("Failed to parse time from input event : %s", err) } + if latest_ts.IsZero() { latest_ts = ts } else if ts.After(latest_ts) { @@ -176,10 +212,12 @@ func testFile(t *testing.T, file string, bs string, holders []BucketFactory, res in.ExpectMode = types.TIMEMACHINE log.Infof("Buckets input : %s", spew.Sdump(in)) + ok, err := PourItemToHolders(in, holders, buckets) if err != nil { t.Fatalf("Failed to pour : %s", err) } + if !ok { log.Warning("Event wasn't poured") } diff --git a/pkg/leakybucket/conditional.go b/pkg/leakybucket/conditional.go index cd5ec40a0a7..a203a639743 100644 --- a/pkg/leakybucket/conditional.go +++ b/pkg/leakybucket/conditional.go @@ -4,8 +4,9 @@ import ( "fmt" "sync" - "github.com/antonmedv/expr" - "github.com/antonmedv/expr/vm" + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" + "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -33,7 +34,7 @@ func (c *ConditionalOverflow) OnBucketInit(g *BucketFactory) error { } else { conditionalExprCacheLock.Unlock() //release the lock during compile - compiledExpr, err = expr.Compile(g.ConditionalOverflow, exprhelpers.GetExprOptions(map[string]interface{}{"queue": &Queue{}, "leaky": &Leaky{}, "evt": &types.Event{}})...) + compiledExpr, err = expr.Compile(g.ConditionalOverflow, exprhelpers.GetExprOptions(map[string]interface{}{"queue": &types.Queue{}, "leaky": &Leaky{}, "evt": &types.Event{}})...) if err != nil { return fmt.Errorf("conditional compile error : %w", err) } @@ -50,12 +51,14 @@ func (c *ConditionalOverflow) AfterBucketPour(b *BucketFactory) func(types.Event var condition, ok bool if c.ConditionalFilterRuntime != nil { l.logger.Debugf("Running condition expression : %s", c.ConditionalFilter) - ret, err := expr.Run(c.ConditionalFilterRuntime, map[string]interface{}{"evt": &msg, "queue": l.Queue, "leaky": l}) + + ret, err := exprhelpers.Run(c.ConditionalFilterRuntime, + map[string]interface{}{"evt": &msg, "queue": l.Queue, "leaky": l}, + l.logger, b.Debug) if err != nil { l.logger.Errorf("unable to run conditional filter : %s", err) return &msg } - l.logger.Debugf("Conditional bucket expression returned : %v", ret) if condition, ok = ret.(bool); !ok { diff --git a/pkg/leakybucket/manager_load.go b/pkg/leakybucket/manager_load.go index 337f6fd3f6e..b8310b8cb17 100644 --- a/pkg/leakybucket/manager_load.go +++ b/pkg/leakybucket/manager_load.go @@ -11,9 +11,9 @@ import ( "sync" "time" - "github.com/antonmedv/expr" - "github.com/antonmedv/expr/vm" "github.com/davecgh/go-spew/spew" + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" "github.com/goombaio/namegenerator" log "github.com/sirupsen/logrus" "gopkg.in/tomb.v2" @@ -22,7 +22,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/alertcontext" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/cwhub" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" + "github.com/crowdsecurity/crowdsec/pkg/cwversion/constraint" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -30,120 +30,174 @@ import ( // BucketFactory struct holds all fields for any bucket configuration. This is to have a // generic struct for buckets. This can be seen as a bucket factory. type BucketFactory struct { - FormatVersion string `yaml:"format"` - Author string `yaml:"author"` - Description string `yaml:"description"` - References []string `yaml:"references"` - Type string `yaml:"type"` //Type can be : leaky, counter, trigger. It determines the main bucket characteristics - Name string `yaml:"name"` //Name of the bucket, used later in log and user-messages. Should be unique - Capacity int `yaml:"capacity"` //Capacity is applicable to leaky buckets and determines the "burst" capacity - LeakSpeed string `yaml:"leakspeed"` //Leakspeed is a float representing how many events per second leak out of the bucket - Duration string `yaml:"duration"` //Duration allows 'counter' buckets to have a fixed life-time - Filter string `yaml:"filter"` //Filter is an expr that determines if an event is elligible for said bucket. Filter is evaluated against the Event struct - GroupBy string `yaml:"groupby,omitempty"` //groupy is an expr that allows to determine the partitions of the bucket. A common example is the source_ip - Distinct string `yaml:"distinct"` //Distinct, when present, adds a `Pour()` processor that will only pour uniq items (based on distinct expr result) - Debug bool `yaml:"debug"` //Debug, when set to true, will enable debugging for _this_ scenario specifically - Labels map[string]interface{} `yaml:"labels"` //Labels is K:V list aiming at providing context the overflow - Blackhole string `yaml:"blackhole,omitempty"` //Blackhole is a duration that, if present, will prevent same bucket partition to overflow more often than $duration - logger *log.Entry `yaml:"-"` //logger is bucket-specific logger (used by Debug as well) - Reprocess bool `yaml:"reprocess"` //Reprocess, if true, will for the bucket to be re-injected into processing chain - CacheSize int `yaml:"cache_size"` //CacheSize, if > 0, limits the size of in-memory cache of the bucket - Profiling bool `yaml:"profiling"` //Profiling, if true, will make the bucket record pours/overflows/etc. - OverflowFilter string `yaml:"overflow_filter"` //OverflowFilter if present, is a filter that must return true for the overflow to go through - ConditionalOverflow string `yaml:"condition"` //condition if present, is an expression that must return true for the bucket to overflow - BayesianPrior float32 `yaml:"bayesian_prior"` - BayesianThreshold float32 `yaml:"bayesian_threshold"` - BayesianConditions []RawBayesianCondition `yaml:"bayesian_conditions"` //conditions for the bayesian bucket - ScopeType types.ScopeType `yaml:"scope,omitempty"` //to enforce a different remediation than blocking an IP. Will default this to IP - BucketName string `yaml:"-"` - Filename string `yaml:"-"` - RunTimeFilter *vm.Program `json:"-"` - ExprDebugger *exprhelpers.ExprDebugger `yaml:"-" json:"-"` // used to debug expression by printing the content of each variable of the expression - RunTimeGroupBy *vm.Program `json:"-"` - Data []*types.DataSource `yaml:"data,omitempty"` - DataDir string `yaml:"-"` - CancelOnFilter string `yaml:"cancel_on,omitempty"` //a filter that, if matched, kills the bucket - leakspeed time.Duration //internal representation of `Leakspeed` - duration time.Duration //internal representation of `Duration` - ret chan types.Event //the bucket-specific output chan for overflows - processors []Processor //processors is the list of hooks for pour/overflow/create (cf. uniq, blackhole etc.) - output bool //?? - ScenarioVersion string `yaml:"version,omitempty"` - hash string `yaml:"-"` - Simulated bool `yaml:"simulated"` //Set to true if the scenario instantiating the bucket was in the exclusion list - tomb *tomb.Tomb `yaml:"-"` - wgPour *sync.WaitGroup `yaml:"-"` - wgDumpState *sync.WaitGroup `yaml:"-"` + FormatVersion string `yaml:"format"` + Author string `yaml:"author"` + Description string `yaml:"description"` + References []string `yaml:"references"` + Type string `yaml:"type"` // Type can be : leaky, counter, trigger. It determines the main bucket characteristics + Name string `yaml:"name"` // Name of the bucket, used later in log and user-messages. Should be unique + Capacity int `yaml:"capacity"` // Capacity is applicable to leaky buckets and determines the "burst" capacity + LeakSpeed string `yaml:"leakspeed"` // Leakspeed is a float representing how many events per second leak out of the bucket + Duration string `yaml:"duration"` // Duration allows 'counter' buckets to have a fixed life-time + Filter string `yaml:"filter"` // Filter is an expr that determines if an event is elligible for said bucket. Filter is evaluated against the Event struct + GroupBy string `yaml:"groupby,omitempty"` // groupy is an expr that allows to determine the partitions of the bucket. A common example is the source_ip + Distinct string `yaml:"distinct"` // Distinct, when present, adds a `Pour()` processor that will only pour uniq items (based on distinct expr result) + Debug bool `yaml:"debug"` // Debug, when set to true, will enable debugging for _this_ scenario specifically + Labels map[string]interface{} `yaml:"labels"` // Labels is K:V list aiming at providing context the overflow + Blackhole string `yaml:"blackhole,omitempty"` // Blackhole is a duration that, if present, will prevent same bucket partition to overflow more often than $duration + logger *log.Entry // logger is bucket-specific logger (used by Debug as well) + Reprocess bool `yaml:"reprocess"` // Reprocess, if true, will for the bucket to be re-injected into processing chain + CacheSize int `yaml:"cache_size"` // CacheSize, if > 0, limits the size of in-memory cache of the bucket + Profiling bool `yaml:"profiling"` // Profiling, if true, will make the bucket record pours/overflows/etc. + OverflowFilter string `yaml:"overflow_filter"` // OverflowFilter if present, is a filter that must return true for the overflow to go through + ConditionalOverflow string `yaml:"condition"` // condition if present, is an expression that must return true for the bucket to overflow + BayesianPrior float32 `yaml:"bayesian_prior"` + BayesianThreshold float32 `yaml:"bayesian_threshold"` + BayesianConditions []RawBayesianCondition `yaml:"bayesian_conditions"` // conditions for the bayesian bucket + ScopeType types.ScopeType `yaml:"scope,omitempty"` // to enforce a different remediation than blocking an IP. Will default this to IP + BucketName string `yaml:"-"` + Filename string `yaml:"-"` + RunTimeFilter *vm.Program `json:"-"` + RunTimeGroupBy *vm.Program `json:"-"` + Data []*types.DataSource `yaml:"data,omitempty"` + DataDir string `yaml:"-"` + CancelOnFilter string `yaml:"cancel_on,omitempty"` // a filter that, if matched, kills the bucket + leakspeed time.Duration // internal representation of `Leakspeed` + duration time.Duration // internal representation of `Duration` + ret chan types.Event // the bucket-specific output chan for overflows + processors []Processor // processors is the list of hooks for pour/overflow/create (cf. uniq, blackhole etc.) + output bool // ?? + ScenarioVersion string `yaml:"version,omitempty"` + hash string + Simulated bool `yaml:"simulated"` // Set to true if the scenario instantiating the bucket was in the exclusion list + tomb *tomb.Tomb + wgPour *sync.WaitGroup + wgDumpState *sync.WaitGroup orderEvent bool } // we use one NameGenerator for all the future buckets var seed namegenerator.Generator = namegenerator.NewNameGenerator(time.Now().UTC().UnixNano()) +func validateLeakyType(bucketFactory *BucketFactory) error { + if bucketFactory.Capacity <= 0 { // capacity must be a positive int + return fmt.Errorf("bad capacity for leaky '%d'", bucketFactory.Capacity) + } + + if bucketFactory.LeakSpeed == "" { + return errors.New("leakspeed can't be empty for leaky") + } + + if bucketFactory.leakspeed == 0 { + return fmt.Errorf("bad leakspeed for leaky '%s'", bucketFactory.LeakSpeed) + } + + return nil +} + +func validateCounterType(bucketFactory *BucketFactory) error { + if bucketFactory.Duration == "" { + return errors.New("duration can't be empty for counter") + } + + if bucketFactory.duration == 0 { + return fmt.Errorf("bad duration for counter bucket '%d'", bucketFactory.duration) + } + + if bucketFactory.Capacity != -1 { + return errors.New("counter bucket must have -1 capacity") + } + + return nil +} + +func validateTriggerType(bucketFactory *BucketFactory) error { + if bucketFactory.Capacity != 0 { + return errors.New("trigger bucket must have 0 capacity") + } + + return nil +} + +func validateConditionalType(bucketFactory *BucketFactory) error { + if bucketFactory.ConditionalOverflow == "" { + return errors.New("conditional bucket must have a condition") + } + + if bucketFactory.Capacity != -1 { + bucketFactory.logger.Warnf("Using a value different than -1 as capacity for conditional bucket, this may lead to unexpected overflows") + } + + if bucketFactory.LeakSpeed == "" { + return errors.New("leakspeed can't be empty for conditional bucket") + } + + if bucketFactory.leakspeed == 0 { + return fmt.Errorf("bad leakspeed for conditional bucket '%s'", bucketFactory.LeakSpeed) + } + + return nil +} + +func validateBayesianType(bucketFactory *BucketFactory) error { + if bucketFactory.BayesianConditions == nil { + return errors.New("bayesian bucket must have bayesian conditions") + } + + if bucketFactory.BayesianPrior == 0 { + return errors.New("bayesian bucket must have a valid, non-zero prior") + } + + if bucketFactory.BayesianThreshold == 0 { + return errors.New("bayesian bucket must have a valid, non-zero threshold") + } + + if bucketFactory.BayesianPrior > 1 { + return errors.New("bayesian bucket must have a valid, non-zero prior") + } + + if bucketFactory.BayesianThreshold > 1 { + return errors.New("bayesian bucket must have a valid, non-zero threshold") + } + + if bucketFactory.Capacity != -1 { + return errors.New("bayesian bucket must have capacity -1") + } + + return nil +} + func ValidateFactory(bucketFactory *BucketFactory) error { if bucketFactory.Name == "" { - return fmt.Errorf("bucket must have name") + return errors.New("bucket must have name") } + if bucketFactory.Description == "" { - return fmt.Errorf("description is mandatory") + return errors.New("description is mandatory") } - if bucketFactory.Type == "leaky" { - if bucketFactory.Capacity <= 0 { //capacity must be a positive int - return fmt.Errorf("bad capacity for leaky '%d'", bucketFactory.Capacity) - } - if bucketFactory.LeakSpeed == "" { - return fmt.Errorf("leakspeed can't be empty for leaky") - } - if bucketFactory.leakspeed == 0 { - return fmt.Errorf("bad leakspeed for leaky '%s'", bucketFactory.LeakSpeed) - } - } else if bucketFactory.Type == "counter" { - if bucketFactory.Duration == "" { - return fmt.Errorf("duration can't be empty for counter") - } - if bucketFactory.duration == 0 { - return fmt.Errorf("bad duration for counter bucket '%d'", bucketFactory.duration) - } - if bucketFactory.Capacity != -1 { - return fmt.Errorf("counter bucket must have -1 capacity") - } - } else if bucketFactory.Type == "trigger" { - if bucketFactory.Capacity != 0 { - return fmt.Errorf("trigger bucket must have 0 capacity") - } - } else if bucketFactory.Type == "conditional" { - if bucketFactory.ConditionalOverflow == "" { - return fmt.Errorf("conditional bucket must have a condition") - } - if bucketFactory.Capacity != -1 { - bucketFactory.logger.Warnf("Using a value different than -1 as capacity for conditional bucket, this may lead to unexpected overflows") - } - if bucketFactory.LeakSpeed == "" { - return fmt.Errorf("leakspeed can't be empty for conditional bucket") - } - if bucketFactory.leakspeed == 0 { - return fmt.Errorf("bad leakspeed for conditional bucket '%s'", bucketFactory.LeakSpeed) - } - } else if bucketFactory.Type == "bayesian" { - if bucketFactory.BayesianConditions == nil { - return fmt.Errorf("bayesian bucket must have bayesian conditions") - } - if bucketFactory.BayesianPrior == 0 { - return fmt.Errorf("bayesian bucket must have a valid, non-zero prior") + + switch bucketFactory.Type { + case "leaky": + if err := validateLeakyType(bucketFactory); err != nil { + return err } - if bucketFactory.BayesianThreshold == 0 { - return fmt.Errorf("bayesian bucket must have a valid, non-zero threshold") + case "counter": + if err := validateCounterType(bucketFactory); err != nil { + return err } - if bucketFactory.BayesianPrior > 1 { - return fmt.Errorf("bayesian bucket must have a valid, non-zero prior") + case "trigger": + if err := validateTriggerType(bucketFactory); err != nil { + return err } - if bucketFactory.BayesianThreshold > 1 { - return fmt.Errorf("bayesian bucket must have a valid, non-zero threshold") + case "conditional": + if err := validateConditionalType(bucketFactory); err != nil { + return err } - if bucketFactory.Capacity != -1 { - return fmt.Errorf("bayesian bucket must have capacity -1") + case "bayesian": + if err := validateBayesianType(bucketFactory); err != nil { + return err } - } else { + default: return fmt.Errorf("unknown bucket type '%s'", bucketFactory.Type) } @@ -156,106 +210,121 @@ func ValidateFactory(bucketFactory *BucketFactory) error { runTimeFilter *vm.Program err error ) + if bucketFactory.ScopeType.Filter != "" { if runTimeFilter, err = expr.Compile(bucketFactory.ScopeType.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...); err != nil { - return fmt.Errorf("Error compiling the scope filter: %s", err) + return fmt.Errorf("error compiling the scope filter: %w", err) } + bucketFactory.ScopeType.RunTimeFilter = runTimeFilter } default: - //Compile the scope filter + // Compile the scope filter var ( runTimeFilter *vm.Program err error ) + if bucketFactory.ScopeType.Filter != "" { if runTimeFilter, err = expr.Compile(bucketFactory.ScopeType.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...); err != nil { - return fmt.Errorf("Error compiling the scope filter: %s", err) + return fmt.Errorf("error compiling the scope filter: %w", err) } + bucketFactory.ScopeType.RunTimeFilter = runTimeFilter } } + return nil } -func LoadBuckets(cscfg *csconfig.CrowdsecServiceCfg, files []string, tomb *tomb.Tomb, buckets *Buckets, orderEvent bool) ([]BucketFactory, chan types.Event, error) { +func LoadBuckets(cscfg *csconfig.CrowdsecServiceCfg, hub *cwhub.Hub, files []string, tomb *tomb.Tomb, buckets *Buckets, orderEvent bool) ([]BucketFactory, chan types.Event, error) { var ( ret = []BucketFactory{} response chan types.Event ) response = make(chan types.Event, 1) + for _, f := range files { log.Debugf("Loading '%s'", f) + if !strings.HasSuffix(f, ".yaml") && !strings.HasSuffix(f, ".yml") { log.Debugf("Skipping %s : not a yaml file", f) continue } - //process the yaml + // process the yaml bucketConfigurationFile, err := os.Open(f) if err != nil { log.Errorf("Can't access leaky configuration file %s", f) return nil, nil, err } + + defer bucketConfigurationFile.Close() dec := yaml.NewDecoder(bucketConfigurationFile) dec.SetStrict(true) + for { bucketFactory := BucketFactory{} + err = dec.Decode(&bucketFactory) if err != nil { if !errors.Is(err, io.EOF) { - log.Errorf("Bad yaml in %s : %v", f, err) - return nil, nil, fmt.Errorf("bad yaml in %s : %v", f, err) + log.Errorf("Bad yaml in %s: %v", f, err) + return nil, nil, fmt.Errorf("bad yaml in %s: %w", f, err) } + log.Tracef("End of yaml file") + break } - bucketFactory.DataDir = cscfg.DataDir - //check empty + + bucketFactory.DataDir = hub.GetDataDir() + // check empty if bucketFactory.Name == "" { log.Errorf("Won't load nameless bucket") - return nil, nil, fmt.Errorf("nameless bucket") + return nil, nil, errors.New("nameless bucket") } - //check compat + // check compat if bucketFactory.FormatVersion == "" { log.Tracef("no version in %s : %s, assuming '1.0'", bucketFactory.Name, f) bucketFactory.FormatVersion = "1.0" } - ok, err := cwversion.Satisfies(bucketFactory.FormatVersion, cwversion.Constraint_scenario) + + ok, err := constraint.Satisfies(bucketFactory.FormatVersion, constraint.Scenario) if err != nil { - log.Fatalf("Failed to check version : %s", err) + return nil, nil, fmt.Errorf("failed to check version: %w", err) } + if !ok { - log.Errorf("can't load %s : %s doesn't satisfy scenario format %s, skip", bucketFactory.Name, bucketFactory.FormatVersion, cwversion.Constraint_scenario) + log.Errorf("can't load %s : %s doesn't satisfy scenario format %s, skip", bucketFactory.Name, bucketFactory.FormatVersion, constraint.Scenario) continue } bucketFactory.Filename = filepath.Clean(f) bucketFactory.BucketName = seed.Generate() bucketFactory.ret = response - hubItem, err := cwhub.GetItemByPath(cwhub.SCENARIOS, bucketFactory.Filename) - if err != nil { - log.Errorf("scenario %s (%s) couldn't be find in hub (ignore if in unit tests)", bucketFactory.Name, bucketFactory.Filename) + + hubItem := hub.GetItemByPath(bucketFactory.Filename) + if hubItem == nil { + log.Errorf("scenario %s (%s) could not be found in hub (ignore if in unit tests)", bucketFactory.Name, bucketFactory.Filename) } else { if cscfg.SimulationConfig != nil { bucketFactory.Simulated = cscfg.SimulationConfig.IsSimulated(hubItem.Name) } - if hubItem != nil { - bucketFactory.ScenarioVersion = hubItem.LocalVersion - bucketFactory.hash = hubItem.LocalHash - } else { - log.Errorf("scenario %s (%s) couldn't be find in hub (ignore if in unit tests)", bucketFactory.Name, bucketFactory.Filename) - } + + bucketFactory.ScenarioVersion = hubItem.State.LocalVersion + bucketFactory.hash = hubItem.State.LocalHash } bucketFactory.wgDumpState = buckets.wgDumpState bucketFactory.wgPour = buckets.wgPour + err = LoadBucket(&bucketFactory, tomb) if err != nil { - log.Errorf("Failed to load bucket %s : %v", bucketFactory.Name, err) - return nil, nil, fmt.Errorf("loading of %s failed : %v", bucketFactory.Name, err) + log.Errorf("Failed to load bucket %s: %v", bucketFactory.Name, err) + return nil, nil, fmt.Errorf("loading of %s failed: %w", bucketFactory.Name, err) } bucketFactory.orderEvent = orderEvent @@ -265,73 +334,70 @@ func LoadBuckets(cscfg *csconfig.CrowdsecServiceCfg, files []string, tomb *tomb. } if err := alertcontext.NewAlertContext(cscfg.ContextToSend, cscfg.ConsoleContextValueLength); err != nil { - return nil, nil, fmt.Errorf("unable to load alert context: %s", err) + return nil, nil, fmt.Errorf("unable to load alert context: %w", err) } log.Infof("Loaded %d scenarios", len(ret)) + return ret, response, nil } /* Init recursively process yaml files from a directory and loads them as BucketFactory */ func LoadBucket(bucketFactory *BucketFactory, tomb *tomb.Tomb) error { var err error + if bucketFactory.Debug { - var clog = log.New() + clog := log.New() if err := types.ConfigureLogger(clog); err != nil { - log.Fatalf("While creating bucket-specific logger : %s", err) + return fmt.Errorf("while creating bucket-specific logger: %w", err) } + clog.SetLevel(log.DebugLevel) bucketFactory.logger = clog.WithFields(log.Fields{ "cfg": bucketFactory.BucketName, "name": bucketFactory.Name, - "file": bucketFactory.Filename, }) } else { /* else bind it to the default one (might find something more elegant here)*/ bucketFactory.logger = log.WithFields(log.Fields{ "cfg": bucketFactory.BucketName, "name": bucketFactory.Name, - "file": bucketFactory.Filename, }) } if bucketFactory.LeakSpeed != "" { if bucketFactory.leakspeed, err = time.ParseDuration(bucketFactory.LeakSpeed); err != nil { - return fmt.Errorf("bad leakspeed '%s' in %s : %v", bucketFactory.LeakSpeed, bucketFactory.Filename, err) + return fmt.Errorf("bad leakspeed '%s' in %s: %w", bucketFactory.LeakSpeed, bucketFactory.Filename, err) } } else { bucketFactory.leakspeed = time.Duration(0) } + if bucketFactory.Duration != "" { if bucketFactory.duration, err = time.ParseDuration(bucketFactory.Duration); err != nil { - return fmt.Errorf("invalid Duration '%s' in %s : %v", bucketFactory.Duration, bucketFactory.Filename, err) + return fmt.Errorf("invalid Duration '%s' in %s: %w", bucketFactory.Duration, bucketFactory.Filename, err) } } if bucketFactory.Filter == "" { bucketFactory.logger.Warning("Bucket without filter, abort.") - return fmt.Errorf("bucket without filter directive") + return errors.New("bucket without filter directive") } + bucketFactory.RunTimeFilter, err = expr.Compile(bucketFactory.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { - return fmt.Errorf("invalid filter '%s' in %s : %v", bucketFactory.Filter, bucketFactory.Filename, err) - } - if bucketFactory.Debug { - bucketFactory.ExprDebugger, err = exprhelpers.NewDebugger(bucketFactory.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) - if err != nil { - log.Errorf("unable to build debug filter for '%s' : %s", bucketFactory.Filter, err) - } + return fmt.Errorf("invalid filter '%s' in %s: %w", bucketFactory.Filter, bucketFactory.Filename, err) } if bucketFactory.GroupBy != "" { bucketFactory.RunTimeGroupBy, err = expr.Compile(bucketFactory.GroupBy, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { - return fmt.Errorf("invalid groupby '%s' in %s : %v", bucketFactory.GroupBy, bucketFactory.Filename, err) + return fmt.Errorf("invalid groupby '%s' in %s: %w", bucketFactory.GroupBy, bucketFactory.Filename, err) } } bucketFactory.logger.Infof("Adding %s bucket", bucketFactory.Type) - //return the Holder corresponding to the type of bucket + // return the Holder corresponding to the type of bucket bucketFactory.processors = []Processor{} switch bucketFactory.Type { case "leaky": @@ -345,7 +411,7 @@ func LoadBucket(bucketFactory *BucketFactory, tomb *tomb.Tomb) error { case "bayesian": bucketFactory.processors = append(bucketFactory.processors, &DumbProcessor{}) default: - return fmt.Errorf("invalid type '%s' in %s : %v", bucketFactory.Type, bucketFactory.Filename, err) + return fmt.Errorf("invalid type '%s' in %s: %w", bucketFactory.Type, bucketFactory.Filename, err) } if bucketFactory.Distinct != "" { @@ -360,21 +426,25 @@ func LoadBucket(bucketFactory *BucketFactory, tomb *tomb.Tomb) error { if bucketFactory.OverflowFilter != "" { bucketFactory.logger.Tracef("Adding an overflow filter") + filovflw, err := NewOverflowFilter(bucketFactory) if err != nil { bucketFactory.logger.Errorf("Error creating overflow_filter : %s", err) - return fmt.Errorf("error creating overflow_filter : %s", err) + return fmt.Errorf("error creating overflow_filter: %w", err) } + bucketFactory.processors = append(bucketFactory.processors, filovflw) } if bucketFactory.Blackhole != "" { bucketFactory.logger.Tracef("Adding blackhole.") + blackhole, err := NewBlackhole(bucketFactory) if err != nil { bucketFactory.logger.Errorf("Error creating blackhole : %s", err) - return fmt.Errorf("error creating blackhole : %s", err) + return fmt.Errorf("error creating blackhole : %w", err) } + bucketFactory.processors = append(bucketFactory.processors, blackhole) } @@ -388,87 +458,98 @@ func LoadBucket(bucketFactory *BucketFactory, tomb *tomb.Tomb) error { bucketFactory.processors = append(bucketFactory.processors, &BayesianBucket{}) } - if len(bucketFactory.Data) > 0 { - for _, data := range bucketFactory.Data { - if data.DestPath == "" { - bucketFactory.logger.Errorf("no dest_file provided for '%s'", bucketFactory.Name) - continue - } - err = exprhelpers.FileInit(bucketFactory.DataDir, data.DestPath, data.Type) - if err != nil { - bucketFactory.logger.Errorf("unable to init data for file '%s': %s", data.DestPath, err) - } - if data.Type == "regexp" { //cache only makes sense for regexp - exprhelpers.RegexpCacheInit(data.DestPath, *data) - } + for _, data := range bucketFactory.Data { + if data.DestPath == "" { + bucketFactory.logger.Errorf("no dest_file provided for '%s'", bucketFactory.Name) + continue + } + + err = exprhelpers.FileInit(bucketFactory.DataDir, data.DestPath, data.Type) + if err != nil { + bucketFactory.logger.Errorf("unable to init data for file '%s': %s", data.DestPath, err) + } + + if data.Type == "regexp" { // cache only makes sense for regexp + exprhelpers.RegexpCacheInit(data.DestPath, *data) } } bucketFactory.output = false if err := ValidateFactory(bucketFactory); err != nil { - return fmt.Errorf("invalid bucket from %s : %v", bucketFactory.Filename, err) + return fmt.Errorf("invalid bucket from %s: %w", bucketFactory.Filename, err) } + bucketFactory.tomb = tomb return nil - } func LoadBucketsState(file string, buckets *Buckets, bucketFactories []BucketFactory) error { var state map[string]Leaky + body, err := os.ReadFile(file) if err != nil { - return fmt.Errorf("can't state file %s : %s", file, err) + return fmt.Errorf("can't read state file %s: %w", file, err) } + if err := json.Unmarshal(body, &state); err != nil { - return fmt.Errorf("can't unmarshal state file %s : %s", file, err) + return fmt.Errorf("can't parse state file %s: %w", file, err) } + for k, v := range state { var tbucket *Leaky + log.Debugf("Reloading bucket %s", k) + val, ok := buckets.Bucket_map.Load(k) if ok { - log.Fatalf("key %s already exists : %+v", k, val) + return fmt.Errorf("key %s already exists: %+v", k, val) } - //find back our holder + // find back our holder found := false + for _, h := range bucketFactories { - if h.Name == v.Name { - log.Debugf("found factory %s/%s -> %s", h.Author, h.Name, h.Description) - //check in which mode the bucket was - if v.Mode == types.TIMEMACHINE { - tbucket = NewTimeMachine(h) - } else if v.Mode == types.LIVE { - tbucket = NewLeaky(h) - } else { - log.Errorf("Unknown bucket type : %d", v.Mode) - } - /*Trying to restore queue state*/ - tbucket.Queue = v.Queue - /*Trying to set the limiter to the saved values*/ - tbucket.Limiter.Load(v.SerializedState) - tbucket.In = make(chan *types.Event) - tbucket.Mapkey = k - tbucket.Signal = make(chan bool, 1) - tbucket.First_ts = v.First_ts - tbucket.Last_ts = v.Last_ts - tbucket.Ovflw_ts = v.Ovflw_ts - tbucket.Total_count = v.Total_count - buckets.Bucket_map.Store(k, tbucket) - h.tomb.Go(func() error { - return LeakRoutine(tbucket) - }) - <-tbucket.Signal - found = true - break + if h.Name != v.Name { + continue + } + + log.Debugf("found factory %s/%s -> %s", h.Author, h.Name, h.Description) + // check in which mode the bucket was + if v.Mode == types.TIMEMACHINE { + tbucket = NewTimeMachine(h) + } else if v.Mode == types.LIVE { + tbucket = NewLeaky(h) + } else { + log.Errorf("Unknown bucket type : %d", v.Mode) } + /*Trying to restore queue state*/ + tbucket.Queue = v.Queue + /*Trying to set the limiter to the saved values*/ + tbucket.Limiter.Load(v.SerializedState) + tbucket.In = make(chan *types.Event) + tbucket.Mapkey = k + tbucket.Signal = make(chan bool, 1) + tbucket.First_ts = v.First_ts + tbucket.Last_ts = v.Last_ts + tbucket.Ovflw_ts = v.Ovflw_ts + tbucket.Total_count = v.Total_count + buckets.Bucket_map.Store(k, tbucket) + h.tomb.Go(func() error { + return LeakRoutine(tbucket) + }) + <-tbucket.Signal + + found = true + + break } + if !found { - log.Fatalf("Unable to find holder for bucket %s : %s", k, spew.Sdump(v)) + return fmt.Errorf("unable to find holder for bucket %s: %s", k, spew.Sdump(v)) } } log.Infof("Restored %d buckets from dump", len(state)) - return nil + return nil } diff --git a/pkg/leakybucket/manager_run.go b/pkg/leakybucket/manager_run.go index 388227a41e5..2858d8b5635 100644 --- a/pkg/leakybucket/manager_run.go +++ b/pkg/leakybucket/manager_run.go @@ -9,11 +9,11 @@ import ( "sync" "time" - "github.com/antonmedv/expr" "github.com/mohae/deepcopy" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" + "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -85,7 +85,7 @@ func DumpBucketsStateAt(deadline time.Time, outputdir string, buckets *Buckets) defer buckets.wgDumpState.Done() if outputdir == "" { - return "", fmt.Errorf("empty output dir for dump bucket state") + return "", errors.New("empty output dir for dump bucket state") } tmpFd, err := os.CreateTemp(os.TempDir(), "crowdsec-buckets-dump-") if err != nil { @@ -132,11 +132,11 @@ func DumpBucketsStateAt(deadline time.Time, outputdir string, buckets *Buckets) }) bbuckets, err := json.MarshalIndent(serialized, "", " ") if err != nil { - return "", fmt.Errorf("Failed to unmarshal buckets : %s", err) + return "", fmt.Errorf("failed to parse buckets: %s", err) } size, err := tmpFd.Write(bbuckets) if err != nil { - return "", fmt.Errorf("failed to write temp file : %s", err) + return "", fmt.Errorf("failed to write temp file: %s", err) } log.Infof("Serialized %d live buckets (+%d expired) in %d bytes to %s", len(serialized), discard, size, tmpFd.Name()) serialized = nil @@ -203,7 +203,7 @@ func PourItemToBucket(bucket *Leaky, holder BucketFactory, buckets *Buckets, par var d time.Time err = d.UnmarshalText([]byte(parsed.MarshaledTime)) if err != nil { - holder.logger.Warningf("Failed unmarshaling event time (%s) : %v", parsed.MarshaledTime, err) + holder.logger.Warningf("Failed to parse event time (%s) : %v", parsed.MarshaledTime, err) } if d.After(lastTs.Add(bucket.Duration)) { bucket.logger.Tracef("bucket is expired (curr event: %s, bucket deadline: %s), kill", d, lastTs.Add(bucket.Duration)) @@ -297,15 +297,17 @@ func PourItemToHolders(parsed types.Event, holders []BucketFactory, buckets *Buc evt := deepcopy.Copy(parsed) BucketPourCache["OK"] = append(BucketPourCache["OK"], evt.(types.Event)) } - //find the relevant holders (scenarios) - for idx := 0; idx < len(holders); idx++ { + for idx := range holders { //for idx, holder := range holders { //evaluate bucket's condition if holders[idx].RunTimeFilter != nil { holders[idx].logger.Tracef("event against holder %d/%d", idx, len(holders)) - output, err := expr.Run(holders[idx].RunTimeFilter, map[string]interface{}{"evt": &parsed}) + output, err := exprhelpers.Run(holders[idx].RunTimeFilter, + map[string]interface{}{"evt": &parsed}, + holders[idx].logger, + holders[idx].Debug) if err != nil { holders[idx].logger.Errorf("failed parsing : %v", err) return false, fmt.Errorf("leaky failed : %s", err) @@ -315,10 +317,6 @@ func PourItemToHolders(parsed types.Event, holders []BucketFactory, buckets *Buc holders[idx].logger.Errorf("unexpected non-bool return : %T", output) holders[idx].logger.Fatalf("Filter issue") } - - if holders[idx].Debug { - holders[idx].ExprDebugger.Run(holders[idx].logger, condition, map[string]interface{}{"evt": &parsed}) - } if !condition { holders[idx].logger.Debugf("Event leaving node : ko (filter mismatch)") continue @@ -328,7 +326,7 @@ func PourItemToHolders(parsed types.Event, holders []BucketFactory, buckets *Buc //groupby determines the partition key for the specific bucket var groupby string if holders[idx].RunTimeGroupBy != nil { - tmpGroupBy, err := expr.Run(holders[idx].RunTimeGroupBy, map[string]interface{}{"evt": &parsed}) + tmpGroupBy, err := exprhelpers.Run(holders[idx].RunTimeGroupBy, map[string]interface{}{"evt": &parsed}, holders[idx].logger, holders[idx].Debug) if err != nil { holders[idx].logger.Errorf("failed groupby : %v", err) return false, errors.New("leaky failed :/") diff --git a/pkg/leakybucket/manager_run_test.go b/pkg/leakybucket/manager_run_test.go index 27b665f750c..f3fe08b697a 100644 --- a/pkg/leakybucket/manager_run_test.go +++ b/pkg/leakybucket/manager_run_test.go @@ -5,9 +5,10 @@ import ( "testing" "time" - "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" "gopkg.in/tomb.v2" + + "github.com/crowdsecurity/crowdsec/pkg/types" ) func expectBucketCount(buckets *Buckets, expected int) error { @@ -20,7 +21,6 @@ func expectBucketCount(buckets *Buckets, expected int) error { return fmt.Errorf("expected %d live buckets, got %d", expected, count) } return nil - } func TestGCandDump(t *testing.T) { @@ -29,7 +29,7 @@ func TestGCandDump(t *testing.T) { tomb = &tomb.Tomb{} ) - var Holders = []BucketFactory{ + Holders := []BucketFactory{ //one overflowing soon + bh { Name: "test_counter_fast", @@ -80,7 +80,7 @@ func TestGCandDump(t *testing.T) { log.Printf("Pouring to bucket") - var in = types.Event{Parsed: map[string]string{"something": "something"}} + in := types.Event{Parsed: map[string]string{"something": "something"}} //pour an item that will go to leaky + counter ok, err := PourItemToHolders(in, Holders, buckets) if err != nil { @@ -156,7 +156,7 @@ func TestShutdownBuckets(t *testing.T) { log.Printf("Pouring to bucket") - var in = types.Event{Parsed: map[string]string{"something": "something"}} + in := types.Event{Parsed: map[string]string{"something": "something"}} //pour an item that will go to leaky + counter ok, err := PourItemToHolders(in, Holders, buckets) if err != nil { @@ -178,5 +178,4 @@ func TestShutdownBuckets(t *testing.T) { if err := expectBucketCount(buckets, 2); err != nil { t.Fatal(err) } - } diff --git a/pkg/leakybucket/overflow_filter.go b/pkg/leakybucket/overflow_filter.go index c716c22d3fb..01dd491ed41 100644 --- a/pkg/leakybucket/overflow_filter.go +++ b/pkg/leakybucket/overflow_filter.go @@ -3,8 +3,8 @@ package leakybucket import ( "fmt" - "github.com/antonmedv/expr" - "github.com/antonmedv/expr/vm" + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" @@ -28,7 +28,7 @@ func NewOverflowFilter(g *BucketFactory) (*OverflowFilter, error) { u := OverflowFilter{} u.Filter = g.OverflowFilter - u.FilterRuntime, err = expr.Compile(u.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"queue": &Queue{}, "signal": &types.RuntimeAlert{}, "leaky": &Leaky{}})...) + u.FilterRuntime, err = expr.Compile(u.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"queue": &types.Queue{}, "signal": &types.RuntimeAlert{}, "leaky": &Leaky{}})...) if err != nil { g.logger.Errorf("Unable to compile filter : %v", err) return nil, fmt.Errorf("unable to compile filter : %v", err) @@ -36,10 +36,10 @@ func NewOverflowFilter(g *BucketFactory) (*OverflowFilter, error) { return &u, nil } -func (u *OverflowFilter) OnBucketOverflow(Bucket *BucketFactory) func(*Leaky, types.RuntimeAlert, *Queue) (types.RuntimeAlert, *Queue) { - return func(l *Leaky, s types.RuntimeAlert, q *Queue) (types.RuntimeAlert, *Queue) { - el, err := expr.Run(u.FilterRuntime, map[string]interface{}{ - "queue": q, "signal": s, "leaky": l}) +func (u *OverflowFilter) OnBucketOverflow(Bucket *BucketFactory) func(*Leaky, types.RuntimeAlert, *types.Queue) (types.RuntimeAlert, *types.Queue) { + return func(l *Leaky, s types.RuntimeAlert, q *types.Queue) (types.RuntimeAlert, *types.Queue) { + el, err := exprhelpers.Run(u.FilterRuntime, map[string]interface{}{ + "queue": q, "signal": s, "leaky": l}, l.logger, Bucket.Debug) if err != nil { l.logger.Errorf("Failed running overflow filter: %s", err) return s, q diff --git a/pkg/leakybucket/overflows.go b/pkg/leakybucket/overflows.go index 5f5f0484bec..39b0e6a0ec4 100644 --- a/pkg/leakybucket/overflows.go +++ b/pkg/leakybucket/overflows.go @@ -1,120 +1,149 @@ package leakybucket import ( + "errors" "fmt" "net" "sort" "strconv" - "github.com/antonmedv/expr" "github.com/davecgh/go-spew/spew" "github.com/go-openapi/strfmt" log "github.com/sirupsen/logrus" "github.com/crowdsecurity/crowdsec/pkg/alertcontext" + "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/types" ) // SourceFromEvent extracts and formats a valid models.Source object from an Event func SourceFromEvent(evt types.Event, leaky *Leaky) (map[string]models.Source, error) { - srcs := make(map[string]models.Source) /*if it's already an overflow, we have properly formatted sources. we can just twitch them to reflect the requested scope*/ if evt.Type == types.OVFLW { + return overflowEventSources(evt, leaky) + } - for k, v := range evt.Overflow.Sources { + return eventSources(evt, leaky) +} - /*the scopes are already similar, nothing to do*/ - if leaky.scopeType.Scope == *v.Scope { - srcs[k] = v - continue - } +func overflowEventSources(evt types.Event, leaky *Leaky) (map[string]models.Source, error) { + srcs := make(map[string]models.Source) - /*The bucket requires a decision on scope Range */ - if leaky.scopeType.Scope == types.Range { - /*the original bucket was target IPs, check that we do have range*/ - if *v.Scope == types.Ip { - src := models.Source{} - src.AsName = v.AsName - src.AsNumber = v.AsNumber - src.Cn = v.Cn - src.Latitude = v.Latitude - src.Longitude = v.Longitude - src.Range = v.Range - src.Value = new(string) - src.Scope = new(string) - *src.Scope = leaky.scopeType.Scope - *src.Value = "" - if v.Range != "" { - *src.Value = v.Range - } - if leaky.scopeType.RunTimeFilter != nil { - retValue, err := expr.Run(leaky.scopeType.RunTimeFilter, map[string]interface{}{"evt": &evt}) - if err != nil { - return srcs, fmt.Errorf("while running scope filter: %w", err) - } - value, ok := retValue.(string) - if !ok { - value = "" - } - src.Value = &value + for k, v := range evt.Overflow.Sources { + /*the scopes are already similar, nothing to do*/ + if leaky.scopeType.Scope == *v.Scope { + srcs[k] = v + continue + } + + /*The bucket requires a decision on scope Range */ + if leaky.scopeType.Scope == types.Range { + /*the original bucket was target IPs, check that we do have range*/ + if *v.Scope == types.Ip { + src := models.Source{} + src.AsName = v.AsName + src.AsNumber = v.AsNumber + src.Cn = v.Cn + src.Latitude = v.Latitude + src.Longitude = v.Longitude + src.Range = v.Range + src.Value = new(string) + src.Scope = new(string) + *src.Scope = leaky.scopeType.Scope + *src.Value = "" + + if v.Range != "" { + *src.Value = v.Range + } + + if leaky.scopeType.RunTimeFilter != nil { + retValue, err := exprhelpers.Run(leaky.scopeType.RunTimeFilter, map[string]interface{}{"evt": &evt}, leaky.logger, leaky.BucketConfig.Debug) + if err != nil { + return srcs, fmt.Errorf("while running scope filter: %w", err) } - if *src.Value != "" { - srcs[*src.Value] = src - } else { - log.Warningf("bucket %s requires scope Range, but none was provided. It seems that the %s wasn't enriched to include its range.", leaky.Name, *v.Value) + + value, ok := retValue.(string) + if !ok { + value = "" } + + src.Value = &value + } + + if *src.Value != "" { + srcs[*src.Value] = src } else { - log.Warningf("bucket %s requires scope Range, but can't extrapolate from %s (%s)", - leaky.Name, *v.Scope, *v.Value) + log.Warningf("bucket %s requires scope Range, but none was provided. It seems that the %s wasn't enriched to include its range.", leaky.Name, *v.Value) } + } else { + log.Warningf("bucket %s requires scope Range, but can't extrapolate from %s (%s)", + leaky.Name, *v.Scope, *v.Value) } } - return srcs, nil } + + return srcs, nil +} + +func eventSources(evt types.Event, leaky *Leaky) (map[string]models.Source, error) { + srcs := make(map[string]models.Source) + src := models.Source{} + switch leaky.scopeType.Scope { case types.Range, types.Ip: v, ok := evt.Meta["source_ip"] if !ok { return srcs, fmt.Errorf("scope is %s but Meta[source_ip] doesn't exist", leaky.scopeType.Scope) } + if net.ParseIP(v) == nil { return srcs, fmt.Errorf("scope is %s but '%s' isn't a valid ip", leaky.scopeType.Scope, v) } + src.IP = v src.Scope = &leaky.scopeType.Scope + if v, ok := evt.Enriched["ASNumber"]; ok { src.AsNumber = v } else if v, ok := evt.Enriched["ASNNumber"]; ok { src.AsNumber = v } + if v, ok := evt.Enriched["IsoCode"]; ok { src.Cn = v } + if v, ok := evt.Enriched["ASNOrg"]; ok { src.AsName = v } + if v, ok := evt.Enriched["Latitude"]; ok { l, err := strconv.ParseFloat(v, 32) if err != nil { log.Warningf("bad latitude %s : %s", v, err) } + src.Latitude = float32(l) } + if v, ok := evt.Enriched["Longitude"]; ok { l, err := strconv.ParseFloat(v, 32) if err != nil { log.Warningf("bad longitude %s : %s", v, err) } + src.Longitude = float32(l) } + if v, ok := evt.Meta["SourceRange"]; ok && v != "" { _, ipNet, err := net.ParseCIDR(v) if err != nil { - return srcs, fmt.Errorf("Declared range %s of %s can't be parsed", v, src.IP) + return srcs, fmt.Errorf("declared range %s of %s can't be parsed", v, src.IP) } + if ipNet != nil { src.Range = ipNet.String() leaky.logger.Tracef("Valid range from %s : %s", src.IP, src.Range) @@ -124,8 +153,9 @@ func SourceFromEvent(evt types.Event, leaky *Leaky) (map[string]models.Source, e src.Value = &src.IP } else if leaky.scopeType.Scope == types.Range { src.Value = &src.Range + if leaky.scopeType.RunTimeFilter != nil { - retValue, err := expr.Run(leaky.scopeType.RunTimeFilter, map[string]interface{}{"evt": &evt}) + retValue, err := exprhelpers.Run(leaky.scopeType.RunTimeFilter, map[string]interface{}{"evt": &evt}, leaky.logger, leaky.BucketConfig.Debug) if err != nil { return srcs, fmt.Errorf("while running scope filter: %w", err) } @@ -134,15 +164,18 @@ func SourceFromEvent(evt types.Event, leaky *Leaky) (map[string]models.Source, e if !ok { value = "" } + src.Value = &value } } + srcs[*src.Value] = src default: if leaky.scopeType.RunTimeFilter == nil { - return srcs, fmt.Errorf("empty scope information") + return srcs, errors.New("empty scope information") } - retValue, err := expr.Run(leaky.scopeType.RunTimeFilter, map[string]interface{}{"evt": &evt}) + + retValue, err := exprhelpers.Run(leaky.scopeType.RunTimeFilter, map[string]interface{}{"evt": &evt}, leaky.logger, leaky.BucketConfig.Debug) if err != nil { return srcs, fmt.Errorf("while running scope filter: %w", err) } @@ -151,30 +184,34 @@ func SourceFromEvent(evt types.Event, leaky *Leaky) (map[string]models.Source, e if !ok { value = "" } + src.Value = &value src.Scope = new(string) *src.Scope = leaky.scopeType.Scope srcs[*src.Value] = src } + return srcs, nil } // EventsFromQueue iterates the queue to collect & prepare meta-datas from alert -func EventsFromQueue(queue *Queue) []*models.Event { - +func EventsFromQueue(queue *types.Queue) []*models.Event { events := []*models.Event{} for _, evt := range queue.Queue { if evt.Meta == nil { continue } + meta := models.Meta{} - //we want consistence + // we want consistence skeys := make([]string, 0, len(evt.Meta)) for k := range evt.Meta { skeys = append(skeys, k) } + sort.Strings(skeys) + for _, k := range skeys { v := evt.Meta[k] subMeta := models.MetaItems0{Key: k, Value: v} @@ -185,15 +222,16 @@ func EventsFromQueue(queue *Queue) []*models.Event { ovflwEvent := models.Event{ Meta: meta, } - //either MarshaledTime is present and is extracted from log + // either MarshaledTime is present and is extracted from log if evt.MarshaledTime != "" { tmpTimeStamp := evt.MarshaledTime ovflwEvent.Timestamp = &tmpTimeStamp - } else if !evt.Time.IsZero() { //or .Time has been set during parse as time.Now().UTC() + } else if !evt.Time.IsZero() { // or .Time has been set during parse as time.Now().UTC() ovflwEvent.Timestamp = new(string) + raw, err := evt.Time.MarshalText() if err != nil { - log.Warningf("while marshaling time '%s' : %s", evt.Time.String(), err) + log.Warningf("while serializing time '%s' : %s", evt.Time.String(), err) } else { *ovflwEvent.Timestamp = string(raw) } @@ -203,14 +241,16 @@ func EventsFromQueue(queue *Queue) []*models.Event { events = append(events, &ovflwEvent) } + return events } // alertFormatSource iterates over the queue to collect sources -func alertFormatSource(leaky *Leaky, queue *Queue) (map[string]models.Source, string, error) { - var sources = make(map[string]models.Source) +func alertFormatSource(leaky *Leaky, queue *types.Queue) (map[string]models.Source, string, error) { var source_type string + sources := make(map[string]models.Source) + log.Debugf("Formatting (%s) - scope Info : scope_type:%s / scope_filter:%s", leaky.Name, leaky.scopeType.Scope, leaky.scopeType.Filter) for _, evt := range queue.Queue { @@ -218,22 +258,26 @@ func alertFormatSource(leaky *Leaky, queue *Queue) (map[string]models.Source, st if err != nil { return nil, "", fmt.Errorf("while extracting scope from bucket %s: %w", leaky.Name, err) } + for key, src := range srcs { if source_type == types.Undefined { source_type = *src.Scope } + if *src.Scope != source_type { return nil, "", fmt.Errorf("event has multiple source types : %s != %s", *src.Scope, source_type) } + sources[key] = src } } + return sources, source_type, nil } // NewAlert will generate a RuntimeAlert and its APIAlert(s) from a bucket that overflowed -func NewAlert(leaky *Leaky, queue *Queue) (types.RuntimeAlert, error) { +func NewAlert(leaky *Leaky, queue *types.Queue) (types.RuntimeAlert, error) { var runtimeAlert types.RuntimeAlert leaky.logger.Tracef("Overflow (start: %s, end: %s)", leaky.First_ts, leaky.Ovflw_ts) @@ -242,12 +286,14 @@ func NewAlert(leaky *Leaky, queue *Queue) (types.RuntimeAlert, error) { */ start_at, err := leaky.First_ts.MarshalText() if err != nil { - log.Warningf("failed to marshal start ts %s : %s", leaky.First_ts.String(), err) + log.Warningf("failed to serialize start ts %s : %s", leaky.First_ts.String(), err) } + stop_at, err := leaky.Ovflw_ts.MarshalText() if err != nil { - log.Warningf("failed to marshal ovflw ts %s : %s", leaky.First_ts.String(), err) + log.Warningf("failed to serialize ovflw ts %s : %s", leaky.First_ts.String(), err) } + capacity := int32(leaky.Capacity) EventsCount := int32(leaky.Total_count) leakSpeed := leaky.Leakspeed.String() @@ -265,20 +311,22 @@ func NewAlert(leaky *Leaky, queue *Queue) (types.RuntimeAlert, error) { StopAt: &stopAt, Simulated: &leaky.Simulated, } + if leaky.BucketConfig == nil { - return runtimeAlert, fmt.Errorf("leaky.BucketConfig is nil") + return runtimeAlert, errors.New("leaky.BucketConfig is nil") } - //give information about the bucket + // give information about the bucket runtimeAlert.Mapkey = leaky.Mapkey - //Get the sources from Leaky/Queue + // Get the sources from Leaky/Queue sources, source_scope, err := alertFormatSource(leaky, queue) if err != nil { return runtimeAlert, fmt.Errorf("unable to collect sources from bucket: %w", err) } + runtimeAlert.Sources = sources - //Include source info in format string + // Include source info in format string sourceStr := "UNKNOWN" if len(sources) > 1 { sourceStr = fmt.Sprintf("%d sources", len(sources)) @@ -290,20 +338,23 @@ func NewAlert(leaky *Leaky, queue *Queue) (types.RuntimeAlert, error) { } *apiAlert.Message = fmt.Sprintf("%s %s performed '%s' (%d events over %s) at %s", source_scope, sourceStr, leaky.Name, leaky.Total_count, leaky.Ovflw_ts.Sub(leaky.First_ts), leaky.Last_ts) - //Get the events from Leaky/Queue + // Get the events from Leaky/Queue apiAlert.Events = EventsFromQueue(queue) + var warnings []error + apiAlert.Meta, warnings = alertcontext.EventToContext(leaky.Queue.GetQueue()) for _, w := range warnings { log.Warningf("while extracting context from bucket %s : %s", leaky.Name, w) } - //Loop over the Sources and generate appropriate number of ApiAlerts + // Loop over the Sources and generate appropriate number of ApiAlerts for _, srcValue := range sources { newApiAlert := apiAlert srcCopy := srcValue newApiAlert.Source = &srcCopy - if v, ok := leaky.BucketConfig.Labels["remediation"]; ok && v == true { + + if v, ok := leaky.BucketConfig.Labels["remediation"]; ok && v == true { //nolint:revive newApiAlert.Remediation = true } @@ -312,6 +363,7 @@ func NewAlert(leaky *Leaky, queue *Queue) (types.RuntimeAlert, error) { log.Errorf("->%s", spew.Sdump(newApiAlert)) log.Fatalf("error : %s", err) } + runtimeAlert.APIAlerts = append(runtimeAlert.APIAlerts, newApiAlert) } @@ -322,5 +374,6 @@ func NewAlert(leaky *Leaky, queue *Queue) (types.RuntimeAlert, error) { if leaky.Reprocess { runtimeAlert.Reprocess = true } + return runtimeAlert, nil } diff --git a/pkg/leakybucket/processor.go b/pkg/leakybucket/processor.go index 18dc287d810..81af3000c1c 100644 --- a/pkg/leakybucket/processor.go +++ b/pkg/leakybucket/processor.go @@ -5,7 +5,7 @@ import "github.com/crowdsecurity/crowdsec/pkg/types" type Processor interface { OnBucketInit(Bucket *BucketFactory) error OnBucketPour(Bucket *BucketFactory) func(types.Event, *Leaky) *types.Event - OnBucketOverflow(Bucket *BucketFactory) func(*Leaky, types.RuntimeAlert, *Queue) (types.RuntimeAlert, *Queue) + OnBucketOverflow(Bucket *BucketFactory) func(*Leaky, types.RuntimeAlert, *types.Queue) (types.RuntimeAlert, *types.Queue) AfterBucketPour(Bucket *BucketFactory) func(types.Event, *Leaky) *types.Event } @@ -23,8 +23,8 @@ func (d *DumbProcessor) OnBucketPour(bucketFactory *BucketFactory) func(types.Ev } } -func (d *DumbProcessor) OnBucketOverflow(b *BucketFactory) func(*Leaky, types.RuntimeAlert, *Queue) (types.RuntimeAlert, *Queue) { - return func(leaky *Leaky, alert types.RuntimeAlert, queue *Queue) (types.RuntimeAlert, *Queue) { +func (d *DumbProcessor) OnBucketOverflow(b *BucketFactory) func(*Leaky, types.RuntimeAlert, *types.Queue) (types.RuntimeAlert, *types.Queue) { + return func(leaky *Leaky, alert types.RuntimeAlert, queue *types.Queue) (types.RuntimeAlert, *types.Queue) { return alert, queue } } diff --git a/pkg/leakybucket/reset_filter.go b/pkg/leakybucket/reset_filter.go index 9b64681ab94..452ccc085b1 100644 --- a/pkg/leakybucket/reset_filter.go +++ b/pkg/leakybucket/reset_filter.go @@ -3,8 +3,8 @@ package leakybucket import ( "sync" - "github.com/antonmedv/expr" - "github.com/antonmedv/expr/vm" + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" @@ -19,14 +19,13 @@ import ( // Thus, if the bucket receives a request that matches fetching a static resource (here css), it cancels itself type CancelOnFilter struct { - CancelOnFilter *vm.Program - CancelOnFilterDebug *exprhelpers.ExprDebugger + CancelOnFilter *vm.Program + Debug bool } var cancelExprCacheLock sync.Mutex var cancelExprCache map[string]struct { - CancelOnFilter *vm.Program - CancelOnFilterDebug *exprhelpers.ExprDebugger + CancelOnFilter *vm.Program } func (u *CancelOnFilter) OnBucketPour(bucketFactory *BucketFactory) func(types.Event, *Leaky) *types.Event { @@ -34,15 +33,11 @@ func (u *CancelOnFilter) OnBucketPour(bucketFactory *BucketFactory) func(types.E var condition, ok bool if u.CancelOnFilter != nil { leaky.logger.Tracef("running cancel_on filter") - output, err := expr.Run(u.CancelOnFilter, map[string]interface{}{"evt": &msg}) + output, err := exprhelpers.Run(u.CancelOnFilter, map[string]interface{}{"evt": &msg}, leaky.logger, u.Debug) if err != nil { leaky.logger.Warningf("cancel_on error : %s", err) return &msg } - //only run debugger expression if condition is false - if u.CancelOnFilterDebug != nil { - u.CancelOnFilterDebug.Run(leaky.logger, condition, map[string]interface{}{"evt": &msg}) - } if condition, ok = output.(bool); !ok { leaky.logger.Warningf("cancel_on, unexpected non-bool return : %T", output) return &msg @@ -58,8 +53,8 @@ func (u *CancelOnFilter) OnBucketPour(bucketFactory *BucketFactory) func(types.E } } -func (u *CancelOnFilter) OnBucketOverflow(bucketFactory *BucketFactory) func(*Leaky, types.RuntimeAlert, *Queue) (types.RuntimeAlert, *Queue) { - return func(leaky *Leaky, alert types.RuntimeAlert, queue *Queue) (types.RuntimeAlert, *Queue) { +func (u *CancelOnFilter) OnBucketOverflow(bucketFactory *BucketFactory) func(*Leaky, types.RuntimeAlert, *types.Queue) (types.RuntimeAlert, *types.Queue) { + return func(leaky *Leaky, alert types.RuntimeAlert, queue *types.Queue) (types.RuntimeAlert, *types.Queue) { return alert, queue } } @@ -73,14 +68,12 @@ func (u *CancelOnFilter) AfterBucketPour(bucketFactory *BucketFactory) func(type func (u *CancelOnFilter) OnBucketInit(bucketFactory *BucketFactory) error { var err error var compiledExpr struct { - CancelOnFilter *vm.Program - CancelOnFilterDebug *exprhelpers.ExprDebugger + CancelOnFilter *vm.Program } if cancelExprCache == nil { cancelExprCache = make(map[string]struct { - CancelOnFilter *vm.Program - CancelOnFilterDebug *exprhelpers.ExprDebugger + CancelOnFilter *vm.Program }) } @@ -88,30 +81,23 @@ func (u *CancelOnFilter) OnBucketInit(bucketFactory *BucketFactory) error { if compiled, ok := cancelExprCache[bucketFactory.CancelOnFilter]; ok { cancelExprCacheLock.Unlock() u.CancelOnFilter = compiled.CancelOnFilter - u.CancelOnFilterDebug = compiled.CancelOnFilterDebug return nil - } else { - cancelExprCacheLock.Unlock() - //release the lock during compile + } - compiledExpr.CancelOnFilter, err = expr.Compile(bucketFactory.CancelOnFilter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) - if err != nil { - bucketFactory.logger.Errorf("reset_filter compile error : %s", err) - return err - } - u.CancelOnFilter = compiledExpr.CancelOnFilter - if bucketFactory.Debug { - compiledExpr.CancelOnFilterDebug, err = exprhelpers.NewDebugger(bucketFactory.CancelOnFilter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})..., - ) - if err != nil { - bucketFactory.logger.Errorf("reset_filter debug error : %s", err) - return err - } - u.CancelOnFilterDebug = compiledExpr.CancelOnFilterDebug - } - cancelExprCacheLock.Lock() - cancelExprCache[bucketFactory.CancelOnFilter] = compiledExpr - cancelExprCacheLock.Unlock() + cancelExprCacheLock.Unlock() + //release the lock during compile + + compiledExpr.CancelOnFilter, err = expr.Compile(bucketFactory.CancelOnFilter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) + if err != nil { + bucketFactory.logger.Errorf("reset_filter compile error : %s", err) + return err } - return err + u.CancelOnFilter = compiledExpr.CancelOnFilter + if bucketFactory.Debug { + u.Debug = true + } + cancelExprCacheLock.Lock() + cancelExprCache[bucketFactory.CancelOnFilter] = compiledExpr + cancelExprCacheLock.Unlock() + return nil } diff --git a/pkg/leakybucket/tests/hub/index.json b/pkg/leakybucket/tests/hub/index.json new file mode 100644 index 00000000000..0967ef424bc --- /dev/null +++ b/pkg/leakybucket/tests/hub/index.json @@ -0,0 +1 @@ +{} diff --git a/pkg/leakybucket/timemachine.go b/pkg/leakybucket/timemachine.go index 6e84797d4ee..34073d1cc5c 100644 --- a/pkg/leakybucket/timemachine.go +++ b/pkg/leakybucket/timemachine.go @@ -3,8 +3,9 @@ package leakybucket import ( "time" - "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/pkg/types" ) func TimeMachinePour(l *Leaky, msg types.Event) { @@ -23,7 +24,7 @@ func TimeMachinePour(l *Leaky, msg types.Event) { err = d.UnmarshalText([]byte(msg.MarshaledTime)) if err != nil { - log.Warningf("Failed unmarshaling event time (%s) : %v", msg.MarshaledTime, err) + log.Warningf("Failed to parse event time (%s) : %v", msg.MarshaledTime, err) return } @@ -35,7 +36,7 @@ func TimeMachinePour(l *Leaky, msg types.Event) { } l.Last_ts = d l.mutex.Unlock() - if l.Limiter.AllowN(d, 1) { + if l.Limiter.AllowN(d, 1) || l.conditionalOverflow { l.logger.Tracef("Time-Pouring event %s (tokens:%f)", d, l.Limiter.GetTokensCount()) l.Queue.Add(msg) } else { diff --git a/pkg/leakybucket/trigger.go b/pkg/leakybucket/trigger.go index d50d7ecc732..d13e57856f9 100644 --- a/pkg/leakybucket/trigger.go +++ b/pkg/leakybucket/trigger.go @@ -3,8 +3,9 @@ package leakybucket import ( "time" - "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/pkg/types" ) type Trigger struct { @@ -15,25 +16,31 @@ func (t *Trigger) OnBucketPour(b *BucketFactory) func(types.Event, *Leaky) *type // Pour makes the bucket overflow all the time // TriggerPour unconditionally overflows return func(msg types.Event, l *Leaky) *types.Event { + now := time.Now().UTC() + if l.Mode == types.TIMEMACHINE { var d time.Time + err := d.UnmarshalText([]byte(msg.MarshaledTime)) if err != nil { - log.Warningf("Failed unmarshaling event time (%s) : %v", msg.MarshaledTime, err) - d = time.Now().UTC() + log.Warningf("Failed to parse event time (%s) : %v", msg.MarshaledTime, err) + + d = now } + l.logger.Debugf("yay timemachine overflow time : %s --> %s", d, msg.MarshaledTime) l.Last_ts = d l.First_ts = d l.Ovflw_ts = d } else { - l.Last_ts = time.Now().UTC() - l.First_ts = time.Now().UTC() - l.Ovflw_ts = time.Now().UTC() + l.Last_ts = now + l.First_ts = now + l.Ovflw_ts = now } + l.Total_count = 1 - l.logger.Infof("Bucket overflow") + l.logger.Debug("Bucket overflow") l.Queue.Add(msg) l.Out <- l.Queue diff --git a/pkg/leakybucket/uniq.go b/pkg/leakybucket/uniq.go index cb8bf63fe93..0cc0583390b 100644 --- a/pkg/leakybucket/uniq.go +++ b/pkg/leakybucket/uniq.go @@ -3,8 +3,8 @@ package leakybucket import ( "sync" - "github.com/antonmedv/expr" - "github.com/antonmedv/expr/vm" + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" @@ -39,16 +39,14 @@ func (u *Uniq) OnBucketPour(bucketFactory *BucketFactory) func(types.Event, *Lea leaky.logger.Debugf("Uniq(%s) : ok", element) u.KeyCache[element] = true return &msg - - } else { - leaky.logger.Debugf("Uniq(%s) : ko, discard event", element) - return nil } + leaky.logger.Debugf("Uniq(%s) : ko, discard event", element) + return nil } } -func (u *Uniq) OnBucketOverflow(bucketFactory *BucketFactory) func(*Leaky, types.RuntimeAlert, *Queue) (types.RuntimeAlert, *Queue) { - return func(leaky *Leaky, alert types.RuntimeAlert, queue *Queue) (types.RuntimeAlert, *Queue) { +func (u *Uniq) OnBucketOverflow(bucketFactory *BucketFactory) func(*Leaky, types.RuntimeAlert, *types.Queue) (types.RuntimeAlert, *types.Queue) { + return func(leaky *Leaky, alert types.RuntimeAlert, queue *types.Queue) (types.RuntimeAlert, *types.Queue) { return alert, queue } } diff --git a/pkg/longpollclient/client.go b/pkg/longpollclient/client.go index 587826452db..5a7af0bfa63 100644 --- a/pkg/longpollclient/client.go +++ b/pkg/longpollclient/client.go @@ -2,6 +2,7 @@ package longpollclient import ( "encoding/json" + "errors" "fmt" "io" "net/http" @@ -45,7 +46,7 @@ type pollResponse struct { ErrorMessage string `json:"error"` } -var errUnauthorized = fmt.Errorf("user is not authorized to use PAPI") +var errUnauthorized = errors.New("user is not authorized to use PAPI") const timeoutMessage = "no events before timeout" @@ -73,11 +74,9 @@ func (c *LongPollClient) doQuery() (*http.Response, error) { } func (c *LongPollClient) poll() error { - logger := c.logger.WithField("method", "poll") resp, err := c.doQuery() - if err != nil { return err } @@ -94,7 +93,7 @@ func (c *LongPollClient) poll() error { logger.Errorf("failed to read response body: %s", err) return err } - logger.Errorf(string(bodyContent)) + logger.Error(string(bodyContent)) return errUnauthorized } return fmt.Errorf("unexpected status code: %d", resp.StatusCode) @@ -112,7 +111,7 @@ func (c *LongPollClient) poll() error { var pollResp pollResponse err = decoder.Decode(&pollResp) if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { logger.Debugf("server closed connection") return nil } @@ -121,7 +120,7 @@ func (c *LongPollClient) poll() error { logger.Tracef("got response: %+v", pollResp) - if len(pollResp.ErrorMessage) > 0 { + if pollResp.ErrorMessage != "" { if pollResp.ErrorMessage == timeoutMessage { logger.Debugf("got timeout message") return nil @@ -158,7 +157,7 @@ func (c *LongPollClient) pollEvents() error { err := c.poll() if err != nil { c.logger.Errorf("failed to poll: %s", err) - if err == errUnauthorized { + if errors.Is(err, errUnauthorized) { c.t.Kill(err) close(c.c) return err @@ -193,32 +192,38 @@ func (c *LongPollClient) PullOnce(since time.Time) ([]Event, error) { } defer resp.Body.Close() decoder := json.NewDecoder(resp.Body) - var pollResp pollResponse - err = decoder.Decode(&pollResp) - if err != nil { - if err == io.EOF { - c.logger.Debugf("server closed connection") - return nil, nil + evts := []Event{} + for { + var pollResp pollResponse + err = decoder.Decode(&pollResp) + if err != nil { + if errors.Is(err, io.EOF) { + c.logger.Debugf("server closed connection") + break + } + log.Errorf("error decoding poll response: %v", err) + break } - return nil, fmt.Errorf("error decoding poll response: %v", err) - } - c.logger.Tracef("got response: %+v", pollResp) + c.logger.Tracef("got response: %+v", pollResp) - if len(pollResp.ErrorMessage) > 0 { - if pollResp.ErrorMessage == timeoutMessage { - c.logger.Debugf("got timeout message") - return nil, nil + if pollResp.ErrorMessage != "" { + if pollResp.ErrorMessage == timeoutMessage { + c.logger.Debugf("got timeout message") + break + } + log.Errorf("longpoll API error message: %s", pollResp.ErrorMessage) + break } - return nil, fmt.Errorf("longpoll API error message: %s", pollResp.ErrorMessage) + evts = append(evts, pollResp.Events...) } - return pollResp.Events, nil + return evts, nil } func NewLongPollClient(config LongPollClientConfig) (*LongPollClient, error) { var logger *log.Entry if config.Url == (url.URL{}) { - return nil, fmt.Errorf("url is required") + return nil, errors.New("url is required") } if config.Logger == nil { logger = log.WithField("component", "longpollclient") diff --git a/pkg/metabase/api.go b/pkg/metabase/api.go index 7235ff7f104..08e10188678 100644 --- a/pkg/metabase/api.go +++ b/pkg/metabase/api.go @@ -6,10 +6,10 @@ import ( "net/http" "time" - "github.com/crowdsecurity/go-cs-lib/pkg/version" - "github.com/dghubble/sling" log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/pkg/apiclient/useragent" ) type MBClient struct { @@ -38,7 +38,7 @@ var ( func NewMBClient(url string) (*MBClient, error) { httpClient := &http.Client{Timeout: 20 * time.Second} return &MBClient{ - CTX: sling.New().Client(httpClient).Base(url).Set("User-Agent", fmt.Sprintf("crowdsec/%s", version.String())), + CTX: sling.New().Client(httpClient).Base(url).Set("User-Agent", useragent.Default()), Client: httpClient, }, nil } @@ -79,7 +79,7 @@ func (h *MBClient) Do(method string, route string, body interface{}) (interface{ return Success, Error, err } -// Set set headers as key:value +// Set headers as key:value func (h *MBClient) Set(key string, value string) { h.CTX = h.CTX.Set(key, value) } diff --git a/pkg/metabase/container.go b/pkg/metabase/container.go index d30fed2d5fa..8b3dd4084c0 100644 --- a/pkg/metabase/container.go +++ b/pkg/metabase/container.go @@ -4,7 +4,6 @@ import ( "bufio" "context" "fmt" - "runtime" "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/container" @@ -13,7 +12,7 @@ import ( "github.com/docker/go-connections/nat" log "github.com/sirupsen/logrus" - "github.com/crowdsecurity/go-cs-lib/pkg/ptr" + "github.com/crowdsecurity/go-cs-lib/ptr" ) type Container struct { @@ -93,14 +92,6 @@ func (c *Container) Create() error { Tty: true, Env: env, } - os := runtime.GOOS - switch os { - case "linux": - case "windows", "darwin": - return fmt.Errorf("Mac and Windows are not supported yet") - default: - return fmt.Errorf("OS '%s' is not supported", os) - } log.Infof("creating container '%s'", c.Name) resp, err := c.CLI.ContainerCreate(ctx, dockerConfig, hostConfig, nil, nil, c.Name) @@ -161,15 +152,15 @@ func RemoveContainer(name string) error { return nil } -func RemoveImageContainer() error { +func RemoveImageContainer(image string) error { cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation()) if err != nil { return fmt.Errorf("failed to create docker client : %s", err) } ctx := context.Background() - log.Printf("Removing docker image '%s'", metabaseImage) - if _, err := cli.ImageRemove(ctx, metabaseImage, types.ImageRemoveOptions{}); err != nil { - return fmt.Errorf("failed to remove image container %s : %s", metabaseImage, err) + log.Printf("Removing docker image '%s'", image) + if _, err := cli.ImageRemove(ctx, image, types.ImageRemoveOptions{}); err != nil { + return fmt.Errorf("failed to remove image container %s : %s", image, err) } return nil } diff --git a/pkg/metabase/metabase.go b/pkg/metabase/metabase.go index cdbe65ec8df..324a05666a1 100644 --- a/pkg/metabase/metabase.go +++ b/pkg/metabase/metabase.go @@ -9,7 +9,7 @@ import ( "io" "net/http" "os" - "path" + "path/filepath" "runtime" "strings" "time" @@ -38,12 +38,12 @@ type Config struct { Password string `yaml:"password"` DBPath string `yaml:"metabase_db_path"` DockerGroupID string `yaml:"-"` + Image string `yaml:"image"` } var ( metabaseDefaultUser = "crowdsec@crowdsec.net" metabaseDefaultPassword = "!!Cr0wdS3c_M3t4b4s3??" - metabaseImage = "metabase/metabase:v0.41.5" containerSharedFolder = "/metabase-data" metabaseSQLiteDBURL = "https://crowdsec-statics-assets.s3-eu-west-1.amazonaws.com/metabase_sqlite.zip" ) @@ -63,19 +63,19 @@ func TestAvailability() error { } -func (m *Metabase) Init(containerName string) error { +func (m *Metabase) Init(containerName string, image string) error { var err error var DBConnectionURI string var remoteDBAddr string switch m.Config.Database.Type { case "mysql": - return fmt.Errorf("'mysql' is not supported yet for cscli dashboard") + return errors.New("'mysql' is not supported yet for cscli dashboard") //DBConnectionURI = fmt.Sprintf("MB_DB_CONNECTION_URI=mysql://%s:%d/%s?user=%s&password=%s&allowPublicKeyRetrieval=true", remoteDBAddr, m.Config.Database.Port, m.Config.Database.DbName, m.Config.Database.User, m.Config.Database.Password) case "sqlite": m.InternalDBURL = metabaseSQLiteDBURL case "postgresql", "postgres", "pgsql": - return fmt.Errorf("'postgresql' is not supported yet by cscli dashboard") + return errors.New("'postgresql' is not supported yet by cscli dashboard") default: return fmt.Errorf("database '%s' not supported", m.Config.Database.Type) } @@ -88,20 +88,19 @@ func (m *Metabase) Init(containerName string) error { if err != nil { return err } - m.Container, err = NewContainer(m.Config.ListenAddr, m.Config.ListenPort, m.Config.DBPath, containerName, metabaseImage, DBConnectionURI, m.Config.DockerGroupID) + m.Container, err = NewContainer(m.Config.ListenAddr, m.Config.ListenPort, m.Config.DBPath, containerName, image, DBConnectionURI, m.Config.DockerGroupID) if err != nil { return fmt.Errorf("container init: %w", err) } return nil } - func NewMetabase(configPath string, containerName string) (*Metabase, error) { m := &Metabase{} if err := m.LoadConfig(configPath); err != nil { return m, err } - if err := m.Init(containerName); err != nil { + if err := m.Init(containerName, m.Config.Image); err != nil { return m, err } return m, nil @@ -130,14 +129,18 @@ func (m *Metabase) LoadConfig(configPath string) error { if config.ListenURL == "" { return fmt.Errorf("'listen_url' not found in configuration file '%s'", configPath) } - + /* Default image for backporting */ + if config.Image == "" { + config.Image = "metabase/metabase:v0.41.5" + log.Warn("Image not found in configuration file, you are using an old dashboard setup (v0.41.5), please remove your dashboard and re-create it to use the latest version.") + } m.Config = config return nil } -func SetupMetabase(dbConfig *csconfig.DatabaseCfg, listenAddr string, listenPort string, username string, password string, mbDBPath string, dockerGroupID string, containerName string) (*Metabase, error) { +func SetupMetabase(dbConfig *csconfig.DatabaseCfg, listenAddr string, listenPort string, username string, password string, mbDBPath string, dockerGroupID string, containerName string, image string) (*Metabase, error) { metabase := &Metabase{ Config: &Config{ Database: dbConfig, @@ -148,9 +151,10 @@ func SetupMetabase(dbConfig *csconfig.DatabaseCfg, listenAddr string, listenPort ListenURL: fmt.Sprintf("http://%s:%s", listenAddr, listenPort), DBPath: mbDBPath, DockerGroupID: dockerGroupID, + Image: image, }, } - if err := metabase.Init(containerName); err != nil { + if err := metabase.Init(containerName, image); err != nil { return nil, fmt.Errorf("metabase setup init: %w", err) } @@ -307,7 +311,7 @@ func (m *Metabase) DumpConfig(path string) error { func (m *Metabase) DownloadDatabase(force bool) error { - metabaseDBSubpath := path.Join(m.Config.DBPath, "metabase.db") + metabaseDBSubpath := filepath.Join(m.Config.DBPath, "metabase.db") _, err := os.Stat(metabaseDBSubpath) if err == nil && !force { log.Printf("%s exists, skip.", metabaseDBSubpath) @@ -379,5 +383,5 @@ func (m *Metabase) ExtractDatabase(buf *bytes.Reader) error { } func RemoveDatabase(dataDir string) error { - return os.RemoveAll(path.Join(dataDir, "metabase.db")) + return os.RemoveAll(filepath.Join(dataDir, "metabase.db")) } diff --git a/pkg/models/add_alerts_request.go b/pkg/models/add_alerts_request.go index fd7246be066..a69934ef770 100644 --- a/pkg/models/add_alerts_request.go +++ b/pkg/models/add_alerts_request.go @@ -54,6 +54,11 @@ func (m AddAlertsRequest) ContextValidate(ctx context.Context, formats strfmt.Re for i := 0; i < len(m); i++ { if m[i] != nil { + + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/models/alert.go b/pkg/models/alert.go index ec769a1fbb1..895f5ad76e1 100644 --- a/pkg/models/alert.go +++ b/pkg/models/alert.go @@ -399,6 +399,11 @@ func (m *Alert) contextValidateDecisions(ctx context.Context, formats strfmt.Reg for i := 0; i < len(m.Decisions); i++ { if m.Decisions[i] != nil { + + if swag.IsZero(m.Decisions[i]) { // not required + return nil + } + if err := m.Decisions[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("decisions" + "." + strconv.Itoa(i)) @@ -419,6 +424,11 @@ func (m *Alert) contextValidateEvents(ctx context.Context, formats strfmt.Regist for i := 0; i < len(m.Events); i++ { if m.Events[i] != nil { + + if swag.IsZero(m.Events[i]) { // not required + return nil + } + if err := m.Events[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("events" + "." + strconv.Itoa(i)) @@ -469,6 +479,7 @@ func (m *Alert) contextValidateMeta(ctx context.Context, formats strfmt.Registry func (m *Alert) contextValidateSource(ctx context.Context, formats strfmt.Registry) error { if m.Source != nil { + if err := m.Source.ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("source") diff --git a/pkg/models/all_metrics.go b/pkg/models/all_metrics.go new file mode 100644 index 00000000000..5865070e8ef --- /dev/null +++ b/pkg/models/all_metrics.go @@ -0,0 +1,234 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + "strconv" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" +) + +// AllMetrics AllMetrics +// +// swagger:model AllMetrics +type AllMetrics struct { + + // lapi + Lapi *LapiMetrics `json:"lapi,omitempty"` + + // log processors metrics + LogProcessors []*LogProcessorsMetrics `json:"log_processors"` + + // remediation components metrics + RemediationComponents []*RemediationComponentsMetrics `json:"remediation_components"` +} + +// Validate validates this all metrics +func (m *AllMetrics) Validate(formats strfmt.Registry) error { + var res []error + + if err := m.validateLapi(formats); err != nil { + res = append(res, err) + } + + if err := m.validateLogProcessors(formats); err != nil { + res = append(res, err) + } + + if err := m.validateRemediationComponents(formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *AllMetrics) validateLapi(formats strfmt.Registry) error { + if swag.IsZero(m.Lapi) { // not required + return nil + } + + if m.Lapi != nil { + if err := m.Lapi.Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("lapi") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("lapi") + } + return err + } + } + + return nil +} + +func (m *AllMetrics) validateLogProcessors(formats strfmt.Registry) error { + if swag.IsZero(m.LogProcessors) { // not required + return nil + } + + for i := 0; i < len(m.LogProcessors); i++ { + if swag.IsZero(m.LogProcessors[i]) { // not required + continue + } + + if m.LogProcessors[i] != nil { + if err := m.LogProcessors[i].Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("log_processors" + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("log_processors" + "." + strconv.Itoa(i)) + } + return err + } + } + + } + + return nil +} + +func (m *AllMetrics) validateRemediationComponents(formats strfmt.Registry) error { + if swag.IsZero(m.RemediationComponents) { // not required + return nil + } + + for i := 0; i < len(m.RemediationComponents); i++ { + if swag.IsZero(m.RemediationComponents[i]) { // not required + continue + } + + if m.RemediationComponents[i] != nil { + if err := m.RemediationComponents[i].Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("remediation_components" + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("remediation_components" + "." + strconv.Itoa(i)) + } + return err + } + } + + } + + return nil +} + +// ContextValidate validate this all metrics based on the context it is used +func (m *AllMetrics) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + var res []error + + if err := m.contextValidateLapi(ctx, formats); err != nil { + res = append(res, err) + } + + if err := m.contextValidateLogProcessors(ctx, formats); err != nil { + res = append(res, err) + } + + if err := m.contextValidateRemediationComponents(ctx, formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *AllMetrics) contextValidateLapi(ctx context.Context, formats strfmt.Registry) error { + + if m.Lapi != nil { + + if swag.IsZero(m.Lapi) { // not required + return nil + } + + if err := m.Lapi.ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("lapi") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("lapi") + } + return err + } + } + + return nil +} + +func (m *AllMetrics) contextValidateLogProcessors(ctx context.Context, formats strfmt.Registry) error { + + for i := 0; i < len(m.LogProcessors); i++ { + + if m.LogProcessors[i] != nil { + + if swag.IsZero(m.LogProcessors[i]) { // not required + return nil + } + + if err := m.LogProcessors[i].ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("log_processors" + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("log_processors" + "." + strconv.Itoa(i)) + } + return err + } + } + + } + + return nil +} + +func (m *AllMetrics) contextValidateRemediationComponents(ctx context.Context, formats strfmt.Registry) error { + + for i := 0; i < len(m.RemediationComponents); i++ { + + if m.RemediationComponents[i] != nil { + + if swag.IsZero(m.RemediationComponents[i]) { // not required + return nil + } + + if err := m.RemediationComponents[i].ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("remediation_components" + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("remediation_components" + "." + strconv.Itoa(i)) + } + return err + } + } + + } + + return nil +} + +// MarshalBinary interface implementation +func (m *AllMetrics) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *AllMetrics) UnmarshalBinary(b []byte) error { + var res AllMetrics + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/pkg/models/base_metrics.go b/pkg/models/base_metrics.go new file mode 100644 index 00000000000..94691ea233e --- /dev/null +++ b/pkg/models/base_metrics.go @@ -0,0 +1,215 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + "strconv" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" + "github.com/go-openapi/validate" +) + +// BaseMetrics BaseMetrics +// +// swagger:model BaseMetrics +type BaseMetrics struct { + + // feature flags (expected to be empty for remediation components) + FeatureFlags []string `json:"feature_flags"` + + // metrics details + Metrics []*DetailedMetrics `json:"metrics"` + + // os + Os *OSversion `json:"os,omitempty"` + + // UTC timestamp of the startup of the software + // Required: true + UtcStartupTimestamp *int64 `json:"utc_startup_timestamp"` + + // version of the remediation component + // Required: true + // Max Length: 255 + Version *string `json:"version"` +} + +// Validate validates this base metrics +func (m *BaseMetrics) Validate(formats strfmt.Registry) error { + var res []error + + if err := m.validateMetrics(formats); err != nil { + res = append(res, err) + } + + if err := m.validateOs(formats); err != nil { + res = append(res, err) + } + + if err := m.validateUtcStartupTimestamp(formats); err != nil { + res = append(res, err) + } + + if err := m.validateVersion(formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *BaseMetrics) validateMetrics(formats strfmt.Registry) error { + if swag.IsZero(m.Metrics) { // not required + return nil + } + + for i := 0; i < len(m.Metrics); i++ { + if swag.IsZero(m.Metrics[i]) { // not required + continue + } + + if m.Metrics[i] != nil { + if err := m.Metrics[i].Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("metrics" + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("metrics" + "." + strconv.Itoa(i)) + } + return err + } + } + + } + + return nil +} + +func (m *BaseMetrics) validateOs(formats strfmt.Registry) error { + if swag.IsZero(m.Os) { // not required + return nil + } + + if m.Os != nil { + if err := m.Os.Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("os") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("os") + } + return err + } + } + + return nil +} + +func (m *BaseMetrics) validateUtcStartupTimestamp(formats strfmt.Registry) error { + + if err := validate.Required("utc_startup_timestamp", "body", m.UtcStartupTimestamp); err != nil { + return err + } + + return nil +} + +func (m *BaseMetrics) validateVersion(formats strfmt.Registry) error { + + if err := validate.Required("version", "body", m.Version); err != nil { + return err + } + + if err := validate.MaxLength("version", "body", *m.Version, 255); err != nil { + return err + } + + return nil +} + +// ContextValidate validate this base metrics based on the context it is used +func (m *BaseMetrics) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + var res []error + + if err := m.contextValidateMetrics(ctx, formats); err != nil { + res = append(res, err) + } + + if err := m.contextValidateOs(ctx, formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *BaseMetrics) contextValidateMetrics(ctx context.Context, formats strfmt.Registry) error { + + for i := 0; i < len(m.Metrics); i++ { + + if m.Metrics[i] != nil { + + if swag.IsZero(m.Metrics[i]) { // not required + return nil + } + + if err := m.Metrics[i].ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("metrics" + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("metrics" + "." + strconv.Itoa(i)) + } + return err + } + } + + } + + return nil +} + +func (m *BaseMetrics) contextValidateOs(ctx context.Context, formats strfmt.Registry) error { + + if m.Os != nil { + + if swag.IsZero(m.Os) { // not required + return nil + } + + if err := m.Os.ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("os") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("os") + } + return err + } + } + + return nil +} + +// MarshalBinary interface implementation +func (m *BaseMetrics) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *BaseMetrics) UnmarshalBinary(b []byte) error { + var res BaseMetrics + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/pkg/models/console_options.go b/pkg/models/console_options.go new file mode 100644 index 00000000000..87983ab1762 --- /dev/null +++ b/pkg/models/console_options.go @@ -0,0 +1,27 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + + "github.com/go-openapi/strfmt" +) + +// ConsoleOptions ConsoleOptions +// +// swagger:model ConsoleOptions +type ConsoleOptions []string + +// Validate validates this console options +func (m ConsoleOptions) Validate(formats strfmt.Registry) error { + return nil +} + +// ContextValidate validates this console options based on context it is used +func (m ConsoleOptions) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + return nil +} diff --git a/pkg/models/detailed_metrics.go b/pkg/models/detailed_metrics.go new file mode 100644 index 00000000000..9e605ed8c88 --- /dev/null +++ b/pkg/models/detailed_metrics.go @@ -0,0 +1,173 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + "strconv" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" + "github.com/go-openapi/validate" +) + +// DetailedMetrics DetailedMetrics +// +// swagger:model DetailedMetrics +type DetailedMetrics struct { + + // items + // Required: true + Items []*MetricsDetailItem `json:"items"` + + // meta + // Required: true + Meta *MetricsMeta `json:"meta"` +} + +// Validate validates this detailed metrics +func (m *DetailedMetrics) Validate(formats strfmt.Registry) error { + var res []error + + if err := m.validateItems(formats); err != nil { + res = append(res, err) + } + + if err := m.validateMeta(formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *DetailedMetrics) validateItems(formats strfmt.Registry) error { + + if err := validate.Required("items", "body", m.Items); err != nil { + return err + } + + for i := 0; i < len(m.Items); i++ { + if swag.IsZero(m.Items[i]) { // not required + continue + } + + if m.Items[i] != nil { + if err := m.Items[i].Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("items" + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("items" + "." + strconv.Itoa(i)) + } + return err + } + } + + } + + return nil +} + +func (m *DetailedMetrics) validateMeta(formats strfmt.Registry) error { + + if err := validate.Required("meta", "body", m.Meta); err != nil { + return err + } + + if m.Meta != nil { + if err := m.Meta.Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("meta") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("meta") + } + return err + } + } + + return nil +} + +// ContextValidate validate this detailed metrics based on the context it is used +func (m *DetailedMetrics) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + var res []error + + if err := m.contextValidateItems(ctx, formats); err != nil { + res = append(res, err) + } + + if err := m.contextValidateMeta(ctx, formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *DetailedMetrics) contextValidateItems(ctx context.Context, formats strfmt.Registry) error { + + for i := 0; i < len(m.Items); i++ { + + if m.Items[i] != nil { + + if swag.IsZero(m.Items[i]) { // not required + return nil + } + + if err := m.Items[i].ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("items" + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("items" + "." + strconv.Itoa(i)) + } + return err + } + } + + } + + return nil +} + +func (m *DetailedMetrics) contextValidateMeta(ctx context.Context, formats strfmt.Registry) error { + + if m.Meta != nil { + + if err := m.Meta.ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("meta") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("meta") + } + return err + } + } + + return nil +} + +// MarshalBinary interface implementation +func (m *DetailedMetrics) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *DetailedMetrics) UnmarshalBinary(b []byte) error { + var res DetailedMetrics + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/pkg/models/generate.go b/pkg/models/generate.go new file mode 100644 index 00000000000..502d6f3d2cf --- /dev/null +++ b/pkg/models/generate.go @@ -0,0 +1,4 @@ +package models + +//go:generate go run -mod=mod github.com/go-swagger/go-swagger/cmd/swagger@v0.31.0 generate model --spec=./localapi_swagger.yaml --target=../ + diff --git a/pkg/models/get_alerts_response.go b/pkg/models/get_alerts_response.go index 41b9d5afdbd..d4ea36e02c5 100644 --- a/pkg/models/get_alerts_response.go +++ b/pkg/models/get_alerts_response.go @@ -54,6 +54,11 @@ func (m GetAlertsResponse) ContextValidate(ctx context.Context, formats strfmt.R for i := 0; i < len(m); i++ { if m[i] != nil { + + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/models/get_decisions_response.go b/pkg/models/get_decisions_response.go index b65b950fc58..19437dc9b38 100644 --- a/pkg/models/get_decisions_response.go +++ b/pkg/models/get_decisions_response.go @@ -54,6 +54,11 @@ func (m GetDecisionsResponse) ContextValidate(ctx context.Context, formats strfm for i := 0; i < len(m); i++ { if m[i] != nil { + + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/models/helpers.go b/pkg/models/helpers.go index 8c082550d48..5bc3f2a28b3 100644 --- a/pkg/models/helpers.go +++ b/pkg/models/helpers.go @@ -1,27 +1,33 @@ package models -func (a *Alert) HasRemediation() bool { - return true -} +import ( + "fmt" + + "github.com/davecgh/go-spew/spew" + log "github.com/sirupsen/logrus" +) + +const ( + // these are duplicated from pkg/types + // TODO XXX: de-duplicate + Ip = "Ip" + Range = "Range" + CscliImportOrigin = "cscli-import" +) func (a *Alert) GetScope() string { - if a.Source.Scope == nil { - return "" - } - return *a.Source.Scope + return a.Source.GetScope() } func (a *Alert) GetValue() string { - if a.Source.Value == nil { - return "" - } - return *a.Source.Value + return a.Source.GetValue() } func (a *Alert) GetScenario() string { if a.Scenario == nil { return "" } + return *a.Scenario } @@ -29,6 +35,7 @@ func (a *Alert) GetEventsCount() int32 { if a.EventsCount == nil { return 0 } + return *a.EventsCount } @@ -38,6 +45,7 @@ func (e *Event) GetMeta(key string) string { return meta.Value } } + return "" } @@ -47,6 +55,7 @@ func (a *Alert) GetMeta(key string) string { return meta.Value } } + return "" } @@ -54,6 +63,7 @@ func (s Source) GetValue() string { if s.Value == nil { return "" } + return *s.Value } @@ -61,6 +71,7 @@ func (s Source) GetScope() string { if s.Scope == nil { return "" } + return *s.Scope } @@ -69,8 +80,88 @@ func (s Source) GetAsNumberName() string { if s.AsNumber != "0" { ret += s.AsNumber } + if s.AsName != "" { ret += " " + s.AsName } + return ret } + +func (s *Source) String() string { + if s == nil || s.Scope == nil || *s.Scope == "" { + return "empty source" + } + + cn := s.Cn + + if s.AsNumber != "" { + cn += "/" + s.AsNumber + } + + if cn != "" { + cn = " (" + cn + ")" + } + + switch *s.Scope { + case Ip: + return "ip " + *s.Value + cn + case Range: + return "range " + *s.Value + cn + default: + return *s.Scope + " " + *s.Value + } +} + +func (a *Alert) FormatAsStrings(machineID string, logger *log.Logger) []string { + src := a.Source.String() + + msg := "empty scenario" + if a.Scenario != nil && *a.Scenario != "" { + msg = *a.Scenario + } else if a.Message != nil && *a.Message != "" { + msg = *a.Message + } + + reason := fmt.Sprintf("%s by %s", msg, src) + + if len(a.Decisions) == 0 { + return []string{fmt.Sprintf("(%s) alert : %s", machineID, reason)} + } + + var retStr []string + + if a.Decisions[0].Origin != nil && *a.Decisions[0].Origin == CscliImportOrigin { + return []string{fmt.Sprintf("(%s) alert : %s", machineID, reason)} + } + + for i, decisionItem := range a.Decisions { + decision := "" + if a.Simulated != nil && *a.Simulated { + decision = "(simulated alert)" + } else if decisionItem.Simulated != nil && *decisionItem.Simulated { + decision = "(simulated decision)" + } + + if logger.GetLevel() >= log.DebugLevel { + /*spew is expensive*/ + logger.Debug(spew.Sdump(decisionItem)) + } + + if len(a.Decisions) > 1 { + reason = fmt.Sprintf("%s for %d/%d decisions", msg, i+1, len(a.Decisions)) + } + + origin := *decisionItem.Origin + if machineID != "" { + origin = machineID + "/" + origin + } + + decision += fmt.Sprintf("%s %s on %s %s", *decisionItem.Duration, + *decisionItem.Type, *decisionItem.Scope, *decisionItem.Value) + retStr = append(retStr, + fmt.Sprintf("(%s) %s : %s", origin, reason, decision)) + } + + return retStr +} diff --git a/pkg/models/hub_item.go b/pkg/models/hub_item.go new file mode 100644 index 00000000000..c2bac3702c2 --- /dev/null +++ b/pkg/models/hub_item.go @@ -0,0 +1,56 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" +) + +// HubItem HubItem +// +// swagger:model HubItem +type HubItem struct { + + // name of the hub item + Name string `json:"name,omitempty"` + + // status of the hub item (official, custom, tainted, etc.) + Status string `json:"status,omitempty"` + + // version of the hub item + Version string `json:"version,omitempty"` +} + +// Validate validates this hub item +func (m *HubItem) Validate(formats strfmt.Registry) error { + return nil +} + +// ContextValidate validates this hub item based on context it is used +func (m *HubItem) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + return nil +} + +// MarshalBinary interface implementation +func (m *HubItem) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *HubItem) UnmarshalBinary(b []byte) error { + var res HubItem + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/pkg/models/hub_items.go b/pkg/models/hub_items.go new file mode 100644 index 00000000000..82388d5b97e --- /dev/null +++ b/pkg/models/hub_items.go @@ -0,0 +1,83 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + "strconv" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" + "github.com/go-openapi/validate" +) + +// HubItems HubItems +// +// swagger:model HubItems +type HubItems map[string][]HubItem + +// Validate validates this hub items +func (m HubItems) Validate(formats strfmt.Registry) error { + var res []error + + for k := range m { + + if err := validate.Required(k, "body", m[k]); err != nil { + return err + } + + for i := 0; i < len(m[k]); i++ { + + if err := m[k][i].Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName(k + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName(k + "." + strconv.Itoa(i)) + } + return err + } + + } + + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +// ContextValidate validate this hub items based on the context it is used +func (m HubItems) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + var res []error + + for k := range m { + + for i := 0; i < len(m[k]); i++ { + + if swag.IsZero(m[k][i]) { // not required + return nil + } + + if err := m[k][i].ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName(k + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName(k + "." + strconv.Itoa(i)) + } + return err + } + + } + + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} diff --git a/pkg/models/lapi_metrics.go b/pkg/models/lapi_metrics.go new file mode 100644 index 00000000000..b56d92ef1f8 --- /dev/null +++ b/pkg/models/lapi_metrics.go @@ -0,0 +1,157 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" +) + +// LapiMetrics LapiMetrics +// +// swagger:model LapiMetrics +type LapiMetrics struct { + BaseMetrics + + // console options + ConsoleOptions ConsoleOptions `json:"console_options,omitempty"` +} + +// UnmarshalJSON unmarshals this object from a JSON structure +func (m *LapiMetrics) UnmarshalJSON(raw []byte) error { + // AO0 + var aO0 BaseMetrics + if err := swag.ReadJSON(raw, &aO0); err != nil { + return err + } + m.BaseMetrics = aO0 + + // AO1 + var dataAO1 struct { + ConsoleOptions ConsoleOptions `json:"console_options,omitempty"` + } + if err := swag.ReadJSON(raw, &dataAO1); err != nil { + return err + } + + m.ConsoleOptions = dataAO1.ConsoleOptions + + return nil +} + +// MarshalJSON marshals this object to a JSON structure +func (m LapiMetrics) MarshalJSON() ([]byte, error) { + _parts := make([][]byte, 0, 2) + + aO0, err := swag.WriteJSON(m.BaseMetrics) + if err != nil { + return nil, err + } + _parts = append(_parts, aO0) + var dataAO1 struct { + ConsoleOptions ConsoleOptions `json:"console_options,omitempty"` + } + + dataAO1.ConsoleOptions = m.ConsoleOptions + + jsonDataAO1, errAO1 := swag.WriteJSON(dataAO1) + if errAO1 != nil { + return nil, errAO1 + } + _parts = append(_parts, jsonDataAO1) + return swag.ConcatJSON(_parts...), nil +} + +// Validate validates this lapi metrics +func (m *LapiMetrics) Validate(formats strfmt.Registry) error { + var res []error + + // validation for a type composition with BaseMetrics + if err := m.BaseMetrics.Validate(formats); err != nil { + res = append(res, err) + } + + if err := m.validateConsoleOptions(formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *LapiMetrics) validateConsoleOptions(formats strfmt.Registry) error { + + if swag.IsZero(m.ConsoleOptions) { // not required + return nil + } + + if err := m.ConsoleOptions.Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("console_options") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("console_options") + } + return err + } + + return nil +} + +// ContextValidate validate this lapi metrics based on the context it is used +func (m *LapiMetrics) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + var res []error + + // validation for a type composition with BaseMetrics + if err := m.BaseMetrics.ContextValidate(ctx, formats); err != nil { + res = append(res, err) + } + + if err := m.contextValidateConsoleOptions(ctx, formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *LapiMetrics) contextValidateConsoleOptions(ctx context.Context, formats strfmt.Registry) error { + + if err := m.ConsoleOptions.ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("console_options") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("console_options") + } + return err + } + + return nil +} + +// MarshalBinary interface implementation +func (m *LapiMetrics) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *LapiMetrics) UnmarshalBinary(b []byte) error { + var res LapiMetrics + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/pkg/models/localapi_swagger.yaml b/pkg/models/localapi_swagger.yaml index 66132e5e36e..01bbe6f8bde 100644 --- a/pkg/models/localapi_swagger.yaml +++ b/pkg/models/localapi_swagger.yaml @@ -26,10 +26,10 @@ produces: paths: /decisions/stream: get: - description: Returns a list of new/expired decisions. Intended for bouncers that need to "stream" decisions + description: Returns a list of new/expired decisions. Intended for remediation component that need to "stream" decisions summary: getDecisionsStream tags: - - bouncers + - Remediation component operationId: getDecisionsStream deprecated: false produces: @@ -39,7 +39,7 @@ paths: in: query required: false type: boolean - description: 'If true, means that the bouncers is starting and a full list must be provided' + description: 'If true, means that the remediation component is starting and a full list must be provided' - name: scopes in: query required: false @@ -73,10 +73,10 @@ paths: security: - APIKeyAuthorizer: [] head: - description: Returns a list of new/expired decisions. Intended for bouncers that need to "stream" decisions + description: Returns a list of new/expired decisions. Intended for remediation component that need to "stream" decisions summary: GetDecisionsStream tags: - - bouncers + - Remediation component operationId: headDecisionsStream deprecated: false produces: @@ -100,7 +100,7 @@ paths: description: Returns information about existing decisions summary: getDecisions tags: - - bouncers + - Remediation component operationId: getDecisions deprecated: false produces: @@ -160,11 +160,13 @@ paths: description: "400 response" schema: $ref: "#/definitions/ErrorResponse" + security: + - APIKeyAuthorizer: [] head: description: Returns information about existing decisions summary: GetDecisions tags: - - bouncers + - Remediation component operationId: headDecisions deprecated: false produces: @@ -310,6 +312,9 @@ paths: '201': description: Watcher Created headers: {} + '202': + description: Watcher Validated + headers: {} '400': description: "400 response" schema: @@ -684,6 +689,36 @@ paths: $ref: "#/definitions/ErrorResponse" security: - JWTAuthorizer: [] + /usage-metrics: + post: + description: Post usage metrics from a LP or a bouncer + summary: Send usage metrics + tags: + - Remediation component + - watchers + operationId: usage-metrics + produces: + - application/json + parameters: + - name: body + in: body + required: true + schema: + $ref: '#/definitions/AllMetrics' + description: 'All metrics' + responses: + '200': + description: successful operation + schema: + $ref: '#/definitions/SuccessResponse' + headers: {} + '400': + description: "400 response" + schema: + $ref: "#/definitions/ErrorResponse" + security: + - APIKeyAuthorizer: [] + - JWTAuthorizer: [] definitions: WatcherRegistrationRequest: title: WatcherRegistrationRequest @@ -694,6 +729,10 @@ definitions: password: type: string format: password + registration_token: + type: string + minLength: 32 + maxLength: 255 required: - machine_id - password @@ -994,6 +1033,193 @@ definitions: type: string value: type: string + RemediationComponentsMetrics: + title: RemediationComponentsMetrics + type: object + allOf: + - $ref: '#/definitions/BaseMetrics' + - properties: + type: + type: string + description: type of the remediation component + name: + type: string + description: name of the remediation component + last_pull: + type: integer + description: last pull date + LogProcessorsMetrics: + title: LogProcessorsMetrics + type: object + allOf: + - $ref: '#/definitions/BaseMetrics' + - properties: + hub_items: + $ref: '#/definitions/HubItems' + datasources: + type: object + description: Number of datasources per type + additionalProperties: + type: integer + name: + type: string + description: name of the log processor + last_push: + type: integer + description: last push date + last_update: + type: integer + description: last update date + required: + - hub_items + - datasources + LapiMetrics: + title: LapiMetrics + type: object + allOf: + - $ref: '#/definitions/BaseMetrics' + - properties: + console_options: + $ref: '#/definitions/ConsoleOptions' + AllMetrics: + title: AllMetrics + type: object + properties: + remediation_components: + type: array + items: + $ref: '#/definitions/RemediationComponentsMetrics' + description: remediation components metrics + log_processors: + type: array + items: + $ref: '#/definitions/LogProcessorsMetrics' + description: log processors metrics + lapi: + $ref: '#/definitions/LapiMetrics' + BaseMetrics: + title: BaseMetrics + type: object + properties: + version: + type: string + description: version of the remediation component + maxLength: 255 + os: + $ref: '#/definitions/OSversion' + metrics: + type: array + items: + $ref: '#/definitions/DetailedMetrics' + description: metrics details + feature_flags: + type: array + items: + type: string + description: feature flags (expected to be empty for remediation components) + maxLength: 255 + utc_startup_timestamp: + type: integer + description: UTC timestamp of the startup of the software + required: + - version + - utc_startup_timestamp + OSversion: + title: OSversion + type: object + properties: + name: + type: string + description: name of the OS + maxLength: 255 + version: + type: string + description: version of the OS + maxLength: 255 + required: + - name + - version + DetailedMetrics: + type: object + title: DetailedMetrics + properties: + items: + type: array + items: + $ref: '#/definitions/MetricsDetailItem' + meta: + $ref: '#/definitions/MetricsMeta' + required: + - meta + - items + MetricsDetailItem: + title: MetricsDetailItem + type: object + properties: + name: + type: string + description: name of the metric + maxLength: 255 + value: + type: number + description: value of the metric + unit: + type: string + description: unit of the metric + maxLength: 255 + labels: + $ref: '#/definitions/MetricsLabels' + description: labels of the metric + required: + - name + - value + - unit + MetricsMeta: + title: MetricsMeta + type: object + properties: + window_size_seconds: + type: integer + description: Size, in seconds, of the window used to compute the metric + utc_now_timestamp: + type: integer + description: UTC timestamp of the current time + required: + - window_size_seconds + - utc_now_timestamp + MetricsLabels: + title: MetricsLabels + type: object + additionalProperties: + type: string + description: label of the metric + maxLength: 255 + ConsoleOptions: + title: ConsoleOptions + type: array + items: + type: string + description: enabled console options + HubItems: + title: HubItems + type: object + additionalProperties: + type: array + items: + $ref: '#/definitions/HubItem' + HubItem: + title: HubItem + type: object + properties: + name: + type: string + description: name of the hub item + version: + type: string + description: version of the hub item + status: + type: string + description: status of the hub item (official, custom, tainted, etc.) ErrorResponse: type: "object" required: @@ -1007,8 +1233,18 @@ definitions: description: "more detail on individual errors" title: "error response" description: "error response return by the API" + SuccessResponse: + type: "object" + required: + - "message" + properties: + message: + type: "string" + description: "message" + title: "success response" + description: "success response return by the API" tags: - - name: bouncers + - name: Remediation component description: 'Operations about decisions : bans, captcha, rate-limit etc.' - name: watchers description: 'Operations about watchers : cscli & crowdsec' diff --git a/pkg/models/log_processors_metrics.go b/pkg/models/log_processors_metrics.go new file mode 100644 index 00000000000..05b688fb994 --- /dev/null +++ b/pkg/models/log_processors_metrics.go @@ -0,0 +1,219 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" + "github.com/go-openapi/validate" +) + +// LogProcessorsMetrics LogProcessorsMetrics +// +// swagger:model LogProcessorsMetrics +type LogProcessorsMetrics struct { + BaseMetrics + + // Number of datasources per type + // Required: true + Datasources map[string]int64 `json:"datasources"` + + // hub items + // Required: true + HubItems HubItems `json:"hub_items"` + + // last push date + LastPush int64 `json:"last_push,omitempty"` + + // last update date + LastUpdate int64 `json:"last_update,omitempty"` + + // name of the log processor + Name string `json:"name,omitempty"` +} + +// UnmarshalJSON unmarshals this object from a JSON structure +func (m *LogProcessorsMetrics) UnmarshalJSON(raw []byte) error { + // AO0 + var aO0 BaseMetrics + if err := swag.ReadJSON(raw, &aO0); err != nil { + return err + } + m.BaseMetrics = aO0 + + // AO1 + var dataAO1 struct { + Datasources map[string]int64 `json:"datasources"` + + HubItems HubItems `json:"hub_items"` + + LastPush int64 `json:"last_push,omitempty"` + + LastUpdate int64 `json:"last_update,omitempty"` + + Name string `json:"name,omitempty"` + } + if err := swag.ReadJSON(raw, &dataAO1); err != nil { + return err + } + + m.Datasources = dataAO1.Datasources + + m.HubItems = dataAO1.HubItems + + m.LastPush = dataAO1.LastPush + + m.LastUpdate = dataAO1.LastUpdate + + m.Name = dataAO1.Name + + return nil +} + +// MarshalJSON marshals this object to a JSON structure +func (m LogProcessorsMetrics) MarshalJSON() ([]byte, error) { + _parts := make([][]byte, 0, 2) + + aO0, err := swag.WriteJSON(m.BaseMetrics) + if err != nil { + return nil, err + } + _parts = append(_parts, aO0) + var dataAO1 struct { + Datasources map[string]int64 `json:"datasources"` + + HubItems HubItems `json:"hub_items"` + + LastPush int64 `json:"last_push,omitempty"` + + LastUpdate int64 `json:"last_update,omitempty"` + + Name string `json:"name,omitempty"` + } + + dataAO1.Datasources = m.Datasources + + dataAO1.HubItems = m.HubItems + + dataAO1.LastPush = m.LastPush + + dataAO1.LastUpdate = m.LastUpdate + + dataAO1.Name = m.Name + + jsonDataAO1, errAO1 := swag.WriteJSON(dataAO1) + if errAO1 != nil { + return nil, errAO1 + } + _parts = append(_parts, jsonDataAO1) + return swag.ConcatJSON(_parts...), nil +} + +// Validate validates this log processors metrics +func (m *LogProcessorsMetrics) Validate(formats strfmt.Registry) error { + var res []error + + // validation for a type composition with BaseMetrics + if err := m.BaseMetrics.Validate(formats); err != nil { + res = append(res, err) + } + + if err := m.validateDatasources(formats); err != nil { + res = append(res, err) + } + + if err := m.validateHubItems(formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *LogProcessorsMetrics) validateDatasources(formats strfmt.Registry) error { + + if err := validate.Required("datasources", "body", m.Datasources); err != nil { + return err + } + + return nil +} + +func (m *LogProcessorsMetrics) validateHubItems(formats strfmt.Registry) error { + + if err := validate.Required("hub_items", "body", m.HubItems); err != nil { + return err + } + + if m.HubItems != nil { + if err := m.HubItems.Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("hub_items") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("hub_items") + } + return err + } + } + + return nil +} + +// ContextValidate validate this log processors metrics based on the context it is used +func (m *LogProcessorsMetrics) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + var res []error + + // validation for a type composition with BaseMetrics + if err := m.BaseMetrics.ContextValidate(ctx, formats); err != nil { + res = append(res, err) + } + + if err := m.contextValidateHubItems(ctx, formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *LogProcessorsMetrics) contextValidateHubItems(ctx context.Context, formats strfmt.Registry) error { + + if err := m.HubItems.ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("hub_items") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("hub_items") + } + return err + } + + return nil +} + +// MarshalBinary interface implementation +func (m *LogProcessorsMetrics) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *LogProcessorsMetrics) UnmarshalBinary(b []byte) error { + var res LogProcessorsMetrics + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/pkg/models/meta.go b/pkg/models/meta.go index 6ad20856d6a..df5ae3c6285 100644 --- a/pkg/models/meta.go +++ b/pkg/models/meta.go @@ -56,6 +56,11 @@ func (m Meta) ContextValidate(ctx context.Context, formats strfmt.Registry) erro for i := 0; i < len(m); i++ { if m[i] != nil { + + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/models/metrics.go b/pkg/models/metrics.go index 573678d1f84..7fbb91c63e4 100644 --- a/pkg/models/metrics.go +++ b/pkg/models/metrics.go @@ -141,6 +141,11 @@ func (m *Metrics) contextValidateBouncers(ctx context.Context, formats strfmt.Re for i := 0; i < len(m.Bouncers); i++ { if m.Bouncers[i] != nil { + + if swag.IsZero(m.Bouncers[i]) { // not required + return nil + } + if err := m.Bouncers[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("bouncers" + "." + strconv.Itoa(i)) @@ -161,6 +166,11 @@ func (m *Metrics) contextValidateMachines(ctx context.Context, formats strfmt.Re for i := 0; i < len(m.Machines); i++ { if m.Machines[i] != nil { + + if swag.IsZero(m.Machines[i]) { // not required + return nil + } + if err := m.Machines[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("machines" + "." + strconv.Itoa(i)) diff --git a/pkg/models/metrics_detail_item.go b/pkg/models/metrics_detail_item.go new file mode 100644 index 00000000000..bb237884fcf --- /dev/null +++ b/pkg/models/metrics_detail_item.go @@ -0,0 +1,168 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" + "github.com/go-openapi/validate" +) + +// MetricsDetailItem MetricsDetailItem +// +// swagger:model MetricsDetailItem +type MetricsDetailItem struct { + + // labels of the metric + Labels MetricsLabels `json:"labels,omitempty"` + + // name of the metric + // Required: true + // Max Length: 255 + Name *string `json:"name"` + + // unit of the metric + // Required: true + // Max Length: 255 + Unit *string `json:"unit"` + + // value of the metric + // Required: true + Value *float64 `json:"value"` +} + +// Validate validates this metrics detail item +func (m *MetricsDetailItem) Validate(formats strfmt.Registry) error { + var res []error + + if err := m.validateLabels(formats); err != nil { + res = append(res, err) + } + + if err := m.validateName(formats); err != nil { + res = append(res, err) + } + + if err := m.validateUnit(formats); err != nil { + res = append(res, err) + } + + if err := m.validateValue(formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *MetricsDetailItem) validateLabels(formats strfmt.Registry) error { + if swag.IsZero(m.Labels) { // not required + return nil + } + + if m.Labels != nil { + if err := m.Labels.Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("labels") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("labels") + } + return err + } + } + + return nil +} + +func (m *MetricsDetailItem) validateName(formats strfmt.Registry) error { + + if err := validate.Required("name", "body", m.Name); err != nil { + return err + } + + if err := validate.MaxLength("name", "body", *m.Name, 255); err != nil { + return err + } + + return nil +} + +func (m *MetricsDetailItem) validateUnit(formats strfmt.Registry) error { + + if err := validate.Required("unit", "body", m.Unit); err != nil { + return err + } + + if err := validate.MaxLength("unit", "body", *m.Unit, 255); err != nil { + return err + } + + return nil +} + +func (m *MetricsDetailItem) validateValue(formats strfmt.Registry) error { + + if err := validate.Required("value", "body", m.Value); err != nil { + return err + } + + return nil +} + +// ContextValidate validate this metrics detail item based on the context it is used +func (m *MetricsDetailItem) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + var res []error + + if err := m.contextValidateLabels(ctx, formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *MetricsDetailItem) contextValidateLabels(ctx context.Context, formats strfmt.Registry) error { + + if swag.IsZero(m.Labels) { // not required + return nil + } + + if err := m.Labels.ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("labels") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("labels") + } + return err + } + + return nil +} + +// MarshalBinary interface implementation +func (m *MetricsDetailItem) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *MetricsDetailItem) UnmarshalBinary(b []byte) error { + var res MetricsDetailItem + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/pkg/models/metrics_labels.go b/pkg/models/metrics_labels.go new file mode 100644 index 00000000000..176a15cce24 --- /dev/null +++ b/pkg/models/metrics_labels.go @@ -0,0 +1,42 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/validate" +) + +// MetricsLabels MetricsLabels +// +// swagger:model MetricsLabels +type MetricsLabels map[string]string + +// Validate validates this metrics labels +func (m MetricsLabels) Validate(formats strfmt.Registry) error { + var res []error + + for k := range m { + + if err := validate.MaxLength(k, "body", m[k], 255); err != nil { + return err + } + + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +// ContextValidate validates this metrics labels based on context it is used +func (m MetricsLabels) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + return nil +} diff --git a/pkg/models/metrics_meta.go b/pkg/models/metrics_meta.go new file mode 100644 index 00000000000..b021617e4d9 --- /dev/null +++ b/pkg/models/metrics_meta.go @@ -0,0 +1,88 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" + "github.com/go-openapi/validate" +) + +// MetricsMeta MetricsMeta +// +// swagger:model MetricsMeta +type MetricsMeta struct { + + // UTC timestamp of the current time + // Required: true + UtcNowTimestamp *int64 `json:"utc_now_timestamp"` + + // Size, in seconds, of the window used to compute the metric + // Required: true + WindowSizeSeconds *int64 `json:"window_size_seconds"` +} + +// Validate validates this metrics meta +func (m *MetricsMeta) Validate(formats strfmt.Registry) error { + var res []error + + if err := m.validateUtcNowTimestamp(formats); err != nil { + res = append(res, err) + } + + if err := m.validateWindowSizeSeconds(formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *MetricsMeta) validateUtcNowTimestamp(formats strfmt.Registry) error { + + if err := validate.Required("utc_now_timestamp", "body", m.UtcNowTimestamp); err != nil { + return err + } + + return nil +} + +func (m *MetricsMeta) validateWindowSizeSeconds(formats strfmt.Registry) error { + + if err := validate.Required("window_size_seconds", "body", m.WindowSizeSeconds); err != nil { + return err + } + + return nil +} + +// ContextValidate validates this metrics meta based on context it is used +func (m *MetricsMeta) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + return nil +} + +// MarshalBinary interface implementation +func (m *MetricsMeta) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *MetricsMeta) UnmarshalBinary(b []byte) error { + var res MetricsMeta + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/pkg/models/o_sversion.go b/pkg/models/o_sversion.go new file mode 100644 index 00000000000..8f1f43ea9cc --- /dev/null +++ b/pkg/models/o_sversion.go @@ -0,0 +1,98 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" + "github.com/go-openapi/validate" +) + +// OSversion OSversion +// +// swagger:model OSversion +type OSversion struct { + + // name of the OS + // Required: true + // Max Length: 255 + Name *string `json:"name"` + + // version of the OS + // Required: true + // Max Length: 255 + Version *string `json:"version"` +} + +// Validate validates this o sversion +func (m *OSversion) Validate(formats strfmt.Registry) error { + var res []error + + if err := m.validateName(formats); err != nil { + res = append(res, err) + } + + if err := m.validateVersion(formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *OSversion) validateName(formats strfmt.Registry) error { + + if err := validate.Required("name", "body", m.Name); err != nil { + return err + } + + if err := validate.MaxLength("name", "body", *m.Name, 255); err != nil { + return err + } + + return nil +} + +func (m *OSversion) validateVersion(formats strfmt.Registry) error { + + if err := validate.Required("version", "body", m.Version); err != nil { + return err + } + + if err := validate.MaxLength("version", "body", *m.Version, 255); err != nil { + return err + } + + return nil +} + +// ContextValidate validates this o sversion based on context it is used +func (m *OSversion) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + return nil +} + +// MarshalBinary interface implementation +func (m *OSversion) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *OSversion) UnmarshalBinary(b []byte) error { + var res OSversion + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/pkg/models/remediation_components_metrics.go b/pkg/models/remediation_components_metrics.go new file mode 100644 index 00000000000..ba3845d872a --- /dev/null +++ b/pkg/models/remediation_components_metrics.go @@ -0,0 +1,139 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" +) + +// RemediationComponentsMetrics RemediationComponentsMetrics +// +// swagger:model RemediationComponentsMetrics +type RemediationComponentsMetrics struct { + BaseMetrics + + // last pull date + LastPull int64 `json:"last_pull,omitempty"` + + // name of the remediation component + Name string `json:"name,omitempty"` + + // type of the remediation component + Type string `json:"type,omitempty"` +} + +// UnmarshalJSON unmarshals this object from a JSON structure +func (m *RemediationComponentsMetrics) UnmarshalJSON(raw []byte) error { + // AO0 + var aO0 BaseMetrics + if err := swag.ReadJSON(raw, &aO0); err != nil { + return err + } + m.BaseMetrics = aO0 + + // AO1 + var dataAO1 struct { + LastPull int64 `json:"last_pull,omitempty"` + + Name string `json:"name,omitempty"` + + Type string `json:"type,omitempty"` + } + if err := swag.ReadJSON(raw, &dataAO1); err != nil { + return err + } + + m.LastPull = dataAO1.LastPull + + m.Name = dataAO1.Name + + m.Type = dataAO1.Type + + return nil +} + +// MarshalJSON marshals this object to a JSON structure +func (m RemediationComponentsMetrics) MarshalJSON() ([]byte, error) { + _parts := make([][]byte, 0, 2) + + aO0, err := swag.WriteJSON(m.BaseMetrics) + if err != nil { + return nil, err + } + _parts = append(_parts, aO0) + var dataAO1 struct { + LastPull int64 `json:"last_pull,omitempty"` + + Name string `json:"name,omitempty"` + + Type string `json:"type,omitempty"` + } + + dataAO1.LastPull = m.LastPull + + dataAO1.Name = m.Name + + dataAO1.Type = m.Type + + jsonDataAO1, errAO1 := swag.WriteJSON(dataAO1) + if errAO1 != nil { + return nil, errAO1 + } + _parts = append(_parts, jsonDataAO1) + return swag.ConcatJSON(_parts...), nil +} + +// Validate validates this remediation components metrics +func (m *RemediationComponentsMetrics) Validate(formats strfmt.Registry) error { + var res []error + + // validation for a type composition with BaseMetrics + if err := m.BaseMetrics.Validate(formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +// ContextValidate validate this remediation components metrics based on the context it is used +func (m *RemediationComponentsMetrics) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + var res []error + + // validation for a type composition with BaseMetrics + if err := m.BaseMetrics.ContextValidate(ctx, formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +// MarshalBinary interface implementation +func (m *RemediationComponentsMetrics) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *RemediationComponentsMetrics) UnmarshalBinary(b []byte) error { + var res RemediationComponentsMetrics + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/pkg/models/success_response.go b/pkg/models/success_response.go new file mode 100644 index 00000000000..e8fc281c090 --- /dev/null +++ b/pkg/models/success_response.go @@ -0,0 +1,73 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" + "github.com/go-openapi/validate" +) + +// SuccessResponse success response +// +// success response return by the API +// +// swagger:model SuccessResponse +type SuccessResponse struct { + + // message + // Required: true + Message *string `json:"message"` +} + +// Validate validates this success response +func (m *SuccessResponse) Validate(formats strfmt.Registry) error { + var res []error + + if err := m.validateMessage(formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *SuccessResponse) validateMessage(formats strfmt.Registry) error { + + if err := validate.Required("message", "body", m.Message); err != nil { + return err + } + + return nil +} + +// ContextValidate validates this success response based on context it is used +func (m *SuccessResponse) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + return nil +} + +// MarshalBinary interface implementation +func (m *SuccessResponse) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *SuccessResponse) UnmarshalBinary(b []byte) error { + var res SuccessResponse + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/pkg/models/watcher_registration_request.go b/pkg/models/watcher_registration_request.go index 8be802ea3e7..673f0d59b9e 100644 --- a/pkg/models/watcher_registration_request.go +++ b/pkg/models/watcher_registration_request.go @@ -27,6 +27,11 @@ type WatcherRegistrationRequest struct { // Required: true // Format: password Password *strfmt.Password `json:"password"` + + // registration token + // Max Length: 255 + // Min Length: 32 + RegistrationToken string `json:"registration_token,omitempty"` } // Validate validates this watcher registration request @@ -41,6 +46,10 @@ func (m *WatcherRegistrationRequest) Validate(formats strfmt.Registry) error { res = append(res, err) } + if err := m.validateRegistrationToken(formats); err != nil { + res = append(res, err) + } + if len(res) > 0 { return errors.CompositeValidationError(res...) } @@ -69,6 +78,22 @@ func (m *WatcherRegistrationRequest) validatePassword(formats strfmt.Registry) e return nil } +func (m *WatcherRegistrationRequest) validateRegistrationToken(formats strfmt.Registry) error { + if swag.IsZero(m.RegistrationToken) { // not required + return nil + } + + if err := validate.MinLength("registration_token", "body", m.RegistrationToken, 32); err != nil { + return err + } + + if err := validate.MaxLength("registration_token", "body", m.RegistrationToken, 255); err != nil { + return err + } + + return nil +} + // ContextValidate validates this watcher registration request based on context it is used func (m *WatcherRegistrationRequest) ContextValidate(ctx context.Context, formats strfmt.Registry) error { return nil diff --git a/pkg/modelscapi/add_signals_request.go b/pkg/modelscapi/add_signals_request.go index 62fe590cb79..7bfe6ae80e0 100644 --- a/pkg/modelscapi/add_signals_request.go +++ b/pkg/modelscapi/add_signals_request.go @@ -56,6 +56,11 @@ func (m AddSignalsRequest) ContextValidate(ctx context.Context, formats strfmt.R for i := 0; i < len(m); i++ { if m[i] != nil { + + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/modelscapi/add_signals_request_item.go b/pkg/modelscapi/add_signals_request_item.go index f9c865b4c68..5f63b542d5a 100644 --- a/pkg/modelscapi/add_signals_request_item.go +++ b/pkg/modelscapi/add_signals_request_item.go @@ -65,6 +65,9 @@ type AddSignalsRequestItem struct { // stop at // Required: true StopAt *string `json:"stop_at"` + + // UUID of the alert + UUID string `json:"uuid,omitempty"` } // Validate validates this add signals request item @@ -257,6 +260,11 @@ func (m *AddSignalsRequestItem) contextValidateContext(ctx context.Context, form for i := 0; i < len(m.Context); i++ { if m.Context[i] != nil { + + if swag.IsZero(m.Context[i]) { // not required + return nil + } + if err := m.Context[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("context" + "." + strconv.Itoa(i)) @@ -289,6 +297,7 @@ func (m *AddSignalsRequestItem) contextValidateDecisions(ctx context.Context, fo func (m *AddSignalsRequestItem) contextValidateSource(ctx context.Context, formats strfmt.Registry) error { if m.Source != nil { + if err := m.Source.ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("source") diff --git a/pkg/modelscapi/add_signals_request_item_decisions.go b/pkg/modelscapi/add_signals_request_item_decisions.go index 54e123ab3f8..11ed27a496d 100644 --- a/pkg/modelscapi/add_signals_request_item_decisions.go +++ b/pkg/modelscapi/add_signals_request_item_decisions.go @@ -54,6 +54,11 @@ func (m AddSignalsRequestItemDecisions) ContextValidate(ctx context.Context, for for i := 0; i < len(m); i++ { if m[i] != nil { + + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/modelscapi/add_signals_request_item_decisions_item.go b/pkg/modelscapi/add_signals_request_item_decisions_item.go index 34dfeb5bce5..797c517e33f 100644 --- a/pkg/modelscapi/add_signals_request_item_decisions_item.go +++ b/pkg/modelscapi/add_signals_request_item_decisions_item.go @@ -49,6 +49,9 @@ type AddSignalsRequestItemDecisionsItem struct { // until Until string `json:"until,omitempty"` + // UUID of the decision + UUID string `json:"uuid,omitempty"` + // the value of the decision scope : an IP, a range, a username, etc // Required: true Value *string `json:"value"` diff --git a/pkg/modelscapi/centralapi_swagger.yaml b/pkg/modelscapi/centralapi_swagger.yaml new file mode 100644 index 00000000000..bd695894f2b --- /dev/null +++ b/pkg/modelscapi/centralapi_swagger.yaml @@ -0,0 +1,875 @@ +swagger: "2.0" +info: + description: + "API to manage machines using [crowdsec](https://github.com/crowdsecurity/crowdsec)\ + \ and bouncers.\n" + version: "2023-01-23T11:16:39Z" + title: "prod-capi-v3" + contact: + name: "Crowdsec team" + url: "https://github.com/crowdsecurity/crowdsec" + email: "support@crowdsec.net" +host: "api.crowdsec.net" +basePath: "/v3" +tags: + - name: "watchers" + description: "Operations about watchers: crowdsec & cscli" + - name: "bouncers" + description: "Operations about decisions : bans, captcha, rate-limit etc." +schemes: + - "https" +paths: + /decisions/delete: + post: + tags: + - "watchers" + summary: "delete decisions" + description: "delete provided decisions" + consumes: + - "application/json" + produces: + - "application/json" + parameters: + - in: "body" + name: "DecisionsDeleteRequest" + required: true + schema: + $ref: "#/definitions/DecisionsDeleteRequest" + responses: + "200": + description: "200 response" + schema: + $ref: "#/definitions/SuccessResponse" + "500": + description: "500 response" + schema: + $ref: "#/definitions/ErrorResponse" + security: + - UserPoolAuthorizer: [] + /decisions/stream: + get: + tags: + - "bouncers" + - "watchers" + summary: "returns list of top decisions" + description: "returns list of top decisions to add or delete" + produces: + - "application/json" + responses: + "200": + description: "200 response" + schema: + $ref: "#/definitions/GetDecisionsStreamResponse" + "400": + description: "400 response" + schema: + $ref: "#/definitions/ErrorResponse" + "500": + description: "500 response" + schema: + $ref: "#/definitions/ErrorResponse" + "404": + description: "404 response" + schema: + $ref: "#/definitions/ErrorResponse" + security: + - UserPoolAuthorizer: [] + options: + consumes: + - "application/json" + produces: + - "application/json" + responses: + "200": + description: "200 response" + headers: + Access-Control-Allow-Origin: + type: "string" + Access-Control-Allow-Methods: + type: "string" + Access-Control-Allow-Headers: + type: "string" + /decisions/sync: + post: + tags: + - "watchers" + summary: "sync decisions" + description: "sync provided decisions" + consumes: + - "application/json" + produces: + - "application/json" + parameters: + - in: "body" + name: "DecisionsSyncRequest" + required: true + schema: + $ref: "#/definitions/DecisionsSyncRequest" + responses: + "200": + description: "200 response" + schema: + $ref: "#/definitions/SuccessResponse" + "500": + description: "500 response" + schema: + $ref: "#/definitions/ErrorResponse" + security: + - UserPoolAuthorizer: [] + /metrics: + post: + tags: + - "watchers" + summary: "receive metrics about enrolled machines and bouncers in APIL" + description: "receive metrics about enrolled machines and bouncers in APIL" + consumes: + - "application/json" + produces: + - "application/json" + parameters: + - in: "body" + name: "MetricsRequest" + required: true + schema: + $ref: "#/definitions/MetricsRequest" + responses: + "200": + description: "200 response" + schema: + $ref: "#/definitions/SuccessResponse" + "400": + description: "400 response" + schema: + $ref: "#/definitions/ErrorResponse" + "500": + description: "500 response" + schema: + $ref: "#/definitions/ErrorResponse" + security: + - UserPoolAuthorizer: [] + /signals: + post: + tags: + - "watchers" + summary: "Push signals" + description: "to push signals" + consumes: + - "application/json" + produces: + - "application/json" + parameters: + - in: "body" + name: "AddSignalsRequest" + required: true + schema: + $ref: "#/definitions/AddSignalsRequest" + responses: + "200": + description: "200 response" + schema: + $ref: "#/definitions/SuccessResponse" + "400": + description: "400 response" + schema: + $ref: "#/definitions/ErrorResponse" + "500": + description: "500 response" + schema: + $ref: "#/definitions/ErrorResponse" + security: + - UserPoolAuthorizer: [] + /watchers: + post: + tags: + - "watchers" + summary: "Register watcher" + description: "Register a watcher" + consumes: + - "application/json" + produces: + - "application/json" + parameters: + - in: "body" + name: "RegisterRequest" + required: true + schema: + $ref: "#/definitions/RegisterRequest" + responses: + "200": + description: "200 response" + schema: + $ref: "#/definitions/SuccessResponse" + "400": + description: "400 response" + schema: + $ref: "#/definitions/ErrorResponse" + "500": + description: "500 response" + schema: + $ref: "#/definitions/ErrorResponse" + /watchers/enroll: + post: + tags: + - "watchers" + summary: "watcher enrollment" + description: "watcher enrollment : enroll watcher to crowdsec backoffice account" + consumes: + - "application/json" + produces: + - "application/json" + parameters: + - in: "body" + name: "EnrollRequest" + required: true + schema: + $ref: "#/definitions/EnrollRequest" + responses: + "200": + description: "200 response" + schema: + $ref: "#/definitions/SuccessResponse" + "400": + description: "400 response" + schema: + $ref: "#/definitions/ErrorResponse" + "500": + description: "500 response" + schema: + $ref: "#/definitions/ErrorResponse" + "403": + description: "403 response" + schema: + $ref: "#/definitions/ErrorResponse" + security: + - UserPoolAuthorizer: [] + /watchers/login: + post: + tags: + - "watchers" + summary: "watcher login" + description: "Sign-in to get a valid token" + consumes: + - "application/json" + produces: + - "application/json" + parameters: + - in: "body" + name: "LoginRequest" + required: true + schema: + $ref: "#/definitions/LoginRequest" + responses: + "200": + description: "200 response" + schema: + $ref: "#/definitions/LoginResponse" + "400": + description: "400 response" + schema: + $ref: "#/definitions/ErrorResponse" + "500": + description: "500 response" + schema: + $ref: "#/definitions/ErrorResponse" + "403": + description: "403 response" + schema: + $ref: "#/definitions/ErrorResponse" + /watchers/reset: + post: + tags: + - "watchers" + summary: "Reset Password" + description: "to reset a watcher password" + consumes: + - "application/json" + produces: + - "application/json" + parameters: + - in: "body" + name: "ResetPasswordRequest" + required: true + schema: + $ref: "#/definitions/ResetPasswordRequest" + responses: + "200": + description: "200 response" + schema: + $ref: "#/definitions/SuccessResponse" + headers: + Content-type: + type: "string" + Access-Control-Allow-Origin: + type: "string" + "400": + description: "400 response" + schema: + $ref: "#/definitions/ErrorResponse" + "500": + description: "500 response" + schema: + $ref: "#/definitions/ErrorResponse" + headers: + Content-type: + type: "string" + Access-Control-Allow-Origin: + type: "string" + "403": + description: "403 response" + schema: + $ref: "#/definitions/ErrorResponse" + "404": + description: "404 response" + headers: + Content-type: + type: "string" + Access-Control-Allow-Origin: + type: "string" + options: + consumes: + - "application/json" + produces: + - "application/json" + responses: + "200": + description: "200 response" + headers: + Access-Control-Allow-Origin: + type: "string" + Access-Control-Allow-Methods: + type: "string" + Access-Control-Allow-Headers: + type: "string" +securityDefinitions: + UserPoolAuthorizer: + type: "apiKey" + name: "Authorization" + in: "header" + x-amazon-apigateway-authtype: "cognito_user_pools" +definitions: + DecisionsDeleteRequest: + title: "delete decisions" + type: "array" + description: "delete decision model" + items: + $ref: "#/definitions/DecisionsDeleteRequestItem" + DecisionsSyncRequestItem: + type: "object" + required: + - "message" + - "scenario" + - "scenario_hash" + - "scenario_version" + - "source" + - "start_at" + - "stop_at" + properties: + scenario_trust: + type: "string" + scenario_hash: + type: "string" + scenario: + type: "string" + alert_id: + type: "integer" + created_at: + type: "string" + machine_id: + type: "string" + decisions: + $ref: "#/definitions/DecisionsSyncRequestItemDecisions" + source: + $ref: "#/definitions/DecisionsSyncRequestItemSource" + scenario_version: + type: "string" + message: + type: "string" + description: "a human readable message" + start_at: + type: "string" + stop_at: + type: "string" + title: "Signal" + AddSignalsRequestItem: + type: "object" + required: + - "message" + - "scenario" + - "scenario_hash" + - "scenario_version" + - "source" + - "start_at" + - "stop_at" + properties: + created_at: + type: "string" + machine_id: + type: "string" + source: + $ref: "#/definitions/AddSignalsRequestItemSource" + scenario_version: + type: "string" + message: + type: "string" + description: "a human readable message" + uuid: + type: "string" + description: "UUID of the alert" + start_at: + type: "string" + scenario_trust: + type: "string" + scenario_hash: + type: "string" + scenario: + type: "string" + alert_id: + type: "integer" + context: + type: "array" + items: + type: "object" + properties: + value: + type: "string" + key: + type: "string" + decisions: + $ref: "#/definitions/AddSignalsRequestItemDecisions" + stop_at: + type: "string" + title: "Signal" + DecisionsSyncRequest: + title: "sync decisions request" + type: "array" + description: "sync decision model" + items: + $ref: "#/definitions/DecisionsSyncRequestItem" + LoginRequest: + type: "object" + required: + - "machine_id" + - "password" + properties: + password: + type: "string" + description: "Password, should respect the password policy (link to add)" + machine_id: + type: "string" + description: "machine_id is a (username) generated by crowdsec" + minLength: 48 + maxLength: 48 + pattern: "^[a-zA-Z0-9]+$" + scenarios: + type: "array" + description: "all scenarios installed" + items: + type: "string" + title: "login request" + description: "Login request model" + GetDecisionsStreamResponseNewItem: + type: "object" + required: + - "scenario" + - "scope" + - "decisions" + properties: + scenario: + type: "string" + scope: + type: "string" + description: + "the scope of decision : does it apply to an IP, a range, a username,\ + \ etc" + decisions: + type: array + items: + type: object + required: + - value + - duration + properties: + duration: + type: "string" + value: + type: "string" + description: + "the value of the decision scope : an IP, a range, a username,\ + \ etc" + title: "New Decisions" + GetDecisionsStreamResponseDeletedItem: + type: object + required: + - scope + - decisions + properties: + scope: + type: "string" + description: + "the scope of decision : does it apply to an IP, a range, a username,\ + \ etc" + decisions: + type: array + items: + type: string + BlocklistLink: + type: object + required: + - name + - url + - remediation + - scope + - duration + properties: + name: + type: string + description: "the name of the blocklist" + url: + type: string + description: "the url from which the blocklist content can be downloaded" + remediation: + type: string + description: "the remediation that should be used for the blocklist" + scope: + type: string + description: "the scope of decisions in the blocklist" + duration: + type: string + AddSignalsRequestItemDecisionsItem: + type: "object" + required: + - "duration" + - "id" + - "origin" + - "scenario" + - "scope" + - "type" + - "value" + properties: + duration: + type: "string" + uuid: + type: "string" + description: "UUID of the decision" + scenario: + type: "string" + origin: + type: "string" + description: "the origin of the decision : cscli, crowdsec" + scope: + type: "string" + description: + "the scope of decision : does it apply to an IP, a range, a username,\ + \ etc" + simulated: + type: "boolean" + until: + type: "string" + id: + type: "integer" + description: "(only relevant for GET ops) the unique id" + type: + type: "string" + description: + "the type of decision, might be 'ban', 'captcha' or something\ + \ custom. Ignored when watcher (cscli/crowdsec) is pushing to APIL." + value: + type: "string" + description: + "the value of the decision scope : an IP, a range, a username,\ + \ etc" + title: "Decision" + EnrollRequest: + type: "object" + required: + - "attachment_key" + properties: + name: + type: "string" + description: "The name that will be display in the console for the instance" + overwrite: + type: "boolean" + description: "To force enroll the instance" + attachment_key: + type: "string" + description: + "attachment_key is generated in your crowdsec backoffice account\ + \ and allows you to enroll your machines to your BO account" + pattern: "^[a-zA-Z0-9]+$" + tags: + type: "array" + description: "Tags to apply on the console for the instance" + items: + type: "string" + title: "enroll request" + description: "enroll request model" + ResetPasswordRequest: + type: "object" + required: + - "machine_id" + - "password" + properties: + password: + type: "string" + description: "Password, should respect the password policy (link to add)" + machine_id: + type: "string" + description: "machine_id is a (username) generated by crowdsec" + minLength: 48 + maxLength: 48 + pattern: "^[a-zA-Z0-9]+$" + title: "resetPassword" + description: "ResetPassword request model" + MetricsRequestBouncersItem: + type: "object" + properties: + last_pull: + type: "string" + description: "last bouncer pull date" + custom_name: + type: "string" + description: "bouncer name" + name: + type: "string" + description: "bouncer type (firewall, php...)" + version: + type: "string" + description: "bouncer version" + title: "MetricsBouncerInfo" + AddSignalsRequestItemSource: + type: "object" + required: + - "scope" + - "value" + properties: + scope: + type: "string" + description: "the scope of a source : ip,range,username,etc" + ip: + type: "string" + description: "provided as a convenience when the source is an IP" + latitude: + type: "number" + format: "float" + as_number: + type: "string" + description: "provided as a convenience when the source is an IP" + range: + type: "string" + description: "provided as a convenience when the source is an IP" + cn: + type: "string" + value: + type: "string" + description: "the value of a source : the ip, the range, the username,etc" + as_name: + type: "string" + description: "provided as a convenience when the source is an IP" + longitude: + type: "number" + format: "float" + title: "Source" + DecisionsSyncRequestItemDecisions: + title: "Decisions list" + type: "array" + items: + $ref: "#/definitions/DecisionsSyncRequestItemDecisionsItem" + RegisterRequest: + type: "object" + required: + - "machine_id" + - "password" + properties: + password: + type: "string" + description: "Password, should respect the password policy (link to add)" + machine_id: + type: "string" + description: "machine_id is a (username) generated by crowdsec" + pattern: "^[a-zA-Z0-9]+$" + title: "register request" + description: "Register request model" + SuccessResponse: + type: "object" + required: + - "message" + properties: + message: + type: "string" + description: "message" + title: "success response" + description: "success response return by the API" + LoginResponse: + type: "object" + properties: + code: + type: "integer" + expire: + type: "string" + token: + type: "string" + title: "login response" + description: "Login request model" + DecisionsSyncRequestItemDecisionsItem: + type: "object" + required: + - "duration" + - "id" + - "origin" + - "scenario" + - "scope" + - "type" + - "value" + properties: + duration: + type: "string" + scenario: + type: "string" + origin: + type: "string" + description: "the origin of the decision : cscli, crowdsec" + scope: + type: "string" + description: + "the scope of decision : does it apply to an IP, a range, a username,\ + \ etc" + simulated: + type: "boolean" + until: + type: "string" + id: + type: "integer" + description: "(only relevant for GET ops) the unique id" + type: + type: "string" + description: + "the type of decision, might be 'ban', 'captcha' or something\ + \ custom. Ignored when watcher (cscli/crowdsec) is pushing to APIL." + value: + type: "string" + description: + "the value of the decision scope : an IP, a range, a username,\ + \ etc" + title: "Decision" + GetDecisionsStreamResponse: + type: "object" + properties: + new: + $ref: "#/definitions/GetDecisionsStreamResponseNew" + deleted: + $ref: "#/definitions/GetDecisionsStreamResponseDeleted" + links: + $ref: "#/definitions/GetDecisionsStreamResponseLinks" + title: "get decisions stream response" + description: "get decision response model" + DecisionsSyncRequestItemSource: + type: "object" + required: + - "scope" + - "value" + properties: + scope: + type: "string" + description: "the scope of a source : ip,range,username,etc" + ip: + type: "string" + description: "provided as a convenience when the source is an IP" + latitude: + type: "number" + format: "float" + as_number: + type: "string" + description: "provided as a convenience when the source is an IP" + range: + type: "string" + description: "provided as a convenience when the source is an IP" + cn: + type: "string" + value: + type: "string" + description: "the value of a source : the ip, the range, the username,etc" + as_name: + type: "string" + description: "provided as a convenience when the source is an IP" + longitude: + type: "number" + format: "float" + title: "Source" + AddSignalsRequestItemDecisions: + title: "Decisions list" + type: "array" + items: + $ref: "#/definitions/AddSignalsRequestItemDecisionsItem" + MetricsRequestMachinesItem: + type: "object" + properties: + last_update: + type: "string" + description: "last agent update date" + name: + type: "string" + description: "agent name" + last_push: + type: "string" + description: "last agent push date" + version: + type: "string" + description: "agent version" + title: "MetricsAgentInfo" + MetricsRequest: + type: "object" + required: + - "bouncers" + - "machines" + properties: + bouncers: + type: "array" + items: + $ref: "#/definitions/MetricsRequestBouncersItem" + machines: + type: "array" + items: + $ref: "#/definitions/MetricsRequestMachinesItem" + title: "metrics" + description: "push metrics model" + ErrorResponse: + type: "object" + required: + - "message" + properties: + message: + type: "string" + description: "Error message" + errors: + type: "string" + description: "more detail on individual errors" + title: "error response" + description: "error response return by the API" + AddSignalsRequest: + title: "add signals request" + type: "array" + description: "All signals request model" + items: + $ref: "#/definitions/AddSignalsRequestItem" + DecisionsDeleteRequestItem: + type: "string" + title: "decisionsIDs" + GetDecisionsStreamResponseNew: + title: "Decisions list" + type: "array" + items: + $ref: "#/definitions/GetDecisionsStreamResponseNewItem" + GetDecisionsStreamResponseDeleted: + title: "Decisions list" + type: "array" + items: + $ref: "#/definitions/GetDecisionsStreamResponseDeletedItem" + GetDecisionsStreamResponseLinks: + title: "Decisions list" + type: "object" + properties: + blocklists: + type: array + items: + $ref: "#/definitions/BlocklistLink" + diff --git a/pkg/modelscapi/decisions_delete_request.go b/pkg/modelscapi/decisions_delete_request.go index e8718835027..0c93558adf1 100644 --- a/pkg/modelscapi/decisions_delete_request.go +++ b/pkg/modelscapi/decisions_delete_request.go @@ -11,6 +11,7 @@ import ( "github.com/go-openapi/errors" "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" ) // DecisionsDeleteRequest delete decisions @@ -49,6 +50,10 @@ func (m DecisionsDeleteRequest) ContextValidate(ctx context.Context, formats str for i := 0; i < len(m); i++ { + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/modelscapi/decisions_sync_request.go b/pkg/modelscapi/decisions_sync_request.go index e3a95162519..c087d39ff62 100644 --- a/pkg/modelscapi/decisions_sync_request.go +++ b/pkg/modelscapi/decisions_sync_request.go @@ -56,6 +56,11 @@ func (m DecisionsSyncRequest) ContextValidate(ctx context.Context, formats strfm for i := 0; i < len(m); i++ { if m[i] != nil { + + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/modelscapi/decisions_sync_request_item.go b/pkg/modelscapi/decisions_sync_request_item.go index 5139ea2de4b..460fe4d430e 100644 --- a/pkg/modelscapi/decisions_sync_request_item.go +++ b/pkg/modelscapi/decisions_sync_request_item.go @@ -231,6 +231,7 @@ func (m *DecisionsSyncRequestItem) contextValidateDecisions(ctx context.Context, func (m *DecisionsSyncRequestItem) contextValidateSource(ctx context.Context, formats strfmt.Registry) error { if m.Source != nil { + if err := m.Source.ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("source") diff --git a/pkg/modelscapi/decisions_sync_request_item_decisions.go b/pkg/modelscapi/decisions_sync_request_item_decisions.go index 76316e43c5e..bdc8e77e2b6 100644 --- a/pkg/modelscapi/decisions_sync_request_item_decisions.go +++ b/pkg/modelscapi/decisions_sync_request_item_decisions.go @@ -54,6 +54,11 @@ func (m DecisionsSyncRequestItemDecisions) ContextValidate(ctx context.Context, for i := 0; i < len(m); i++ { if m[i] != nil { + + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/modelscapi/generate.go b/pkg/modelscapi/generate.go new file mode 100644 index 00000000000..66dc2a34b7e --- /dev/null +++ b/pkg/modelscapi/generate.go @@ -0,0 +1,4 @@ +package modelscapi + +//go:generate go run -mod=mod github.com/go-swagger/go-swagger/cmd/swagger@v0.31.0 generate model --spec=./centralapi_swagger.yaml --target=../ --model-package=modelscapi + diff --git a/pkg/modelscapi/get_decisions_stream_response.go b/pkg/modelscapi/get_decisions_stream_response.go index af19b85c4d3..5ebf29c5d93 100644 --- a/pkg/modelscapi/get_decisions_stream_response.go +++ b/pkg/modelscapi/get_decisions_stream_response.go @@ -144,6 +144,11 @@ func (m *GetDecisionsStreamResponse) contextValidateDeleted(ctx context.Context, func (m *GetDecisionsStreamResponse) contextValidateLinks(ctx context.Context, formats strfmt.Registry) error { if m.Links != nil { + + if swag.IsZero(m.Links) { // not required + return nil + } + if err := m.Links.ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("links") diff --git a/pkg/modelscapi/get_decisions_stream_response_deleted.go b/pkg/modelscapi/get_decisions_stream_response_deleted.go index d218bf87e4e..78292860f22 100644 --- a/pkg/modelscapi/get_decisions_stream_response_deleted.go +++ b/pkg/modelscapi/get_decisions_stream_response_deleted.go @@ -54,6 +54,11 @@ func (m GetDecisionsStreamResponseDeleted) ContextValidate(ctx context.Context, for i := 0; i < len(m); i++ { if m[i] != nil { + + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/modelscapi/get_decisions_stream_response_links.go b/pkg/modelscapi/get_decisions_stream_response_links.go index 85cc9af9b48..6b9054574f1 100644 --- a/pkg/modelscapi/get_decisions_stream_response_links.go +++ b/pkg/modelscapi/get_decisions_stream_response_links.go @@ -82,6 +82,11 @@ func (m *GetDecisionsStreamResponseLinks) contextValidateBlocklists(ctx context. for i := 0; i < len(m.Blocklists); i++ { if m.Blocklists[i] != nil { + + if swag.IsZero(m.Blocklists[i]) { // not required + return nil + } + if err := m.Blocklists[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("blocklists" + "." + strconv.Itoa(i)) diff --git a/pkg/modelscapi/get_decisions_stream_response_new.go b/pkg/modelscapi/get_decisions_stream_response_new.go index e9525bf6fa7..8e09f1b20e7 100644 --- a/pkg/modelscapi/get_decisions_stream_response_new.go +++ b/pkg/modelscapi/get_decisions_stream_response_new.go @@ -54,6 +54,11 @@ func (m GetDecisionsStreamResponseNew) ContextValidate(ctx context.Context, form for i := 0; i < len(m); i++ { if m[i] != nil { + + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/modelscapi/get_decisions_stream_response_new_item.go b/pkg/modelscapi/get_decisions_stream_response_new_item.go index a3592d0ab61..77cc06732ce 100644 --- a/pkg/modelscapi/get_decisions_stream_response_new_item.go +++ b/pkg/modelscapi/get_decisions_stream_response_new_item.go @@ -119,6 +119,11 @@ func (m *GetDecisionsStreamResponseNewItem) contextValidateDecisions(ctx context for i := 0; i < len(m.Decisions); i++ { if m.Decisions[i] != nil { + + if swag.IsZero(m.Decisions[i]) { // not required + return nil + } + if err := m.Decisions[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("decisions" + "." + strconv.Itoa(i)) diff --git a/pkg/modelscapi/metrics_request.go b/pkg/modelscapi/metrics_request.go index d5b7d058fc1..5d663cf1750 100644 --- a/pkg/modelscapi/metrics_request.go +++ b/pkg/modelscapi/metrics_request.go @@ -126,6 +126,11 @@ func (m *MetricsRequest) contextValidateBouncers(ctx context.Context, formats st for i := 0; i < len(m.Bouncers); i++ { if m.Bouncers[i] != nil { + + if swag.IsZero(m.Bouncers[i]) { // not required + return nil + } + if err := m.Bouncers[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("bouncers" + "." + strconv.Itoa(i)) @@ -146,6 +151,11 @@ func (m *MetricsRequest) contextValidateMachines(ctx context.Context, formats st for i := 0; i < len(m.Machines); i++ { if m.Machines[i] != nil { + + if swag.IsZero(m.Machines[i]) { // not required + return nil + } + if err := m.Machines[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("machines" + "." + strconv.Itoa(i)) diff --git a/pkg/parser/README.md b/pkg/parser/README.md index 62a56e61820..0fcccc811e4 100644 --- a/pkg/parser/README.md +++ b/pkg/parser/README.md @@ -45,7 +45,7 @@ statics: > `filter: "Line.Src endsWith '/foobar'"` - - *optional* `filter` : an [expression](https://github.com/antonmedv/expr/blob/master/docs/Language-Definition.md) that will be evaluated against the runtime of a line (`Event`) + - *optional* `filter` : an [expression](https://github.com/antonmedv/expr/blob/master/docs/language-definition.md) that will be evaluated against the runtime of a line (`Event`) - if the `filter` is present and returns false, node is not evaluated - if `filter` is absent or present and returns true, node is evaluated diff --git a/pkg/parser/enrich.go b/pkg/parser/enrich.go index 5180b9a5fb9..661410d20d3 100644 --- a/pkg/parser/enrich.go +++ b/pkg/parser/enrich.go @@ -7,7 +7,7 @@ import ( ) /* should be part of a package shared with enrich/geoip.go */ -type EnrichFunc func(string, *types.Event, interface{}, *log.Entry) (map[string]string, error) +type EnrichFunc func(string, *types.Event, *log.Entry) (map[string]string, error) type InitFunc func(map[string]string) (interface{}, error) type EnricherCtx struct { @@ -16,59 +16,42 @@ type EnricherCtx struct { type Enricher struct { Name string - InitFunc InitFunc EnrichFunc EnrichFunc - Ctx interface{} } /* mimic plugin loading */ -func Loadplugin(path string) (EnricherCtx, error) { +func Loadplugin() (EnricherCtx, error) { enricherCtx := EnricherCtx{} enricherCtx.Registered = make(map[string]*Enricher) - enricherConfig := map[string]string{"datadir": path} - EnrichersList := []*Enricher{ { Name: "GeoIpCity", - InitFunc: GeoIPCityInit, EnrichFunc: GeoIpCity, }, { Name: "GeoIpASN", - InitFunc: GeoIPASNInit, EnrichFunc: GeoIpASN, }, { Name: "IpToRange", - InitFunc: IpToRangeInit, EnrichFunc: IpToRange, }, { Name: "reverse_dns", - InitFunc: reverseDNSInit, EnrichFunc: reverse_dns, }, { Name: "ParseDate", - InitFunc: parseDateInit, EnrichFunc: ParseDate, }, { Name: "UnmarshalJSON", - InitFunc: unmarshalInit, EnrichFunc: unmarshalJSON, }, } for _, enricher := range EnrichersList { - log.Debugf("Initiating enricher '%s'", enricher.Name) - pluginCtx, err := enricher.InitFunc(enricherConfig) - if err != nil { - log.Errorf("unable to register plugin '%s': %v", enricher.Name, err) - continue - } - enricher.Ctx = pluginCtx log.Infof("Successfully registered enricher '%s'", enricher.Name) enricherCtx.Registered[enricher.Name] = enricher } diff --git a/pkg/parser/enrich_date.go b/pkg/parser/enrich_date.go index 20828af9037..40c8de39da5 100644 --- a/pkg/parser/enrich_date.go +++ b/pkg/parser/enrich_date.go @@ -18,7 +18,7 @@ func parseDateWithFormat(date, format string) (string, time.Time) { } retstr, err := t.MarshalText() if err != nil { - log.Warningf("Failed marshaling '%v'", t) + log.Warningf("Failed to serialize '%v'", t) return "", time.Time{} } return string(retstr), t @@ -56,7 +56,7 @@ func GenDateParse(date string) (string, time.Time) { return "", time.Time{} } -func ParseDate(in string, p *types.Event, x interface{}, plog *log.Entry) (map[string]string, error) { +func ParseDate(in string, p *types.Event, plog *log.Entry) (map[string]string, error) { var ret = make(map[string]string) var strDate string @@ -98,14 +98,10 @@ func ParseDate(in string, p *types.Event, x interface{}, plog *log.Entry) (map[s now := time.Now().UTC() retstr, err := now.MarshalText() if err != nil { - plog.Warning("Failed marshaling current time") + plog.Warning("Failed to serialize current time") return ret, err } ret["MarshaledTime"] = string(retstr) return ret, nil } - -func parseDateInit(cfg map[string]string) (interface{}, error) { - return nil, nil -} diff --git a/pkg/parser/enrich_date_test.go b/pkg/parser/enrich_date_test.go index 0a9ac67f8e9..930633feb35 100644 --- a/pkg/parser/enrich_date_test.go +++ b/pkg/parser/enrich_date_test.go @@ -4,34 +4,33 @@ import ( "testing" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" - "github.com/crowdsecurity/go-cs-lib/pkg/ptr" + "github.com/crowdsecurity/go-cs-lib/cstest" "github.com/crowdsecurity/crowdsec/pkg/types" ) func TestDateParse(t *testing.T) { tests := []struct { - name string - evt types.Event - expected_err *error - expected_strTime *string + name string + evt types.Event + expectedErr string + expected string }{ { name: "RFC3339", evt: types.Event{ StrTime: "2019-10-12T07:20:50.52Z", }, - expected_err: nil, - expected_strTime: ptr.Of("2019-10-12T07:20:50.52Z"), + expected: "2019-10-12T07:20:50.52Z", }, { name: "02/Jan/2006:15:04:05 -0700", evt: types.Event{ StrTime: "02/Jan/2006:15:04:05 -0700", }, - expected_err: nil, - expected_strTime: ptr.Of("2006-01-02T15:04:05-07:00"), + expected: "2006-01-02T15:04:05-07:00", }, { name: "Dec 17 08:17:43", @@ -39,31 +38,19 @@ func TestDateParse(t *testing.T) { StrTime: "2011 X 17 zz 08X17X43 oneone Dec", StrTimeFormat: "2006 X 2 zz 15X04X05 oneone Jan", }, - expected_err: nil, - expected_strTime: ptr.Of("2011-12-17T08:17:43Z"), + expected: "2011-12-17T08:17:43Z", }, } - logger := log.WithFields(log.Fields{ - "test": "test", - }) + logger := log.WithField("test", "test") for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { - strTime, err := ParseDate(tt.evt.StrTime, &tt.evt, nil, logger) - if tt.expected_err != nil { - if err != *tt.expected_err { - t.Errorf("%s: expected error %v, got %v", tt.name, tt.expected_err, err) - } - } else if err != nil { - t.Errorf("%s: expected no error, got %v", tt.name, err) - } - if err != nil { + strTime, err := ParseDate(tt.evt.StrTime, &tt.evt, logger) + cstest.RequireErrorContains(t, err, tt.expectedErr) + if tt.expectedErr != "" { return } - if tt.expected_strTime != nil && strTime["MarshaledTime"] != *tt.expected_strTime { - t.Errorf("expected strTime %s, got %s", *tt.expected_strTime, strTime["MarshaledTime"]) - } + assert.Equal(t, tt.expected, strTime["MarshaledTime"]) }) } } diff --git a/pkg/parser/enrich_dns.go b/pkg/parser/enrich_dns.go index f622e6c359a..1ff5b0f4f16 100644 --- a/pkg/parser/enrich_dns.go +++ b/pkg/parser/enrich_dns.go @@ -11,7 +11,7 @@ import ( /* All plugins must export a list of function pointers for exported symbols */ //var ExportedFuncs = []string{"reverse_dns"} -func reverse_dns(field string, p *types.Event, ctx interface{}, plog *log.Entry) (map[string]string, error) { +func reverse_dns(field string, p *types.Event, plog *log.Entry) (map[string]string, error) { ret := make(map[string]string) if field == "" { return nil, nil @@ -25,7 +25,3 @@ func reverse_dns(field string, p *types.Event, ctx interface{}, plog *log.Entry) ret["reverse_dns"] = rets[0] return ret, nil } - -func reverseDNSInit(cfg map[string]string) (interface{}, error) { - return nil, nil -} diff --git a/pkg/parser/enrich_geoip.go b/pkg/parser/enrich_geoip.go index 0a263c82793..1756927bc4b 100644 --- a/pkg/parser/enrich_geoip.go +++ b/pkg/parser/enrich_geoip.go @@ -6,53 +6,66 @@ import ( "strconv" "github.com/oschwald/geoip2-golang" - "github.com/oschwald/maxminddb-golang" log "github.com/sirupsen/logrus" + "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" ) -func IpToRange(field string, p *types.Event, ctx interface{}, plog *log.Entry) (map[string]string, error) { - var dummy interface{} - ret := make(map[string]string) - +func IpToRange(field string, p *types.Event, plog *log.Entry) (map[string]string, error) { if field == "" { return nil, nil } - ip := net.ParseIP(field) - if ip == nil { - plog.Infof("Can't parse ip %s, no range enrich", field) - return nil, nil - } - net, ok, err := ctx.(*maxminddb.Reader).LookupNetwork(ip, &dummy) + + r, err := exprhelpers.GeoIPRangeEnrich(field) + if err != nil { - plog.Errorf("Failed to fetch network for %s : %v", ip.String(), err) + plog.Errorf("Unable to enrich ip '%s'", field) + return nil, nil //nolint:nilerr + } + + if r == nil { + plog.Debugf("No range found for ip '%s'", field) return nil, nil } + + record, ok := r.(*net.IPNet) + if !ok { - plog.Debugf("Unable to find range of %s", ip.String()) return nil, nil } - ret["SourceRange"] = net.String() + + ret := make(map[string]string) + ret["SourceRange"] = record.String() + return ret, nil } -func GeoIpASN(field string, p *types.Event, ctx interface{}, plog *log.Entry) (map[string]string, error) { - ret := make(map[string]string) +func GeoIpASN(field string, p *types.Event, plog *log.Entry) (map[string]string, error) { if field == "" { return nil, nil } - ip := net.ParseIP(field) - if ip == nil { - plog.Infof("Can't parse ip %s, no ASN enrich", ip) - return nil, nil - } - record, err := ctx.(*geoip2.Reader).ASN(ip) + r, err := exprhelpers.GeoIPASNEnrich(field) + if err != nil { - plog.Errorf("Unable to enrich ip '%s'", field) + plog.Debugf("Unable to enrich ip '%s'", field) return nil, nil //nolint:nilerr } + + if r == nil { + plog.Debugf("No ASN found for ip '%s'", field) + return nil, nil + } + + record, ok := r.(*geoip2.ASN) + + if !ok { + return nil, nil + } + + ret := make(map[string]string) + ret["ASNNumber"] = fmt.Sprintf("%d", record.AutonomousSystemNumber) ret["ASNumber"] = fmt.Sprintf("%d", record.AutonomousSystemNumber) ret["ASNOrg"] = record.AutonomousSystemOrganization @@ -62,21 +75,31 @@ func GeoIpASN(field string, p *types.Event, ctx interface{}, plog *log.Entry) (m return ret, nil } -func GeoIpCity(field string, p *types.Event, ctx interface{}, plog *log.Entry) (map[string]string, error) { - ret := make(map[string]string) +func GeoIpCity(field string, p *types.Event, plog *log.Entry) (map[string]string, error) { if field == "" { return nil, nil } - ip := net.ParseIP(field) - if ip == nil { - plog.Infof("Can't parse ip %s, no City enrich", ip) - return nil, nil - } - record, err := ctx.(*geoip2.Reader).City(ip) + + r, err := exprhelpers.GeoIPEnrich(field) + if err != nil { - plog.Debugf("Unable to enrich ip '%s'", ip) + plog.Debugf("Unable to enrich ip '%s'", field) return nil, nil //nolint:nilerr } + + if r == nil { + plog.Debugf("No city found for ip '%s'", field) + return nil, nil + } + + record, ok := r.(*geoip2.City) + + if !ok { + return nil, nil + } + + ret := make(map[string]string) + if record.Country.IsoCode != "" { ret["IsoCode"] = record.Country.IsoCode ret["IsInEU"] = strconv.FormatBool(record.Country.IsInEuropeanUnion) @@ -88,7 +111,7 @@ func GeoIpCity(field string, p *types.Event, ctx interface{}, plog *log.Entry) ( ret["IsInEU"] = strconv.FormatBool(record.RepresentedCountry.IsInEuropeanUnion) } else { ret["IsoCode"] = "" - ret["IsInEU"] = strconv.FormatBool(false) + ret["IsInEU"] = "false" } ret["Latitude"] = fmt.Sprintf("%f", record.Location.Latitude) @@ -98,33 +121,3 @@ func GeoIpCity(field string, p *types.Event, ctx interface{}, plog *log.Entry) ( return ret, nil } - -func GeoIPCityInit(cfg map[string]string) (interface{}, error) { - dbCityReader, err := geoip2.Open(cfg["datadir"] + "/GeoLite2-City.mmdb") - if err != nil { - log.Debugf("couldn't open geoip : %v", err) - return nil, err - } - - return dbCityReader, nil -} - -func GeoIPASNInit(cfg map[string]string) (interface{}, error) { - dbASReader, err := geoip2.Open(cfg["datadir"] + "/GeoLite2-ASN.mmdb") - if err != nil { - log.Debugf("couldn't open geoip : %v", err) - return nil, err - } - - return dbASReader, nil -} - -func IpToRangeInit(cfg map[string]string) (interface{}, error) { - ipToRangeReader, err := maxminddb.Open(cfg["datadir"] + "/GeoLite2-ASN.mmdb") - if err != nil { - log.Debugf("couldn't open geoip : %v", err) - return nil, err - } - - return ipToRangeReader, nil -} diff --git a/pkg/parser/enrich_unmarshal.go b/pkg/parser/enrich_unmarshal.go index dce9c75d466..dbdd9d3f583 100644 --- a/pkg/parser/enrich_unmarshal.go +++ b/pkg/parser/enrich_unmarshal.go @@ -8,16 +8,12 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -func unmarshalJSON(field string, p *types.Event, ctx interface{}, plog *log.Entry) (map[string]string, error) { +func unmarshalJSON(field string, p *types.Event, plog *log.Entry) (map[string]string, error) { err := json.Unmarshal([]byte(p.Line.Raw), &p.Unmarshaled) if err != nil { - plog.Errorf("could not unmarshal JSON: %s", err) + plog.Errorf("could not parse JSON: %s", err) return nil, err } plog.Tracef("unmarshaled JSON: %+v", p.Unmarshaled) return nil, nil } - -func unmarshalInit(cfg map[string]string) (interface{}, error) { - return nil, nil -} diff --git a/pkg/parser/grok_pattern.go b/pkg/parser/grok_pattern.go index 5b3204a4201..9c781d47aa6 100644 --- a/pkg/parser/grok_pattern.go +++ b/pkg/parser/grok_pattern.go @@ -3,7 +3,7 @@ package parser import ( "time" - "github.com/antonmedv/expr/vm" + "github.com/expr-lang/expr/vm" "github.com/crowdsecurity/grokky" ) diff --git a/pkg/parser/node.go b/pkg/parser/node.go index f3341cb2b36..26046ae4fd6 100644 --- a/pkg/parser/node.go +++ b/pkg/parser/node.go @@ -3,13 +3,12 @@ package parser import ( "errors" "fmt" - "net" "strings" "time" - "github.com/antonmedv/expr" - "github.com/antonmedv/expr/vm" "github.com/davecgh/go-spew/spew" + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" yaml "gopkg.in/yaml.v2" @@ -23,70 +22,70 @@ import ( type Node struct { FormatVersion string `yaml:"format"` - //Enable config + runtime debug of node via config o/ + // Enable config + runtime debug of node via config o/ Debug bool `yaml:"debug,omitempty"` - //If enabled, the node (and its child) will report their own statistics + // If enabled, the node (and its child) will report their own statistics Profiling bool `yaml:"profiling,omitempty"` - //Name, author, description and reference(s) for parser pattern + // Name, author, description and reference(s) for parser pattern Name string `yaml:"name,omitempty"` Author string `yaml:"author,omitempty"` Description string `yaml:"description,omitempty"` References []string `yaml:"references,omitempty"` - //if debug is present in the node, keep its specific Logger in runtime structure + // if debug is present in the node, keep its specific Logger in runtime structure Logger *log.Entry `yaml:"-"` - //This is mostly a hack to make writing less repetitive. - //relying on stage, we know which field to parse, and we - //can also promote log to next stage on success + // This is mostly a hack to make writing less repetitive. + // relying on stage, we know which field to parse, and we + // can also promote log to next stage on success Stage string `yaml:"stage,omitempty"` - //OnSuccess allows to tag a node to be able to move log to next stage on success + // OnSuccess allows to tag a node to be able to move log to next stage on success OnSuccess string `yaml:"onsuccess,omitempty"` - rn string //this is only for us in debug, a random generated name for each node - //Filter is executed at runtime (with current log line as context) - //and must succeed or node is exited - Filter string `yaml:"filter,omitempty"` - RunTimeFilter *vm.Program `yaml:"-" json:"-"` //the actual compiled filter - ExprDebugger *exprhelpers.ExprDebugger `yaml:"-" json:"-"` //used to debug expression by printing the content of each variable of the expression - //If node has leafs, execute all of them until one asks for a 'break' + rn string // this is only for us in debug, a random generated name for each node + // Filter is executed at runtime (with current log line as context) + // and must succeed or node is exited + Filter string `yaml:"filter,omitempty"` + RunTimeFilter *vm.Program `yaml:"-" json:"-"` // the actual compiled filter + // If node has leafs, execute all of them until one asks for a 'break' LeavesNodes []Node `yaml:"nodes,omitempty"` - //Flag used to describe when to 'break' or return an 'error' + // Flag used to describe when to 'break' or return an 'error' EnrichFunctions EnricherCtx /* If the node is actually a leaf, it can have : grok, enrich, statics */ - //pattern_syntax are named grok patterns that are re-utilized over several grok patterns + // pattern_syntax are named grok patterns that are re-utilized over several grok patterns SubGroks yaml.MapSlice `yaml:"pattern_syntax,omitempty"` - //Holds a grok pattern + // Holds a grok pattern Grok GrokPattern `yaml:"grok,omitempty"` - //Statics can be present in any type of node and is executed last + // Statics can be present in any type of node and is executed last Statics []ExtraField `yaml:"statics,omitempty"` - //Stash allows to capture data from the log line and store it in an accessible cache + // Stash allows to capture data from the log line and store it in an accessible cache Stash []DataCapture `yaml:"stash,omitempty"` - //Whitelists + // Whitelists Whitelist Whitelist `yaml:"whitelist,omitempty"` Data []*types.DataSource `yaml:"data,omitempty"` } -func (n *Node) validate(pctx *UnixParserCtx, ectx EnricherCtx) error { - - //stage is being set automagically +func (n *Node) validate(ectx EnricherCtx) error { + // stage is being set automagically if n.Stage == "" { - return fmt.Errorf("stage needs to be an existing stage") + return errors.New("stage needs to be an existing stage") } /* "" behaves like continue */ if n.OnSuccess != "continue" && n.OnSuccess != "next_stage" && n.OnSuccess != "" { return fmt.Errorf("onsuccess '%s' not continue,next_stage", n.OnSuccess) } + if n.Filter != "" && n.RunTimeFilter == nil { return fmt.Errorf("non-empty filter '%s' was not compiled", n.Filter) } if n.Grok.RunTimeRegexp != nil || n.Grok.TargetField != "" { if n.Grok.TargetField == "" && n.Grok.ExpValue == "" { - return fmt.Errorf("grok requires 'expression' or 'apply_on'") + return errors.New("grok requires 'expression' or 'apply_on'") } + if n.Grok.RegexpName == "" && n.Grok.RegexpValue == "" { - return fmt.Errorf("grok needs 'pattern' or 'name'") + return errors.New("grok needs 'pattern' or 'name'") } } @@ -95,6 +94,7 @@ func (n *Node) validate(pctx *UnixParserCtx, ectx EnricherCtx) error { if static.ExpValue == "" { return fmt.Errorf("static %d : when method is set, expression must be present", idx) } + if _, ok := ectx.Registered[static.Method]; !ok { log.Warningf("the method '%s' doesn't exist or the plugin has not been initialized", static.Method) } @@ -102,6 +102,7 @@ func (n *Node) validate(pctx *UnixParserCtx, ectx EnricherCtx) error { if static.Meta == "" && static.Parsed == "" && static.TargetByName == "" { return fmt.Errorf("static %d : at least one of meta/event/target must be set", idx) } + if static.Value == "" && static.RunTimeValue == nil { return fmt.Errorf("static %d value or expression must be set", idx) } @@ -112,224 +113,224 @@ func (n *Node) validate(pctx *UnixParserCtx, ectx EnricherCtx) error { if stash.Name == "" { return fmt.Errorf("stash %d : name must be set", idx) } + if stash.Value == "" { return fmt.Errorf("stash %s : value expression must be set", stash.Name) } + if stash.Key == "" { return fmt.Errorf("stash %s : key expression must be set", stash.Name) } + if stash.TTL == "" { return fmt.Errorf("stash %s : ttl must be set", stash.Name) } + if stash.Strategy == "" { stash.Strategy = "LRU" } - //should be configurable + // should be configurable if stash.MaxMapSize == 0 { stash.MaxMapSize = 100 } } + return nil } -func (n *Node) process(p *types.Event, ctx UnixParserCtx, expressionEnv map[string]interface{}) (bool, error) { - var NodeState bool - var NodeHasOKGrok bool +func (n *Node) processFilter(cachedExprEnv map[string]interface{}) (bool, error) { clog := n.Logger + if n.RunTimeFilter == nil { + clog.Tracef("Node has not filter, enter") + return true, nil + } - cachedExprEnv := expressionEnv + // Evaluate node's filter + output, err := exprhelpers.Run(n.RunTimeFilter, cachedExprEnv, clog, n.Debug) + if err != nil { + clog.Warningf("failed to run filter : %v", err) + clog.Debugf("Event leaving node : ko") - clog.Tracef("Event entering node") - if n.RunTimeFilter != nil { - //Evaluate node's filter - output, err := expr.Run(n.RunTimeFilter, cachedExprEnv) - if err != nil { - clog.Warningf("failed to run filter : %v", err) - clog.Debugf("Event leaving node : ko") - return false, nil - } + return false, nil + } - switch out := output.(type) { - case bool: - if n.Debug { - n.ExprDebugger.Run(clog, out, cachedExprEnv) - } - if !out { - clog.Debugf("Event leaving node : ko (failed filter)") - return false, nil - } - default: - clog.Warningf("Expr '%s' returned non-bool, abort : %T", n.Filter, output) - clog.Debugf("Event leaving node : ko") + switch out := output.(type) { + case bool: + if !out { + clog.Debugf("Event leaving node : ko (failed filter)") return false, nil } - NodeState = true - } else { - clog.Tracef("Node has not filter, enter") - NodeState = true + default: + clog.Warningf("Expr '%s' returned non-bool, abort : %T", n.Filter, output) + clog.Debugf("Event leaving node : ko") + + return false, nil } - if n.Name != "" { - NodesHits.With(prometheus.Labels{"source": p.Line.Src, "type": p.Line.Module, "name": n.Name}).Inc() + return true, nil +} + +func (n *Node) processWhitelist(cachedExprEnv map[string]interface{}, p *types.Event) (bool, error) { + var exprErr error + + isWhitelisted := n.CheckIPsWL(p) + if !isWhitelisted { + isWhitelisted, exprErr = n.CheckExprWL(cachedExprEnv, p) } - isWhitelisted := false - hasWhitelist := false - var srcs []net.IP - /*overflow and log don't hold the source ip in the same field, should be changed */ - /* perform whitelist checks for ips, cidr accordingly */ - /* TODO move whitelist elsewhere */ - if p.Type == types.LOG { - if _, ok := p.Meta["source_ip"]; ok { - srcs = append(srcs, net.ParseIP(p.Meta["source_ip"])) - } - } else if p.Type == types.OVFLW { - for k := range p.Overflow.Sources { - srcs = append(srcs, net.ParseIP(k)) - } - } - for _, src := range srcs { - if isWhitelisted { - break - } - for _, v := range n.Whitelist.B_Ips { - if v.Equal(src) { - clog.Debugf("Event from [%s] is whitelisted by IP (%s), reason [%s]", src, v, n.Whitelist.Reason) - isWhitelisted = true - } else { - clog.Tracef("whitelist: %s is not eq [%s]", src, v) - } - hasWhitelist = true - } - for _, v := range n.Whitelist.B_Cidrs { - if v.Contains(src) { - clog.Debugf("Event from [%s] is whitelisted by CIDR (%s), reason [%s]", src, v, n.Whitelist.Reason) - isWhitelisted = true - } else { - clog.Tracef("whitelist: %s not in [%s]", src, v) - } - hasWhitelist = true - } + + if exprErr != nil { + // Previous code returned nil if there was an error, so we keep this behavior + return false, nil //nolint:nilerr } - if isWhitelisted { + if isWhitelisted && !p.Whitelisted { p.Whitelisted = true - } - /* run whitelist expression tests anyway */ - for eidx, e := range n.Whitelist.B_Exprs { - output, err := expr.Run(e.Filter, cachedExprEnv) - if err != nil { - clog.Warningf("failed to run whitelist expr : %v", err) - clog.Debug("Event leaving node : ko") - return false, nil - } - switch out := output.(type) { - case bool: - if n.Debug { - e.ExprDebugger.Run(clog, out, cachedExprEnv) - } - if out { - clog.Debugf("Event is whitelisted by expr, reason [%s]", n.Whitelist.Reason) - p.Whitelisted = true - isWhitelisted = true - } - hasWhitelist = true - default: - log.Errorf("unexpected type %t (%v) while running '%s'", output, output, n.Whitelist.Exprs[eidx]) - } - } - if isWhitelisted { p.WhitelistReason = n.Whitelist.Reason /*huglily wipe the ban order if the event is whitelisted and it's an overflow */ if p.Type == types.OVFLW { /*don't do this at home kids */ ips := []string{} - for _, src := range srcs { - ips = append(ips, src.String()) + for k := range p.Overflow.Sources { + ips = append(ips, k) } - clog.Infof("Ban for %s whitelisted, reason [%s]", strings.Join(ips, ","), n.Whitelist.Reason) + + n.Logger.Infof("Ban for %s whitelisted, reason [%s]", strings.Join(ips, ","), n.Whitelist.Reason) + p.Overflow.Whitelisted = true } } - //Process grok if present, should be exclusive with nodes :) + return isWhitelisted, nil +} + +func (n *Node) processGrok(p *types.Event, cachedExprEnv map[string]any) (bool, bool, error) { + // Process grok if present, should be exclusive with nodes :) + clog := n.Logger + var NodeHasOKGrok bool gstr := "" - if n.Grok.RunTimeRegexp != nil { - clog.Tracef("Processing grok pattern : %s : %p", n.Grok.RegexpName, n.Grok.RunTimeRegexp) - //for unparsed, parsed etc. set sensible defaults to reduce user hassle - if n.Grok.TargetField != "" { - //it's a hack to avoid using real reflect - if n.Grok.TargetField == "Line.Raw" { - gstr = p.Line.Raw - } else if val, ok := p.Parsed[n.Grok.TargetField]; ok { - gstr = val - } else { - clog.Debugf("(%s) target field '%s' doesn't exist in %v", n.rn, n.Grok.TargetField, p.Parsed) - NodeState = false - } - } else if n.Grok.RunTimeValue != nil { - output, err := expr.Run(n.Grok.RunTimeValue, cachedExprEnv) - if err != nil { - clog.Warningf("failed to run RunTimeValue : %v", err) - NodeState = false - } - switch out := output.(type) { - case string: - gstr = out - default: - clog.Errorf("unexpected return type for RunTimeValue : %T", output) - } - } - var groklabel string - if n.Grok.RegexpName == "" { - groklabel = fmt.Sprintf("%5.5s...", n.Grok.RegexpValue) - } else { - groklabel = n.Grok.RegexpName - } - grok := n.Grok.RunTimeRegexp.Parse(gstr) - if len(grok) > 0 { - /*tag explicitly that the *current* node had a successful grok pattern. it's important to know success state*/ - NodeHasOKGrok = true - clog.Debugf("+ Grok '%s' returned %d entries to merge in Parsed", groklabel, len(grok)) - //We managed to grok stuff, merged into parse - for k, v := range grok { - clog.Debugf("\t.Parsed['%s'] = '%s'", k, v) - p.Parsed[k] = v - } - // if the grok succeed, process associated statics - err := n.ProcessStatics(n.Grok.Statics, p) - if err != nil { - clog.Errorf("(%s) Failed to process statics : %v", n.rn, err) - return false, err - } + if n.Grok.RunTimeRegexp == nil { + clog.Tracef("! No grok pattern : %p", n.Grok.RunTimeRegexp) + return true, false, nil + } + + clog.Tracef("Processing grok pattern : %s : %p", n.Grok.RegexpName, n.Grok.RunTimeRegexp) + // for unparsed, parsed etc. set sensible defaults to reduce user hassle + if n.Grok.TargetField != "" { + // it's a hack to avoid using real reflect + if n.Grok.TargetField == "Line.Raw" { + gstr = p.Line.Raw + } else if val, ok := p.Parsed[n.Grok.TargetField]; ok { + gstr = val } else { - //grok failed, node failed - clog.Debugf("+ Grok '%s' didn't return data on '%s'", groklabel, gstr) - NodeState = false + clog.Debugf("(%s) target field '%s' doesn't exist in %v", n.rn, n.Grok.TargetField, p.Parsed) + return false, false, nil + } + } else if n.Grok.RunTimeValue != nil { + output, err := exprhelpers.Run(n.Grok.RunTimeValue, cachedExprEnv, clog, n.Debug) + if err != nil { + clog.Warningf("failed to run RunTimeValue : %v", err) + return false, false, nil } + switch out := output.(type) { + case string: + gstr = out + case int: + gstr = fmt.Sprintf("%d", out) + case float64, float32: + gstr = fmt.Sprintf("%f", out) + default: + clog.Errorf("unexpected return type for RunTimeValue : %T", output) + } + } + + var groklabel string + if n.Grok.RegexpName == "" { + groklabel = fmt.Sprintf("%5.5s...", n.Grok.RegexpValue) } else { - clog.Tracef("! No grok pattern : %p", n.Grok.RunTimeRegexp) + groklabel = n.Grok.RegexpName + } + + grok := n.Grok.RunTimeRegexp.Parse(gstr) + + if len(grok) == 0 { + // grok failed, node failed + clog.Debugf("+ Grok '%s' didn't return data on '%s'", groklabel, gstr) + return false, false, nil + } + + /*tag explicitly that the *current* node had a successful grok pattern. it's important to know success state*/ + NodeHasOKGrok = true + + clog.Debugf("+ Grok '%s' returned %d entries to merge in Parsed", groklabel, len(grok)) + // We managed to grok stuff, merged into parse + for k, v := range grok { + clog.Debugf("\t.Parsed['%s'] = '%s'", k, v) + p.Parsed[k] = v + } + // if the grok succeed, process associated statics + err := n.ProcessStatics(n.Grok.Statics, p) + if err != nil { + clog.Errorf("(%s) Failed to process statics : %v", n.rn, err) + return false, false, err + } + + return true, NodeHasOKGrok, nil +} + +func (n *Node) process(p *types.Event, ctx UnixParserCtx, expressionEnv map[string]interface{}) (bool, error) { + clog := n.Logger + + cachedExprEnv := expressionEnv + + clog.Tracef("Event entering node") + + NodeState, err := n.processFilter(cachedExprEnv) + if err != nil { + return false, err } - //Process the stash (data collection) if : a grok was present and succeeded, or if there is no grok + if !NodeState { + return false, nil + } + + if n.Name != "" { + NodesHits.With(prometheus.Labels{"source": p.Line.Src, "type": p.Line.Module, "name": n.Name}).Inc() + } + + isWhitelisted, err := n.processWhitelist(cachedExprEnv, p) + if err != nil { + return false, err + } + + NodeState, NodeHasOKGrok, err := n.processGrok(p, cachedExprEnv) + if err != nil { + return false, err + } + + // Process the stash (data collection) if : a grok was present and succeeded, or if there is no grok if NodeHasOKGrok || n.Grok.RunTimeRegexp == nil { for idx, stash := range n.Stash { - var value string - var key string + var ( + key string + value string + ) + if stash.ValueExpression == nil { clog.Warningf("Stash %d has no value expression, skipping", idx) continue } + if stash.KeyExpression == nil { clog.Warningf("Stash %d has no key expression, skipping", idx) continue } - //collect the data - output, err := expr.Run(stash.ValueExpression, cachedExprEnv) + // collect the data + output, err := exprhelpers.Run(stash.ValueExpression, cachedExprEnv, clog, n.Debug) if err != nil { clog.Warningf("Error while running stash val expression : %v", err) } - //can we expect anything else than a string ? + // can we expect anything else than a string ? switch output := output.(type) { case string: value = output @@ -338,12 +339,12 @@ func (n *Node) process(p *types.Event, ctx UnixParserCtx, expressionEnv map[stri continue } - //collect the key - output, err = expr.Run(stash.KeyExpression, cachedExprEnv) + // collect the key + output, err = exprhelpers.Run(stash.KeyExpression, cachedExprEnv, clog, n.Debug) if err != nil { clog.Warningf("Error while running stash key expression : %v", err) } - //can we expect anything else than a string ? + // can we expect anything else than a string ? switch output := output.(type) { case string: key = output @@ -355,15 +356,18 @@ func (n *Node) process(p *types.Event, ctx UnixParserCtx, expressionEnv map[stri } } - //Iterate on leafs + // Iterate on leafs for _, leaf := range n.LeavesNodes { ret, err := leaf.process(p, ctx, cachedExprEnv) if err != nil { clog.Tracef("\tNode (%s) failed : %v", leaf.rn, err) clog.Debugf("Event leaving node : ko") + return false, err } + clog.Tracef("\tsub-node (%s) ret : %v (strategy:%s)", leaf.rn, ret, n.OnSuccess) + if ret { NodeState = true /* if child is successful, stop processing */ @@ -384,12 +388,14 @@ func (n *Node) process(p *types.Event, ctx UnixParserCtx, expressionEnv map[stri clog.Tracef("State after nodes : %v", NodeState) - //grok or leafs failed, don't process statics + // grok or leafs failed, don't process statics if !NodeState { if n.Name != "" { NodesHitsKo.With(prometheus.Labels{"source": p.Line.Src, "type": p.Line.Module, "name": n.Name}).Inc() } + clog.Debugf("Event leaving node : ko") + return NodeState, nil } @@ -398,9 +404,10 @@ func (n *Node) process(p *types.Event, ctx UnixParserCtx, expressionEnv map[stri } /* - This is to apply statics when the node *has* whitelists that successfully matched the node. + This is to apply statics when the node either was whitelisted, or is not a whitelist (it has no expr/ips wl) + It is overconvoluted and should be simplified */ - if len(n.Statics) > 0 && (isWhitelisted || !hasWhitelist) { + if len(n.Statics) > 0 && (isWhitelisted || !n.ContainsWLs()) { clog.Debugf("+ Processing %d statics", len(n.Statics)) // if all else is good in whitelist, process node's statics err := n.ProcessStatics(n.Statics, p) @@ -415,9 +422,10 @@ func (n *Node) process(p *types.Event, ctx UnixParserCtx, expressionEnv map[stri if NodeState { clog.Debugf("Event leaving node : ok") log.Tracef("node is successful, check strategy") + if n.OnSuccess == "next_stage" { idx := stageidx(p.Stage, ctx.Stages) - //we're at the last stage + // we're at the last stage if idx+1 == len(ctx.Stages) { clog.Debugf("node reached the last stage : %s", p.Stage) } else { @@ -430,15 +438,16 @@ func (n *Node) process(p *types.Event, ctx UnixParserCtx, expressionEnv map[stri } else { clog.Debugf("Event leaving node : ko") } + clog.Tracef("Node successful, continue") + return NodeState, nil } func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { var err error - var valid bool - valid = false + valid := false dumpr := spew.ConfigState{MaxDepth: 1, DisablePointerAddresses: true} n.rn = seed.Generate() @@ -448,20 +457,17 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { /* if the node has debugging enabled, create a specific logger with debug that will be used only for processing this node ;) */ if n.Debug { - var clog = log.New() + clog := log.New() if err = types.ConfigureLogger(clog); err != nil { - log.Fatalf("While creating bucket-specific logger : %s", err) + return fmt.Errorf("while creating bucket-specific logger: %w", err) } + clog.SetLevel(log.DebugLevel) - n.Logger = clog.WithFields(log.Fields{ - "id": n.rn, - }) + n.Logger = clog.WithField("id", n.rn) n.Logger.Infof("%s has debug enabled", n.Name) } else { /* else bind it to the default one (might find something more elegant here)*/ - n.Logger = log.WithFields(log.Fields{ - "id": n.rn, - }) + n.Logger = log.WithField("id", n.rn) } /* display info about top-level nodes, they should be the only one with explicit stage name ?*/ @@ -469,31 +475,26 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { n.Logger.Tracef("Compiling : %s", dumpr.Sdump(n)) - //compile filter if present + // compile filter if present if n.Filter != "" { n.RunTimeFilter, err = expr.Compile(n.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { return fmt.Errorf("compilation of '%s' failed: %v", n.Filter, err) } - - if n.Debug { - n.ExprDebugger, err = exprhelpers.NewDebugger(n.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) - if err != nil { - log.Errorf("unable to build debug filter for '%s' : %s", n.Filter, err) - } - } - } /* handle pattern_syntax and groks */ for _, pattern := range n.SubGroks { n.Logger.Tracef("Adding subpattern '%s' : '%s'", pattern.Key, pattern.Value) + if err = pctx.Grok.Add(pattern.Key.(string), pattern.Value.(string)); err != nil { if errors.Is(err, grokky.ErrAlreadyExist) { n.Logger.Warningf("grok '%s' already registred", pattern.Key) continue } + n.Logger.Errorf("Unable to compile subpattern %s : %v", pattern.Key, err) + return err } } @@ -501,28 +502,36 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { /* load grok by name or compile in-place */ if n.Grok.RegexpName != "" { n.Logger.Tracef("+ Regexp Compilation '%s'", n.Grok.RegexpName) + n.Grok.RunTimeRegexp, err = pctx.Grok.Get(n.Grok.RegexpName) if err != nil { - return fmt.Errorf("unable to find grok '%s' : %v", n.Grok.RegexpName, err) + return fmt.Errorf("unable to find grok '%s': %v", n.Grok.RegexpName, err) } + if n.Grok.RunTimeRegexp == nil { return fmt.Errorf("empty grok '%s'", n.Grok.RegexpName) } + n.Logger.Tracef("%s regexp: %s", n.Grok.RegexpName, n.Grok.RunTimeRegexp.String()) + valid = true } else if n.Grok.RegexpValue != "" { if strings.HasSuffix(n.Grok.RegexpValue, "\n") { n.Logger.Debugf("Beware, pattern ends with \\n : '%s'", n.Grok.RegexpValue) } + n.Grok.RunTimeRegexp, err = pctx.Grok.Compile(n.Grok.RegexpValue) if err != nil { return fmt.Errorf("failed to compile grok '%s': %v", n.Grok.RegexpValue, err) } + if n.Grok.RunTimeRegexp == nil { // We shouldn't be here because compilation succeeded, so regexp shouldn't be nil return fmt.Errorf("grok compilation failure: %s", n.Grok.RegexpValue) } + n.Logger.Tracef("%s regexp : %s", n.Grok.RegexpValue, n.Grok.RunTimeRegexp.String()) + valid = true } @@ -536,7 +545,7 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { } /* load grok statics */ - //compile expr statics if present + // compile expr statics if present for idx := range n.Grok.Statics { if n.Grok.Statics[idx].ExpValue != "" { n.Grok.Statics[idx].RunTimeValue, err = expr.Compile(n.Grok.Statics[idx].ExpValue, @@ -545,6 +554,7 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { return err } } + valid = true } @@ -568,7 +578,7 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { } logLvl := n.Logger.Logger.GetLevel() - //init the cache, does it make sense to create it here just to be sure everything is fine ? + // init the cache, does it make sense to create it here just to be sure everything is fine ? if err = cache.CacheInit(cache.CacheCfg{ Size: n.Stash[i].MaxMapSize, TTL: n.Stash[i].TTLVal, @@ -589,14 +599,18 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { if !n.LeavesNodes[idx].Debug && n.Debug { n.LeavesNodes[idx].Debug = true } + if !n.LeavesNodes[idx].Profiling && n.Profiling { n.LeavesNodes[idx].Profiling = true } + n.LeavesNodes[idx].Stage = n.Stage + err = n.LeavesNodes[idx].compile(pctx, ectx) if err != nil { return err } + valid = true } @@ -609,51 +623,25 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { return err } } - valid = true - } - /* compile whitelists if present */ - for _, v := range n.Whitelist.Ips { - n.Whitelist.B_Ips = append(n.Whitelist.B_Ips, net.ParseIP(v)) - n.Logger.Debugf("adding ip %s to whitelists", net.ParseIP(v)) valid = true } - for _, v := range n.Whitelist.Cidrs { - _, tnet, err := net.ParseCIDR(v) - if err != nil { - n.Logger.Fatalf("Unable to parse cidr whitelist '%s' : %v.", v, err) - } - n.Whitelist.B_Cidrs = append(n.Whitelist.B_Cidrs, tnet) - n.Logger.Debugf("adding cidr %s to whitelists", tnet) - valid = true + /* compile whitelists if present */ + whitelistValid, err := n.CompileWLs() + if err != nil { + return err } - for _, filter := range n.Whitelist.Exprs { - expression := &ExprWhitelist{} - expression.Filter, err = expr.Compile(filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) - if err != nil { - n.Logger.Fatalf("Unable to compile whitelist expression '%s' : %v.", filter, err) - } - expression.ExprDebugger, err = exprhelpers.NewDebugger(filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) - if err != nil { - log.Errorf("unable to build debug filter for '%s' : %s", filter, err) - } - n.Whitelist.B_Exprs = append(n.Whitelist.B_Exprs, expression) - n.Logger.Debugf("adding expression %s to whitelists", filter) - valid = true - } + valid = valid || whitelistValid if !valid { /* node is empty, error force return */ n.Logger.Error("Node is empty or invalid, abort") n.Stage = "" - return fmt.Errorf("Node is empty") - } - if err := n.validate(pctx, ectx); err != nil { - return err + return errors.New("Node is empty") } - return nil + return n.validate(ectx) } diff --git a/pkg/parser/node_test.go b/pkg/parser/node_test.go index d85aa82a8ae..76d35a9ffb0 100644 --- a/pkg/parser/node_test.go +++ b/pkg/parser/node_test.go @@ -49,18 +49,18 @@ func TestParserConfigs(t *testing.T) { } for idx := range CfgTests { err := CfgTests[idx].NodeCfg.compile(pctx, EnricherCtx{}) - if CfgTests[idx].Compiles == true && err != nil { + if CfgTests[idx].Compiles && err != nil { t.Fatalf("Compile: (%d/%d) expected valid, got : %s", idx+1, len(CfgTests), err) } - if CfgTests[idx].Compiles == false && err == nil { + if !CfgTests[idx].Compiles && err == nil { t.Fatalf("Compile: (%d/%d) expected error", idx+1, len(CfgTests)) } - err = CfgTests[idx].NodeCfg.validate(pctx, EnricherCtx{}) - if CfgTests[idx].Valid == true && err != nil { + err = CfgTests[idx].NodeCfg.validate(EnricherCtx{}) + if CfgTests[idx].Valid && err != nil { t.Fatalf("Valid: (%d/%d) expected valid, got : %s", idx+1, len(CfgTests), err) } - if CfgTests[idx].Valid == false && err == nil { + if !CfgTests[idx].Valid && err == nil { t.Fatalf("Valid: (%d/%d) expected error", idx+1, len(CfgTests)) } } diff --git a/pkg/parser/parsing_test.go b/pkg/parser/parsing_test.go index 04d08cc2785..269d51a1ba2 100644 --- a/pkg/parser/parsing_test.go +++ b/pkg/parser/parsing_test.go @@ -24,17 +24,21 @@ type TestFile struct { Results []types.Event `yaml:"results,omitempty"` } -var debug bool = false +var debug = false func TestParser(t *testing.T) { debug = true + log.SetLevel(log.InfoLevel) - var envSetting = os.Getenv("TEST_ONLY") + + envSetting := os.Getenv("TEST_ONLY") + pctx, ectx, err := prepTests() if err != nil { t.Fatalf("failed to load env : %s", err) } - //Init the enricher + + // Init the enricher if envSetting != "" { if err := testOneParser(pctx, ectx, envSetting, nil); err != nil { t.Fatalf("Test '%s' failed : %s", envSetting, err) @@ -44,12 +48,15 @@ func TestParser(t *testing.T) { if err != nil { t.Fatalf("Unable to read test directory : %s", err) } + for _, fd := range fds { if !fd.IsDir() { continue } + fname := "./tests/" + fd.Name() log.Infof("Running test on %s", fname) + if err := testOneParser(pctx, ectx, fname, nil); err != nil { t.Fatalf("Test '%s' failed : %s", fname, err) } @@ -59,13 +66,17 @@ func TestParser(t *testing.T) { func BenchmarkParser(t *testing.B) { log.Printf("start bench !!!!") + debug = false + log.SetLevel(log.ErrorLevel) + pctx, ectx, err := prepTests() if err != nil { t.Fatalf("failed to load env : %s", err) } - var envSetting = os.Getenv("TEST_ONLY") + + envSetting := os.Getenv("TEST_ONLY") if envSetting != "" { if err := testOneParser(pctx, ectx, envSetting, t); err != nil { @@ -76,12 +87,15 @@ func BenchmarkParser(t *testing.B) { if err != nil { t.Fatalf("Unable to read test directory : %s", err) } + for _, fd := range fds { if !fd.IsDir() { continue } + fname := "./tests/" + fd.Name() log.Infof("Running test on %s", fname) + if err := testOneParser(pctx, ectx, fname, t); err != nil { t.Fatalf("Test '%s' failed : %s", fname, err) } @@ -91,49 +105,58 @@ func BenchmarkParser(t *testing.B) { func testOneParser(pctx *UnixParserCtx, ectx EnricherCtx, dir string, b *testing.B) error { var ( - err error - pnodes []Node - + err error + pnodes []Node parser_configs []Stagefile ) + log.Warningf("testing %s", dir) + parser_cfg_file := fmt.Sprintf("%s/parsers.yaml", dir) + cfg, err := os.ReadFile(parser_cfg_file) if err != nil { - return fmt.Errorf("failed opening %s : %s", parser_cfg_file, err) + return fmt.Errorf("failed opening %s: %w", parser_cfg_file, err) } + tmpl, err := template.New("test").Parse(string(cfg)) if err != nil { - return fmt.Errorf("failed to parse template %s : %s", cfg, err) + return fmt.Errorf("failed to parse template %s: %w", cfg, err) } + var out bytes.Buffer + err = tmpl.Execute(&out, map[string]string{"TestDirectory": dir}) if err != nil { panic(err) } + if err = yaml.UnmarshalStrict(out.Bytes(), &parser_configs); err != nil { - return fmt.Errorf("failed unmarshaling %s : %s", parser_cfg_file, err) + return fmt.Errorf("failed to parse %s: %w", parser_cfg_file, err) } pnodes, err = LoadStages(parser_configs, pctx, ectx) if err != nil { - return fmt.Errorf("unable to load parser config : %s", err) + return fmt.Errorf("unable to load parser config: %w", err) } - //TBD: Load post overflows - //func testFile(t *testing.T, file string, pctx UnixParserCtx, nodes []Node) bool { + // TBD: Load post overflows + // func testFile(t *testing.T, file string, pctx UnixParserCtx, nodes []Node) bool { parser_test_file := fmt.Sprintf("%s/test.yaml", dir) tests := loadTestFile(parser_test_file) count := 1 + if b != nil { count = b.N b.ResetTimer() } - for n := 0; n < count; n++ { - if testFile(tests, *pctx, pnodes) != true { - return fmt.Errorf("test failed !") + + for range(count) { + if !testFile(tests, *pctx, pnodes) { + return errors.New("test failed") } } + return nil } @@ -147,26 +170,34 @@ func prepTests() (*UnixParserCtx, EnricherCtx, error) { err = exprhelpers.Init(nil) if err != nil { - log.Fatalf("exprhelpers init failed: %s", err) + return nil, ectx, fmt.Errorf("exprhelpers init failed: %w", err) } - //Load enrichment + // Load enrichment datadir := "./test_data/" - ectx, err = Loadplugin(datadir) + + err = exprhelpers.GeoIPInit(datadir) if err != nil { - log.Fatalf("failed to load plugin geoip : %v", err) + log.Fatalf("unable to initialize GeoIP: %s", err) } + + ectx, err = Loadplugin() + if err != nil { + return nil, ectx, fmt.Errorf("failed to load plugin geoip: %v", err) + } + log.Printf("Loaded -> %+v", ectx) - //Load the parser patterns + // Load the parser patterns cfgdir := "../../config/" /* this should be refactored to 2 lines :p */ // Init the parser pctx, err = Init(map[string]interface{}{"patterns": cfgdir + string("/patterns/"), "data": "./tests/"}) if err != nil { - return nil, ectx, fmt.Errorf("failed to initialize parser : %v", err) + return nil, ectx, fmt.Errorf("failed to initialize parser: %v", err) } + return pctx, ectx, nil } @@ -175,43 +206,54 @@ func loadTestFile(file string) []TestFile { if err != nil { log.Fatalf("yamlFile.Get err #%v ", err) } + dec := yaml.NewDecoder(yamlFile) dec.SetStrict(true) + var testSet []TestFile + for { tf := TestFile{} + err := dec.Decode(&tf) if err != nil { if errors.Is(err, io.EOF) { break } + log.Fatalf("Failed to load testfile '%s' yaml error : %v", file, err) + return nil } + testSet = append(testSet, tf) } + return testSet } func matchEvent(expected types.Event, out types.Event, debug bool) ([]string, bool) { var retInfo []string - var valid = false + + valid := false expectMaps := []map[string]string{expected.Parsed, expected.Meta, expected.Enriched} outMaps := []map[string]string{out.Parsed, out.Meta, out.Enriched} outLabels := []string{"Parsed", "Meta", "Enriched"} - //allow to check as well for stage and processed flags + // allow to check as well for stage and processed flags if expected.Stage != "" { if expected.Stage != out.Stage { if debug { retInfo = append(retInfo, fmt.Sprintf("mismatch stage %s != %s", expected.Stage, out.Stage)) } + goto checkFinished - } else { - valid = true - if debug { - retInfo = append(retInfo, fmt.Sprintf("ok stage %s == %s", expected.Stage, out.Stage)) - } + } + + valid = true + + if debug { + retInfo = append(retInfo, fmt.Sprintf("ok stage %s == %s", expected.Stage, out.Stage)) } } @@ -219,48 +261,58 @@ func matchEvent(expected types.Event, out types.Event, debug bool) ([]string, bo if debug { retInfo = append(retInfo, fmt.Sprintf("mismatch process %t != %t", expected.Process, out.Process)) } + goto checkFinished - } else { - valid = true - if debug { - retInfo = append(retInfo, fmt.Sprintf("ok process %t == %t", expected.Process, out.Process)) - } + } + + valid = true + + if debug { + retInfo = append(retInfo, fmt.Sprintf("ok process %t == %t", expected.Process, out.Process)) } if expected.Whitelisted != out.Whitelisted { if debug { retInfo = append(retInfo, fmt.Sprintf("mismatch whitelist %t != %t", expected.Whitelisted, out.Whitelisted)) } + goto checkFinished - } else { - if debug { - retInfo = append(retInfo, fmt.Sprintf("ok whitelist %t == %t", expected.Whitelisted, out.Whitelisted)) - } - valid = true } - for mapIdx := 0; mapIdx < len(expectMaps); mapIdx++ { + if debug { + retInfo = append(retInfo, fmt.Sprintf("ok whitelist %t == %t", expected.Whitelisted, out.Whitelisted)) + } + + valid = true + + for mapIdx := range(len(expectMaps)) { for expKey, expVal := range expectMaps[mapIdx] { - if outVal, ok := outMaps[mapIdx][expKey]; ok { - if outVal == expVal { //ok entry - if debug { - retInfo = append(retInfo, fmt.Sprintf("ok %s[%s] %s == %s", outLabels[mapIdx], expKey, expVal, outVal)) - } - valid = true - } else { //mismatch entry - if debug { - retInfo = append(retInfo, fmt.Sprintf("mismatch %s[%s] %s != %s", outLabels[mapIdx], expKey, expVal, outVal)) - } - valid = false - goto checkFinished - } - } else { //missing entry + outVal, ok := outMaps[mapIdx][expKey] + if !ok { if debug { retInfo = append(retInfo, fmt.Sprintf("missing entry %s[%s]", outLabels[mapIdx], expKey)) } + valid = false + goto checkFinished } + + if outVal != expVal { // ok entry + if debug { + retInfo = append(retInfo, fmt.Sprintf("mismatch %s[%s] %s != %s", outLabels[mapIdx], expKey, expVal, outVal)) + } + + valid = false + + goto checkFinished + } + + if debug { + retInfo = append(retInfo, fmt.Sprintf("ok %s[%s] %s == %s", outLabels[mapIdx], expKey, expVal, outVal)) + } + + valid = true } } checkFinished: @@ -273,6 +325,7 @@ checkFinished: retInfo = append(retInfo, fmt.Sprintf("KO ! \n\t%s", strings.Join(retInfo, "\n\t"))) } } + return retInfo, valid } @@ -284,9 +337,10 @@ func testSubSet(testSet TestFile, pctx UnixParserCtx, nodes []Node) (bool, error if err != nil { log.Errorf("Failed to process %s : %v", spew.Sdump(in), err) } - //log.Infof("Parser output : %s", spew.Sdump(out)) + // log.Infof("Parser output : %s", spew.Sdump(out)) results = append(results, out) } + log.Infof("parsed %d lines", len(testSet.Lines)) log.Infof("got %d results", len(results)) @@ -295,21 +349,22 @@ func testSubSet(testSet TestFile, pctx UnixParserCtx, nodes []Node) (bool, error only the keys of the expected part are checked against result */ if len(testSet.Results) == 0 && len(results) == 0 { - log.Fatal("No results, no tests, abort.") - return false, fmt.Errorf("no tests, no results") + return false, errors.New("no tests, no results") } reCheck: failinfo := []string{} + for ridx, result := range results { for eidx, expected := range testSet.Results { explain, match := matchEvent(expected, result, debug) - if match == true { + if match { log.Infof("expected %d/%d matches result %d/%d", eidx, len(testSet.Results), ridx, len(results)) + if len(explain) > 0 { log.Printf("-> %s", explain[len(explain)-1]) } - //don't do this at home : delete current element from list and redo + // don't do this at home : delete current element from list and redo results[len(results)-1], results[ridx] = results[ridx], results[len(results)-1] results = results[:len(results)-1] @@ -317,34 +372,40 @@ reCheck: testSet.Results = testSet.Results[:len(testSet.Results)-1] goto reCheck - } else { - failinfo = append(failinfo, explain...) } + + failinfo = append(failinfo, explain...) } } + if len(results) > 0 { log.Printf("Errors : %s", strings.Join(failinfo, " / ")) return false, fmt.Errorf("leftover results : %+v", results) } + if len(testSet.Results) > 0 { log.Printf("Errors : %s", strings.Join(failinfo, " / ")) return false, fmt.Errorf("leftover expected results : %+v", testSet.Results) } + return true, nil } func testFile(testSet []TestFile, pctx UnixParserCtx, nodes []Node) bool { log.Warning("Going to process one test set") + for _, tf := range testSet { - //func testSubSet(testSet TestFile, pctx UnixParserCtx, nodes []Node) (bool, error) { + // func testSubSet(testSet TestFile, pctx UnixParserCtx, nodes []Node) (bool, error) { testOk, err := testSubSet(tf, pctx, nodes) if err != nil { log.Fatalf("test failed : %s", err) } + if !testOk { log.Fatalf("failed test : %+v", tf) } } + return true } @@ -369,48 +430,61 @@ func TestGeneratePatternsDoc(t *testing.T) { if err != nil { t.Fatalf("unable to load patterns : %s", err) } + log.Infof("-> %s", spew.Sdump(pctx)) /*don't judge me, we do it for the users*/ p := make(PairList, len(pctx.Grok.Patterns)) i := 0 + for key, val := range pctx.Grok.Patterns { p[i] = Pair{key, val} p[i].Value = strings.ReplaceAll(p[i].Value, "{%{", "\\{\\%\\{") i++ } + sort.Sort(p) - f, err := os.OpenFile("./patterns-documentation.md", os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0644) + f, err := os.OpenFile("./patterns-documentation.md", os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644) if err != nil { t.Fatalf("failed to open : %s", err) } + if _, err := f.WriteString("# Patterns documentation\n\n"); err != nil { t.Fatal("failed to write to file") } + if _, err := f.WriteString("You will find here a generated documentation of all the patterns loaded by crowdsec.\n"); err != nil { t.Fatal("failed to write to file") } + if _, err := f.WriteString("They are sorted by pattern length, and are meant to be used in parsers, in the form %{PATTERN_NAME}.\n"); err != nil { t.Fatal("failed to write to file") } + if _, err := f.WriteString("\n\n"); err != nil { t.Fatal("failed to write to file") } + for _, k := range p { if _, err := fmt.Fprintf(f, "## %s\n\nPattern :\n```\n%s\n```\n\n", k.Key, k.Value); err != nil { t.Fatal("failed to write to file") } + fmt.Printf("%v\t%v\n", k.Key, k.Value) } + if _, err := f.WriteString("\n"); err != nil { t.Fatal("failed to write to file") } + if _, err := f.WriteString("# Documentation generation\n"); err != nil { t.Fatal("failed to write to file") } + if _, err := f.WriteString("This documentation is generated by `pkg/parser` : `GO_WANT_TEST_DOC=1 go test -run TestGeneratePatternsDoc`\n"); err != nil { t.Fatal("failed to write to file") } + f.Close() } diff --git a/pkg/parser/runtime.go b/pkg/parser/runtime.go index 6f957e7a222..8068690b68f 100644 --- a/pkg/parser/runtime.go +++ b/pkg/parser/runtime.go @@ -14,11 +14,12 @@ import ( "sync" "time" - "github.com/antonmedv/expr" "github.com/mohae/deepcopy" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" + "github.com/crowdsecurity/crowdsec/pkg/dumps" + "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -41,8 +42,8 @@ func SetTargetByName(target string, value string, evt *types.Event) bool { iter := reflect.ValueOf(evt).Elem() if (iter == reflect.Value{}) || iter.IsZero() { - log.Tracef("event is nill") - //event is nill + log.Tracef("event is nil") + //event is nil return false } for _, f := range strings.Split(target, ".") { @@ -117,7 +118,7 @@ func (n *Node) ProcessStatics(statics []ExtraField, event *types.Event) error { if static.Value != "" { value = static.Value } else if static.RunTimeValue != nil { - output, err := expr.Run(static.RunTimeValue, map[string]interface{}{"evt": event}) + output, err := exprhelpers.Run(static.RunTimeValue, map[string]interface{}{"evt": event}, clog, n.Debug) if err != nil { clog.Warningf("failed to run RunTimeValue : %v", err) continue @@ -127,6 +128,8 @@ func (n *Node) ProcessStatics(statics []ExtraField, event *types.Event) error { value = out case int: value = strconv.Itoa(out) + case float64, float32: + value = fmt.Sprintf("%f", out) case map[string]interface{}: clog.Warnf("Expression '%s' returned a map, please use ToJsonString() to convert it to string if you want to keep it as is, or refine your expression to extract a string", static.ExpValue) case []interface{}: @@ -134,7 +137,7 @@ func (n *Node) ProcessStatics(statics []ExtraField, event *types.Event) error { case nil: clog.Debugf("Expression '%s' returned nil, skipping", static.ExpValue) default: - clog.Errorf("unexpected return type for RunTimeValue : %T", output) + clog.Errorf("unexpected return type for '%s' : %T", static.ExpValue, output) return errors.New("unexpected return type for RunTimeValue") } } @@ -152,7 +155,7 @@ func (n *Node) ProcessStatics(statics []ExtraField, event *types.Event) error { /*still way too hackish, but : inject all the results in enriched, and */ if enricherPlugin, ok := n.EnrichFunctions.Registered[static.Method]; ok { clog.Tracef("Found method '%s'", static.Method) - ret, err := enricherPlugin.EnrichFunc(value, event, enricherPlugin.Ctx, n.Logger) + ret, err := enricherPlugin.EnrichFunc(value, event, n.Logger.WithField("method", static.Method)) if err != nil { clog.Errorf("method '%s' returned an error : %v", static.Method, err) } @@ -218,6 +221,24 @@ var NodesHitsKo = prometheus.NewCounterVec( []string{"source", "type", "name"}, ) +// + +var NodesWlHitsOk = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "cs_node_wl_hits_ok_total", + Help: "Total events successfully whitelisted by node.", + }, + []string{"source", "type", "name", "reason"}, +) + +var NodesWlHits = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "cs_node_wl_hits_total", + Help: "Total events processed by whitelist node.", + }, + []string{"source", "type", "name", "reason"}, +) + func stageidx(stage string, stages []string) int { for i, v := range stages { if stage == v { @@ -227,14 +248,10 @@ func stageidx(stage string, stages []string) int { return -1 } -type ParserResult struct { - Evt types.Event - Success bool -} - var ParseDump bool var DumpFolder string -var StageParseCache map[string]map[string][]ParserResult + +var StageParseCache dumps.ParserResults var StageParseMutex sync.Mutex func Parse(ctx UnixParserCtx, xp types.Event, nodes []Node) (types.Event, error) { @@ -269,9 +286,9 @@ func Parse(ctx UnixParserCtx, xp types.Event, nodes []Node) (types.Event, error) if ParseDump { if StageParseCache == nil { StageParseMutex.Lock() - StageParseCache = make(map[string]map[string][]ParserResult) - StageParseCache["success"] = make(map[string][]ParserResult) - StageParseCache["success"][""] = make([]ParserResult, 0) + StageParseCache = make(dumps.ParserResults) + StageParseCache["success"] = make(map[string][]dumps.ParserResult) + StageParseCache["success"][""] = make([]dumps.ParserResult, 0) StageParseMutex.Unlock() } } @@ -280,7 +297,7 @@ func Parse(ctx UnixParserCtx, xp types.Event, nodes []Node) (types.Event, error) if ParseDump { StageParseMutex.Lock() if _, ok := StageParseCache[stage]; !ok { - StageParseCache[stage] = make(map[string][]ParserResult) + StageParseCache[stage] = make(map[string][]dumps.ParserResult) } StageParseMutex.Unlock() } @@ -320,13 +337,18 @@ func Parse(ctx UnixParserCtx, xp types.Event, nodes []Node) (types.Event, error) } clog.Tracef("node (%s) ret : %v", node.rn, ret) if ParseDump { + var parserIdxInStage int StageParseMutex.Lock() if len(StageParseCache[stage][node.Name]) == 0 { - StageParseCache[stage][node.Name] = make([]ParserResult, 0) + StageParseCache[stage][node.Name] = make([]dumps.ParserResult, 0) + parserIdxInStage = len(StageParseCache[stage]) + } else { + parserIdxInStage = StageParseCache[stage][node.Name][0].Idx } StageParseMutex.Unlock() + evtcopy := deepcopy.Copy(event) - parserInfo := ParserResult{Evt: evtcopy.(types.Event), Success: ret} + parserInfo := dumps.ParserResult{Evt: evtcopy.(types.Event), Success: ret, Idx: parserIdxInStage} StageParseMutex.Lock() StageParseCache[stage][node.Name] = append(StageParseCache[stage][node.Name], parserInfo) StageParseMutex.Unlock() diff --git a/pkg/parser/stage.go b/pkg/parser/stage.go index 37d43fbfe96..b98db350254 100644 --- a/pkg/parser/stage.go +++ b/pkg/parser/stage.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "io" + // enable profiling _ "net/http/pprof" "os" "sort" @@ -20,7 +21,7 @@ import ( log "github.com/sirupsen/logrus" yaml "gopkg.in/yaml.v2" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" + "github.com/crowdsecurity/crowdsec/pkg/cwversion/constraint" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" ) @@ -57,6 +58,7 @@ func LoadStages(stageFiles []Stagefile, pctx *UnixParserCtx, ectx EnricherCtx) ( if err != nil { return nil, fmt.Errorf("can't access parsing configuration file %s : %s", stageFile.Filename, err) } + defer yamlFile.Close() //process the yaml dec := yaml.NewDecoder(yamlFile) dec.SetStrict(true) @@ -70,12 +72,12 @@ func LoadStages(stageFiles []Stagefile, pctx *UnixParserCtx, ectx EnricherCtx) ( log.Tracef("End of yaml file") break } - log.Fatalf("Error decoding parsing configuration file '%s': %v", stageFile.Filename, err) + return nil, fmt.Errorf("error decoding parsing configuration file '%s': %v", stageFile.Filename, err) } //check for empty bucket if node.Name == "" && node.Description == "" && node.Author == "" { - log.Infof("Node in %s has no name,author or description. Skipping.", stageFile.Filename) + log.Infof("Node in %s has no name, author or description. Skipping.", stageFile.Filename) continue } //check compat @@ -83,12 +85,12 @@ func LoadStages(stageFiles []Stagefile, pctx *UnixParserCtx, ectx EnricherCtx) ( log.Tracef("no version in %s, assuming '1.0'", node.Name) node.FormatVersion = "1.0" } - ok, err := cwversion.Satisfies(node.FormatVersion, cwversion.Constraint_parser) + ok, err := constraint.Satisfies(node.FormatVersion, constraint.Parser) if err != nil { - log.Fatalf("Failed to check version : %s", err) + return nil, fmt.Errorf("failed to check version : %s", err) } if !ok { - log.Errorf("%s : %s doesn't satisfy parser format %s, skip", node.Name, node.FormatVersion, cwversion.Constraint_parser) + log.Errorf("%s : %s doesn't satisfy parser format %s, skip", node.Name, node.FormatVersion, constraint.Parser) continue } diff --git a/pkg/parser/tests/reverse-dns-enrich/test.yaml b/pkg/parser/tests/reverse-dns-enrich/test.yaml index 1495d3f86f2..a492669c5d3 100644 --- a/pkg/parser/tests/reverse-dns-enrich/test.yaml +++ b/pkg/parser/tests/reverse-dns-enrich/test.yaml @@ -1,14 +1,14 @@ #these are the events we input into parser lines: - Enriched: - IpToResolve: 8.8.8.8 + IpToResolve: 1.1.1.1 - Enriched: IpToResolve: 1.2.3.4 #these are the results we expect from the parser results: - Enriched: - reverse_dns: dns.google. - IpToResolve: 8.8.8.8 + reverse_dns: one.one.one.one. + IpToResolve: 1.1.1.1 Meta: did_dns_succeeded: yes Process: true diff --git a/pkg/parser/tests/whitelist-base/base-grok.yaml b/pkg/parser/tests/whitelist-base/base-grok.yaml index 44cbd103546..7a8f6d8d8e7 100644 --- a/pkg/parser/tests/whitelist-base/base-grok.yaml +++ b/pkg/parser/tests/whitelist-base/base-grok.yaml @@ -4,7 +4,7 @@ debug: true whitelist: reason: "Whitelist tests" ip: - - 8.8.8.8 + - 1.1.1.1 cidr: - "1.2.3.0/24" expression: diff --git a/pkg/parser/tests/whitelist-base/test.yaml b/pkg/parser/tests/whitelist-base/test.yaml index 4524e957ed2..1ad2b2773de 100644 --- a/pkg/parser/tests/whitelist-base/test.yaml +++ b/pkg/parser/tests/whitelist-base/test.yaml @@ -2,7 +2,7 @@ lines: - Meta: test: test1 - source_ip: 8.8.8.8 + source_ip: 1.1.1.1 statics: toto - Meta: test: test2 diff --git a/pkg/parser/unix_parser.go b/pkg/parser/unix_parser.go index d5d91f9320d..351de8ade56 100644 --- a/pkg/parser/unix_parser.go +++ b/pkg/parser/unix_parser.go @@ -3,7 +3,7 @@ package parser import ( "fmt" "os" - "path" + "path/filepath" "sort" "strings" @@ -46,7 +46,7 @@ func Init(c map[string]interface{}) (*UnixParserCtx, error) { if strings.Contains(f.Name(), ".") { continue } - if err := r.Grok.AddFromFile(path.Join(c["patterns"].(string), f.Name())); err != nil { + if err := r.Grok.AddFromFile(filepath.Join(c["patterns"].(string), f.Name())); err != nil { log.Errorf("failed to load pattern %s : %v", f.Name(), err) return nil, err } @@ -57,29 +57,29 @@ func Init(c map[string]interface{}) (*UnixParserCtx, error) { // Return new parsers // nodes and povfwnodes are already initialized in parser.LoadStages -func NewParsers() *Parsers { +func NewParsers(hub *cwhub.Hub) *Parsers { parsers := &Parsers{ Ctx: &UnixParserCtx{}, Povfwctx: &UnixParserCtx{}, StageFiles: make([]Stagefile, 0), PovfwStageFiles: make([]Stagefile, 0), } - for _, itemType := range []string{cwhub.PARSERS, cwhub.PARSERS_OVFLW} { - for _, hubParserItem := range cwhub.GetItemMap(itemType) { - if hubParserItem.Installed { - stagefile := Stagefile{ - Filename: hubParserItem.LocalPath, - Stage: hubParserItem.Stage, - } - if itemType == cwhub.PARSERS { - parsers.StageFiles = append(parsers.StageFiles, stagefile) - } - if itemType == cwhub.PARSERS_OVFLW { - parsers.PovfwStageFiles = append(parsers.PovfwStageFiles, stagefile) - } + + for _, itemType := range []string{cwhub.PARSERS, cwhub.POSTOVERFLOWS} { + for _, hubParserItem := range hub.GetInstalledByType(itemType, false) { + stagefile := Stagefile{ + Filename: hubParserItem.State.LocalPath, + Stage: hubParserItem.Stage, + } + if itemType == cwhub.PARSERS { + parsers.StageFiles = append(parsers.StageFiles, stagefile) + } + if itemType == cwhub.POSTOVERFLOWS { + parsers.PovfwStageFiles = append(parsers.PovfwStageFiles, stagefile) } } } + if parsers.StageFiles != nil { sort.Slice(parsers.StageFiles, func(i, j int) bool { return parsers.StageFiles[i].Filename < parsers.StageFiles[j].Filename @@ -97,16 +97,20 @@ func NewParsers() *Parsers { func LoadParsers(cConfig *csconfig.Config, parsers *Parsers) (*Parsers, error) { var err error - patternsDir := path.Join(cConfig.Crowdsec.ConfigDir, "patterns/") + patternsDir := cConfig.ConfigPaths.PatternDir log.Infof("Loading grok library %s", patternsDir) /* load base regexps for two grok parsers */ - parsers.Ctx, err = Init(map[string]interface{}{"patterns": patternsDir, - "data": cConfig.Crowdsec.DataDir}) + parsers.Ctx, err = Init(map[string]interface{}{ + "patterns": patternsDir, + "data": cConfig.ConfigPaths.DataDir, + }) if err != nil { return parsers, fmt.Errorf("failed to load parser patterns : %v", err) } - parsers.Povfwctx, err = Init(map[string]interface{}{"patterns": patternsDir, - "data": cConfig.Crowdsec.DataDir}) + parsers.Povfwctx, err = Init(map[string]interface{}{ + "patterns": patternsDir, + "data": cConfig.ConfigPaths.DataDir, + }) if err != nil { return parsers, fmt.Errorf("failed to load postovflw parser patterns : %v", err) } @@ -116,7 +120,7 @@ func LoadParsers(cConfig *csconfig.Config, parsers *Parsers) (*Parsers, error) { */ log.Infof("Loading enrich plugins") - parsers.EnricherCtx, err = Loadplugin(cConfig.Crowdsec.DataDir) + parsers.EnricherCtx, err = Loadplugin() if err != nil { return parsers, fmt.Errorf("failed to load enrich plugin : %v", err) } @@ -148,6 +152,12 @@ func LoadParsers(cConfig *csconfig.Config, parsers *Parsers) (*Parsers, error) { parsers.Ctx.Profiling = true parsers.Povfwctx.Profiling = true } - + /* + Reset CTX grok to reduce memory footprint after we compile all the patterns + */ + parsers.Ctx.Grok = grokky.Host{} + parsers.Povfwctx.Grok = grokky.Host{} + parsers.StageFiles = []Stagefile{} + parsers.PovfwStageFiles = []Stagefile{} return parsers, nil } diff --git a/pkg/parser/whitelist.go b/pkg/parser/whitelist.go index e2f179fb3a1..e7b93a8d7da 100644 --- a/pkg/parser/whitelist.go +++ b/pkg/parser/whitelist.go @@ -1,11 +1,15 @@ package parser import ( + "fmt" "net" - "github.com/antonmedv/expr/vm" + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" + "github.com/prometheus/client_golang/prometheus" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" + "github.com/crowdsecurity/crowdsec/pkg/types" ) type Whitelist struct { @@ -19,6 +23,115 @@ type Whitelist struct { } type ExprWhitelist struct { - Filter *vm.Program - ExprDebugger *exprhelpers.ExprDebugger // used to debug expression by printing the content of each variable of the expression + Filter *vm.Program +} + +func (n *Node) ContainsWLs() bool { + return n.ContainsIPLists() || n.ContainsExprLists() +} + +func (n *Node) ContainsExprLists() bool { + return len(n.Whitelist.B_Exprs) > 0 +} + +func (n *Node) ContainsIPLists() bool { + return len(n.Whitelist.B_Ips) > 0 || len(n.Whitelist.B_Cidrs) > 0 +} + +func (n *Node) CheckIPsWL(p *types.Event) bool { + srcs := p.ParseIPSources() + isWhitelisted := false + if !n.ContainsIPLists() { + return isWhitelisted + } + NodesWlHits.With(prometheus.Labels{"source": p.Line.Src, "type": p.Line.Module, "name": n.Name, "reason": n.Whitelist.Reason}).Inc() + for _, src := range srcs { + if isWhitelisted { + break + } + for _, v := range n.Whitelist.B_Ips { + if v.Equal(src) { + n.Logger.Debugf("Event from [%s] is whitelisted by IP (%s), reason [%s]", src, v, n.Whitelist.Reason) + isWhitelisted = true + break + } + n.Logger.Tracef("whitelist: %s is not eq [%s]", src, v) + } + for _, v := range n.Whitelist.B_Cidrs { + if v.Contains(src) { + n.Logger.Debugf("Event from [%s] is whitelisted by CIDR (%s), reason [%s]", src, v, n.Whitelist.Reason) + isWhitelisted = true + break + } + n.Logger.Tracef("whitelist: %s not in [%s]", src, v) + } + } + if isWhitelisted { + NodesWlHitsOk.With(prometheus.Labels{"source": p.Line.Src, "type": p.Line.Module, "name": n.Name, "reason": n.Whitelist.Reason}).Inc() + } + return isWhitelisted +} + +func (n *Node) CheckExprWL(cachedExprEnv map[string]interface{}, p *types.Event) (bool, error) { + isWhitelisted := false + + if !n.ContainsExprLists() { + return false, nil + } + NodesWlHits.With(prometheus.Labels{"source": p.Line.Src, "type": p.Line.Module, "name": n.Name, "reason": n.Whitelist.Reason}).Inc() + /* run whitelist expression tests anyway */ + for eidx, e := range n.Whitelist.B_Exprs { + //if we already know the event is whitelisted, skip the rest of the expressions + if isWhitelisted { + break + } + + output, err := exprhelpers.Run(e.Filter, cachedExprEnv, n.Logger, n.Debug) + if err != nil { + n.Logger.Warningf("failed to run whitelist expr : %v", err) + n.Logger.Debug("Event leaving node : ko") + return isWhitelisted, err + } + switch out := output.(type) { + case bool: + if out { + n.Logger.Debugf("Event is whitelisted by expr, reason [%s]", n.Whitelist.Reason) + isWhitelisted = true + } + default: + n.Logger.Errorf("unexpected type %t (%v) while running '%s'", output, output, n.Whitelist.Exprs[eidx]) + } + } + if isWhitelisted { + NodesWlHitsOk.With(prometheus.Labels{"source": p.Line.Src, "type": p.Line.Module, "name": n.Name, "reason": n.Whitelist.Reason}).Inc() + } + return isWhitelisted, nil +} + +func (n *Node) CompileWLs() (bool, error) { + for _, v := range n.Whitelist.Ips { + n.Whitelist.B_Ips = append(n.Whitelist.B_Ips, net.ParseIP(v)) + n.Logger.Debugf("adding ip %s to whitelists", net.ParseIP(v)) + } + + for _, v := range n.Whitelist.Cidrs { + _, tnet, err := net.ParseCIDR(v) + if err != nil { + return false, fmt.Errorf("unable to parse cidr whitelist '%s' : %v", v, err) + } + n.Whitelist.B_Cidrs = append(n.Whitelist.B_Cidrs, tnet) + n.Logger.Debugf("adding cidr %s to whitelists", tnet) + } + + for _, filter := range n.Whitelist.Exprs { + var err error + expression := &ExprWhitelist{} + expression.Filter, err = expr.Compile(filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) + if err != nil { + return false, fmt.Errorf("unable to compile whitelist expression '%s' : %v", filter, err) + } + n.Whitelist.B_Exprs = append(n.Whitelist.B_Exprs, expression) + n.Logger.Debugf("adding expression %s to whitelists", filter) + } + return n.ContainsWLs(), nil } diff --git a/pkg/parser/whitelist_test.go b/pkg/parser/whitelist_test.go new file mode 100644 index 00000000000..02846f17fc1 --- /dev/null +++ b/pkg/parser/whitelist_test.go @@ -0,0 +1,298 @@ +package parser + +import ( + "testing" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + + "github.com/crowdsecurity/go-cs-lib/cstest" + + "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +func TestWhitelistCompile(t *testing.T) { + node := &Node{ + Logger: log.NewEntry(log.New()), + } + tests := []struct { + name string + whitelist Whitelist + expectedErr string + }{ + { + name: "Valid CIDR whitelist", + whitelist: Whitelist{ + Reason: "test", + Cidrs: []string{ + "127.0.0.1/24", + }, + }, + }, + { + name: "Invalid CIDR whitelist", + whitelist: Whitelist{ + Reason: "test", + Cidrs: []string{ + "127.0.0.1/1000", + }, + }, + expectedErr: "invalid CIDR address", + }, + { + name: "Valid EXPR whitelist", + whitelist: Whitelist{ + Reason: "test", + Exprs: []string{ + "1==1", + }, + }, + }, + { + name: "Invalid EXPR whitelist", + whitelist: Whitelist{ + Reason: "test", + Exprs: []string{ + "evt.THISPROPERTYSHOULDERROR == true", + }, + }, + expectedErr: "types.Event has no field", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + node.Whitelist = tt.whitelist + _, err := node.CompileWLs() + cstest.RequireErrorContains(t, err, tt.expectedErr) + }) + } +} + +func TestWhitelistCheck(t *testing.T) { + node := &Node{ + Logger: log.NewEntry(log.New()), + } + tests := []struct { + name string + whitelist Whitelist + event *types.Event + expected bool + }{ + { + name: "IP Whitelisted", + whitelist: Whitelist{ + Reason: "test", + Ips: []string{ + "127.0.0.1", + }, + }, + event: &types.Event{ + Meta: map[string]string{ + "source_ip": "127.0.0.1", + }, + }, + expected: true, + }, + { + name: "IP Not Whitelisted", + whitelist: Whitelist{ + Reason: "test", + Ips: []string{ + "127.0.0.1", + }, + }, + event: &types.Event{ + Meta: map[string]string{ + "source_ip": "127.0.0.2", + }, + }, + }, + { + name: "CIDR Whitelisted", + whitelist: Whitelist{ + Reason: "test", + Cidrs: []string{ + "127.0.0.1/32", + }, + }, + event: &types.Event{ + Meta: map[string]string{ + "source_ip": "127.0.0.1", + }, + }, + expected: true, + }, + { + name: "CIDR Not Whitelisted", + whitelist: Whitelist{ + Reason: "test", + Cidrs: []string{ + "127.0.0.1/32", + }, + }, + event: &types.Event{ + Meta: map[string]string{ + "source_ip": "127.0.0.2", + }, + }, + }, + { + name: "EXPR Whitelisted", + whitelist: Whitelist{ + Reason: "test", + Exprs: []string{ + "evt.Meta.source_ip == '127.0.0.1'", + }, + }, + event: &types.Event{ + Meta: map[string]string{ + "source_ip": "127.0.0.1", + }, + }, + expected: true, + }, + { + name: "EXPR Not Whitelisted", + whitelist: Whitelist{ + Reason: "test", + Exprs: []string{ + "evt.Meta.source_ip == '127.0.0.1'", + }, + }, + event: &types.Event{ + Meta: map[string]string{ + "source_ip": "127.0.0.2", + }, + }, + }, + { + name: "Postoverflow IP Whitelisted", + whitelist: Whitelist{ + Reason: "test", + Ips: []string{ + "192.168.1.1", + }, + }, + event: &types.Event{ + Type: types.OVFLW, + Overflow: types.RuntimeAlert{ + Sources: map[string]models.Source{ + "192.168.1.1": {}, + }, + }, + }, + expected: true, + }, + { + name: "Postoverflow IP Not Whitelisted", + whitelist: Whitelist{ + Reason: "test", + Ips: []string{ + "192.168.1.2", + }, + }, + event: &types.Event{ + Type: types.OVFLW, + Overflow: types.RuntimeAlert{ + Sources: map[string]models.Source{ + "192.168.1.1": {}, + }, + }, + }, + }, + { + name: "Postoverflow CIDR Whitelisted", + whitelist: Whitelist{ + Reason: "test", + Cidrs: []string{ + "192.168.1.1/32", + }, + }, + event: &types.Event{ + Type: types.OVFLW, + Overflow: types.RuntimeAlert{ + Sources: map[string]models.Source{ + "192.168.1.1": {}, + }, + }, + }, + expected: true, + }, + { + name: "Postoverflow CIDR Not Whitelisted", + whitelist: Whitelist{ + Reason: "test", + Cidrs: []string{ + "192.168.1.2/32", + }, + }, + event: &types.Event{ + Type: types.OVFLW, + Overflow: types.RuntimeAlert{ + Sources: map[string]models.Source{ + "192.168.1.1": {}, + }, + }, + }, + }, + { + name: "Postoverflow EXPR Whitelisted", + whitelist: Whitelist{ + Reason: "test", + Exprs: []string{ + "evt.Overflow.APIAlerts[0].Source.Cn == 'test'", + }, + }, + event: &types.Event{ + Type: types.OVFLW, + Overflow: types.RuntimeAlert{ + APIAlerts: []models.Alert{ + { + Source: &models.Source{ + Cn: "test", + }, + }, + }, + }, + }, + expected: true, + }, + { + name: "Postoverflow EXPR Not Whitelisted", + whitelist: Whitelist{ + Reason: "test", + Exprs: []string{ + "evt.Overflow.APIAlerts[0].Source.Cn == 'test2'", + }, + }, + event: &types.Event{ + Type: types.OVFLW, + Overflow: types.RuntimeAlert{ + APIAlerts: []models.Alert{ + { + Source: &models.Source{ + Cn: "test", + }, + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var err error + node.Whitelist = tt.whitelist + node.CompileWLs() + isWhitelisted := node.CheckIPsWL(tt.event) + if !isWhitelisted { + isWhitelisted, err = node.CheckExprWL(map[string]interface{}{"evt": tt.event}, tt.event) + } + require.NoError(t, err) + require.Equal(t, tt.expected, isWhitelisted) + }) + } +} diff --git a/pkg/protobufs/generate.go b/pkg/protobufs/generate.go new file mode 100644 index 00000000000..0e90d65b643 --- /dev/null +++ b/pkg/protobufs/generate.go @@ -0,0 +1,14 @@ +package protobufs + +// Dependencies: +// +// apt install protobuf-compiler +// +// keep this in sync with go.mod +// go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2 +// +// Not the same versions as google.golang.org/grpc +// go list -m -versions google.golang.org/grpc/cmd/protoc-gen-go-grpc +// go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.5.1 + +//go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative notifier.proto diff --git a/pkg/protobufs/notifier.pb.go b/pkg/protobufs/notifier.pb.go index b5dc8113568..8c4754da773 100644 --- a/pkg/protobufs/notifier.pb.go +++ b/pkg/protobufs/notifier.pb.go @@ -1,16 +1,12 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.27.1 -// protoc v3.12.4 +// protoc-gen-go v1.34.2 +// protoc v3.21.12 // source: notifier.proto package protobufs import ( - context "context" - grpc "google.golang.org/grpc" - codes "google.golang.org/grpc/codes" - status "google.golang.org/grpc/status" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" @@ -198,7 +194,7 @@ func file_notifier_proto_rawDescGZIP() []byte { } var file_notifier_proto_msgTypes = make([]protoimpl.MessageInfo, 3) -var file_notifier_proto_goTypes = []interface{}{ +var file_notifier_proto_goTypes = []any{ (*Notification)(nil), // 0: proto.Notification (*Config)(nil), // 1: proto.Config (*Empty)(nil), // 2: proto.Empty @@ -221,7 +217,7 @@ func file_notifier_proto_init() { return } if !protoimpl.UnsafeEnabled { - file_notifier_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + file_notifier_proto_msgTypes[0].Exporter = func(v any, i int) any { switch v := v.(*Notification); i { case 0: return &v.state @@ -233,7 +229,7 @@ func file_notifier_proto_init() { return nil } } - file_notifier_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + file_notifier_proto_msgTypes[1].Exporter = func(v any, i int) any { switch v := v.(*Config); i { case 0: return &v.state @@ -245,7 +241,7 @@ func file_notifier_proto_init() { return nil } } - file_notifier_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + file_notifier_proto_msgTypes[2].Exporter = func(v any, i int) any { switch v := v.(*Empty); i { case 0: return &v.state @@ -277,119 +273,3 @@ func file_notifier_proto_init() { file_notifier_proto_goTypes = nil file_notifier_proto_depIdxs = nil } - -// Reference imports to suppress errors if they are not otherwise used. -var _ context.Context -var _ grpc.ClientConnInterface - -// This is a compile-time assertion to ensure that this generated file -// is compatible with the grpc package it is being compiled against. -const _ = grpc.SupportPackageIsVersion6 - -// NotifierClient is the client API for Notifier service. -// -// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream. -type NotifierClient interface { - Notify(ctx context.Context, in *Notification, opts ...grpc.CallOption) (*Empty, error) - Configure(ctx context.Context, in *Config, opts ...grpc.CallOption) (*Empty, error) -} - -type notifierClient struct { - cc grpc.ClientConnInterface -} - -func NewNotifierClient(cc grpc.ClientConnInterface) NotifierClient { - return ¬ifierClient{cc} -} - -func (c *notifierClient) Notify(ctx context.Context, in *Notification, opts ...grpc.CallOption) (*Empty, error) { - out := new(Empty) - err := c.cc.Invoke(ctx, "/proto.Notifier/Notify", in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -func (c *notifierClient) Configure(ctx context.Context, in *Config, opts ...grpc.CallOption) (*Empty, error) { - out := new(Empty) - err := c.cc.Invoke(ctx, "/proto.Notifier/Configure", in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -// NotifierServer is the server API for Notifier service. -type NotifierServer interface { - Notify(context.Context, *Notification) (*Empty, error) - Configure(context.Context, *Config) (*Empty, error) -} - -// UnimplementedNotifierServer can be embedded to have forward compatible implementations. -type UnimplementedNotifierServer struct { -} - -func (*UnimplementedNotifierServer) Notify(context.Context, *Notification) (*Empty, error) { - return nil, status.Errorf(codes.Unimplemented, "method Notify not implemented") -} -func (*UnimplementedNotifierServer) Configure(context.Context, *Config) (*Empty, error) { - return nil, status.Errorf(codes.Unimplemented, "method Configure not implemented") -} - -func RegisterNotifierServer(s *grpc.Server, srv NotifierServer) { - s.RegisterService(&_Notifier_serviceDesc, srv) -} - -func _Notifier_Notify_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(Notification) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(NotifierServer).Notify(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: "/proto.Notifier/Notify", - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(NotifierServer).Notify(ctx, req.(*Notification)) - } - return interceptor(ctx, in, info, handler) -} - -func _Notifier_Configure_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(Config) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(NotifierServer).Configure(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: "/proto.Notifier/Configure", - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(NotifierServer).Configure(ctx, req.(*Config)) - } - return interceptor(ctx, in, info, handler) -} - -var _Notifier_serviceDesc = grpc.ServiceDesc{ - ServiceName: "proto.Notifier", - HandlerType: (*NotifierServer)(nil), - Methods: []grpc.MethodDesc{ - { - MethodName: "Notify", - Handler: _Notifier_Notify_Handler, - }, - { - MethodName: "Configure", - Handler: _Notifier_Configure_Handler, - }, - }, - Streams: []grpc.StreamDesc{}, - Metadata: "notifier.proto", -} diff --git a/pkg/protobufs/notifier_grpc.pb.go b/pkg/protobufs/notifier_grpc.pb.go new file mode 100644 index 00000000000..5141e83f98b --- /dev/null +++ b/pkg/protobufs/notifier_grpc.pb.go @@ -0,0 +1,159 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.5.1 +// - protoc v3.21.12 +// source: notifier.proto + +package protobufs + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + Notifier_Notify_FullMethodName = "/proto.Notifier/Notify" + Notifier_Configure_FullMethodName = "/proto.Notifier/Configure" +) + +// NotifierClient is the client API for Notifier service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type NotifierClient interface { + Notify(ctx context.Context, in *Notification, opts ...grpc.CallOption) (*Empty, error) + Configure(ctx context.Context, in *Config, opts ...grpc.CallOption) (*Empty, error) +} + +type notifierClient struct { + cc grpc.ClientConnInterface +} + +func NewNotifierClient(cc grpc.ClientConnInterface) NotifierClient { + return ¬ifierClient{cc} +} + +func (c *notifierClient) Notify(ctx context.Context, in *Notification, opts ...grpc.CallOption) (*Empty, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(Empty) + err := c.cc.Invoke(ctx, Notifier_Notify_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *notifierClient) Configure(ctx context.Context, in *Config, opts ...grpc.CallOption) (*Empty, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(Empty) + err := c.cc.Invoke(ctx, Notifier_Configure_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +// NotifierServer is the server API for Notifier service. +// All implementations must embed UnimplementedNotifierServer +// for forward compatibility. +type NotifierServer interface { + Notify(context.Context, *Notification) (*Empty, error) + Configure(context.Context, *Config) (*Empty, error) + mustEmbedUnimplementedNotifierServer() +} + +// UnimplementedNotifierServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedNotifierServer struct{} + +func (UnimplementedNotifierServer) Notify(context.Context, *Notification) (*Empty, error) { + return nil, status.Errorf(codes.Unimplemented, "method Notify not implemented") +} +func (UnimplementedNotifierServer) Configure(context.Context, *Config) (*Empty, error) { + return nil, status.Errorf(codes.Unimplemented, "method Configure not implemented") +} +func (UnimplementedNotifierServer) mustEmbedUnimplementedNotifierServer() {} +func (UnimplementedNotifierServer) testEmbeddedByValue() {} + +// UnsafeNotifierServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to NotifierServer will +// result in compilation errors. +type UnsafeNotifierServer interface { + mustEmbedUnimplementedNotifierServer() +} + +func RegisterNotifierServer(s grpc.ServiceRegistrar, srv NotifierServer) { + // If the following call pancis, it indicates UnimplementedNotifierServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&Notifier_ServiceDesc, srv) +} + +func _Notifier_Notify_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Notification) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(NotifierServer).Notify(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: Notifier_Notify_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(NotifierServer).Notify(ctx, req.(*Notification)) + } + return interceptor(ctx, in, info, handler) +} + +func _Notifier_Configure_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Config) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(NotifierServer).Configure(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: Notifier_Configure_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(NotifierServer).Configure(ctx, req.(*Config)) + } + return interceptor(ctx, in, info, handler) +} + +// Notifier_ServiceDesc is the grpc.ServiceDesc for Notifier service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var Notifier_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "proto.Notifier", + HandlerType: (*NotifierServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Notify", + Handler: _Notifier_Notify_Handler, + }, + { + MethodName: "Configure", + Handler: _Notifier_Configure_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "notifier.proto", +} diff --git a/pkg/protobufs/plugin_interface.go b/pkg/protobufs/plugin_interface.go deleted file mode 100644 index fc89b2fa009..00000000000 --- a/pkg/protobufs/plugin_interface.go +++ /dev/null @@ -1,47 +0,0 @@ -package protobufs - -import ( - "context" - - plugin "github.com/hashicorp/go-plugin" - "google.golang.org/grpc" -) - -type Notifier interface { - Notify(ctx context.Context, notification *Notification) (*Empty, error) - Configure(ctx context.Context, config *Config) (*Empty, error) -} - -// This is the implementation of plugin.NotifierPlugin so we can serve/consume this. -type NotifierPlugin struct { - // GRPCPlugin must still implement the Plugin interface - plugin.Plugin - // Concrete implementation, written in Go. This is only used for plugins - // that are written in Go. - Impl Notifier -} - -type GRPCClient struct{ client NotifierClient } - -func (m *GRPCClient) Notify(ctx context.Context, notification *Notification) (*Empty, error) { - _, err := m.client.Notify(context.Background(), notification) - return &Empty{}, err -} - -func (m *GRPCClient) Configure(ctx context.Context, config *Config) (*Empty, error) { - _, err := m.client.Configure(context.Background(), config) - return &Empty{}, err -} - -type GRPCServer struct { - Impl Notifier -} - -func (p *NotifierPlugin) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) error { - RegisterNotifierServer(s, p.Impl) - return nil -} - -func (p *NotifierPlugin) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) { - return &GRPCClient{client: NewNotifierClient(c)}, nil -} diff --git a/pkg/setup/README.md b/pkg/setup/README.md index 3585ee8b141..9cdc7243975 100644 --- a/pkg/setup/README.md +++ b/pkg/setup/README.md @@ -129,7 +129,7 @@ services: and must all return true for a service to be detected (implied *and* clause, no short-circuit). A missing or empty `when:` section is evaluated as true. The [expression -engine](https://github.com/antonmedv/expr/blob/master/docs/Language-Definition.md) +engine](https://github.com/antonmedv/expr/blob/master/docs/language-definition.md) is the same one used by CrowdSec parser filters. You can force the detection of a process by using the `cscli setup detect... --force-process ` flag. It will always behave as if `` was running. diff --git a/pkg/setup/detect.go b/pkg/setup/detect.go index b345c0d6f63..073b221b10c 100644 --- a/pkg/setup/detect.go +++ b/pkg/setup/detect.go @@ -2,14 +2,16 @@ package setup import ( "bytes" + "errors" "fmt" + "io" "os" "os/exec" "sort" "github.com/Masterminds/semver/v3" - "github.com/antonmedv/expr" "github.com/blackfireio/osinfo" + "github.com/expr-lang/expr" "github.com/shirou/gopsutil/v3/process" log "github.com/sirupsen/logrus" "gopkg.in/yaml.v3" @@ -52,6 +54,7 @@ func validateDataSource(opaqueDS DataSourceItem) error { // formally validate YAML commonDS := configuration.DataSourceCommonCfg{} + body, err := yaml.Marshal(opaqueDS) if err != nil { return err @@ -65,14 +68,14 @@ func validateDataSource(opaqueDS DataSourceItem) error { // source is mandatory // XXX unless it's not? if commonDS.Source == "" { - return fmt.Errorf("source is empty") + return errors.New("source is empty") } // source must be known - ds := acquisition.GetDataSourceIface(commonDS.Source) - if ds == nil { - return fmt.Errorf("unknown source '%s'", commonDS.Source) + ds, err := acquisition.GetDataSourceIface(commonDS.Source) + if err != nil { + return err } // unmarshal and validate the rest with the specific implementation @@ -86,28 +89,28 @@ func validateDataSource(opaqueDS DataSourceItem) error { return nil } -func readDetectConfig(file string) (DetectConfig, error) { +func readDetectConfig(fin io.Reader) (DetectConfig, error) { var dc DetectConfig - yamlBytes, err := os.ReadFile(file) + yamlBytes, err := io.ReadAll(fin) if err != nil { - return DetectConfig{}, fmt.Errorf("while reading file: %w", err) + return DetectConfig{}, err } dec := yaml.NewDecoder(bytes.NewBuffer(yamlBytes)) dec.KnownFields(true) if err = dec.Decode(&dc); err != nil { - return DetectConfig{}, fmt.Errorf("while parsing %s: %w", file, err) + return DetectConfig{}, err } switch dc.Version { case "": - return DetectConfig{}, fmt.Errorf("missing version tag (must be 1.0)") + return DetectConfig{}, errors.New("missing version tag (must be 1.0)") case "1.0": // all is well default: - return DetectConfig{}, fmt.Errorf("unsupported version tag '%s' (must be 1.0)", dc.Version) + return DetectConfig{}, fmt.Errorf("invalid version tag '%s' (must be 1.0)", dc.Version) } for name, svc := range dc.Detect { @@ -457,15 +460,13 @@ type DetectOptions struct { // Detect performs the service detection from a given configuration. // It outputs a setup file that can be used as input to "cscli setup install-hub" // or "cscli setup datasources". -func Detect(serviceDetectionFile string, opts DetectOptions) (Setup, error) { +func Detect(detectReader io.Reader, opts DetectOptions) (Setup, error) { ret := Setup{} // explicitly initialize to avoid json mashaling an empty slice as "null" ret.Setup = make([]ServiceSetup, 0) - log.Tracef("Reading detection rules: %s", serviceDetectionFile) - - sc, err := readDetectConfig(serviceDetectionFile) + sc, err := readDetectConfig(detectReader) if err != nil { return ret, err } @@ -544,7 +545,7 @@ func Detect(serviceDetectionFile string, opts DetectOptions) (Setup, error) { // } // err = yaml.Unmarshal(svc.AcquisYAML, svc.DataSource) // if err != nil { - // return Setup{}, fmt.Errorf("while unmarshaling datasource for service %s: %w", name, err) + // return Setup{}, fmt.Errorf("while parsing datasource for service %s: %w", name, err) // } // } @@ -559,8 +560,8 @@ func Detect(serviceDetectionFile string, opts DetectOptions) (Setup, error) { } // ListSupported parses the configuration file and outputs a list of the supported services. -func ListSupported(serviceDetectionFile string) ([]string, error) { - dc, err := readDetectConfig(serviceDetectionFile) +func ListSupported(detectConfig io.Reader) ([]string, error) { + dc, err := readDetectConfig(detectConfig) if err != nil { return nil, err } diff --git a/pkg/setup/detect_test.go b/pkg/setup/detect_test.go index adb5f1d436c..588e74dab54 100644 --- a/pkg/setup/detect_test.go +++ b/pkg/setup/detect_test.go @@ -10,8 +10,7 @@ import ( "github.com/lithammer/dedent" "github.com/stretchr/testify/require" - "github.com/crowdsecurity/go-cs-lib/pkg/csstring" - "github.com/crowdsecurity/go-cs-lib/pkg/cstest" + "github.com/crowdsecurity/go-cs-lib/cstest" "github.com/crowdsecurity/crowdsec/pkg/setup" ) @@ -58,7 +57,7 @@ func TestSetupHelperProcess(t *testing.T) { os.Exit(0) } -func tempYAML(t *testing.T, content string) string { +func tempYAML(t *testing.T, content string) os.File { t.Helper() require := require.New(t) file, err := os.CreateTemp("", "") @@ -70,7 +69,10 @@ func tempYAML(t *testing.T, content string) string { err = file.Close() require.NoError(err) - return file.Name() + file, err = os.Open(file.Name()) + require.NoError(err) + + return *file } func TestPathExists(t *testing.T) { @@ -92,11 +94,11 @@ func TestPathExists(t *testing.T) { } for _, tc := range tests { - tc := tc env := setup.NewExprEnvironment(setup.DetectOptions{}, setup.ExprOS{}) t.Run(tc.path, func(t *testing.T) { t.Parallel() + actual := env.PathExists(tc.path) require.Equal(t, tc.expected, actual) }) @@ -145,11 +147,11 @@ func TestVersionCheck(t *testing.T) { } for _, tc := range tests { - tc := tc e := setup.ExprOS{RawVersion: tc.version} t.Run(fmt.Sprintf("Check(%s,%s)", tc.version, tc.constraint), func(t *testing.T) { t.Parallel() + actual, err := e.VersionCheck(tc.constraint) cstest.RequireErrorContains(t, err, tc.expectedErr) require.Equal(t, tc.expected, actual) @@ -182,7 +184,6 @@ func TestNormalizeVersion(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.version, func(t *testing.T) { t.Parallel() actual := setup.NormalizeVersion(tc.version) @@ -239,17 +240,18 @@ func TestListSupported(t *testing.T) { "invalid yaml: bad version", "version: 2.0", nil, - "unsupported version tag '2.0' (must be 1.0)", + "invalid version tag '2.0' (must be 1.0)", }, } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() + f := tempYAML(t, tc.yml) - defer os.Remove(f) - supported, err := setup.ListSupported(f) + defer os.Remove(f.Name()) + + supported, err := setup.ListSupported(&f) cstest.RequireErrorContains(t, err, tc.expectedErr) require.ElementsMatch(t, tc.expected, supported) }) @@ -327,9 +329,9 @@ func TestApplyRules(t *testing.T) { env := setup.ExprEnvironment{} for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() + svc := setup.Service{When: tc.rules} _, actualOk, err := setup.ApplyRules(svc, env) //nolint:typecheck,nolintlint // exported only for tests cstest.RequireErrorContains(t, err, tc.expectedErr) @@ -351,7 +353,7 @@ func TestUnitFound(t *testing.T) { installed, err := env.UnitFound("crowdsec-setup-detect.service") require.NoError(err) - require.Equal(true, installed) + require.True(installed) } // TODO apply rules to filter a list of Service structs @@ -373,9 +375,9 @@ func TestDetectSimpleRule(t *testing.T) { - false ugly: `) - defer os.Remove(f) + defer os.Remove(f.Name()) - detected, err := setup.Detect(f, setup.DetectOptions{}) + detected, err := setup.Detect(&f, setup.DetectOptions{}) require.NoError(err) expected := []setup.ServiceSetup{ @@ -417,12 +419,11 @@ detect: } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { f := tempYAML(t, tc.config) - defer os.Remove(f) + defer os.Remove(f.Name()) - detected, err := setup.Detect(f, setup.DetectOptions{}) + detected, err := setup.Detect(&f, setup.DetectOptions{}) cstest.RequireErrorContains(t, err, tc.expectedErr) require.Equal(tc.expected, detected) }) @@ -511,12 +512,11 @@ detect: } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { f := tempYAML(t, tc.config) - defer os.Remove(f) + defer os.Remove(f.Name()) - detected, err := setup.Detect(f, setup.DetectOptions{}) + detected, err := setup.Detect(&f, setup.DetectOptions{}) cstest.RequireErrorContains(t, err, tc.expectedErr) require.Equal(tc.expected, detected) }) @@ -542,9 +542,9 @@ func TestDetectForcedUnit(t *testing.T) { journalctl_filter: - _SYSTEMD_UNIT=crowdsec-setup-forced.service `) - defer os.Remove(f) + defer os.Remove(f.Name()) - detected, err := setup.Detect(f, setup.DetectOptions{ForcedUnits: []string{"crowdsec-setup-forced.service"}}) + detected, err := setup.Detect(&f, setup.DetectOptions{ForcedUnits: []string{"crowdsec-setup-forced.service"}}) require.NoError(err) expected := setup.Setup{ @@ -564,8 +564,8 @@ func TestDetectForcedUnit(t *testing.T) { func TestDetectForcedProcess(t *testing.T) { if runtime.GOOS == "windows" { - t.Skip("skipping on windows") // while looking for service wizard: rule 'ProcessRunning("foobar")': while looking up running processes: could not get Name: A device attached to the system is not functioning. + t.Skip("skipping on windows") } require := require.New(t) @@ -580,9 +580,9 @@ func TestDetectForcedProcess(t *testing.T) { when: - ProcessRunning("foobar") `) - defer os.Remove(f) + defer os.Remove(f.Name()) - detected, err := setup.Detect(f, setup.DetectOptions{ForcedProcesses: []string{"foobar"}}) + detected, err := setup.Detect(&f, setup.DetectOptions{ForcedProcesses: []string{"foobar"}}) require.NoError(err) expected := setup.Setup{ @@ -610,9 +610,9 @@ func TestDetectSkipService(t *testing.T) { when: - ProcessRunning("foobar") `) - defer os.Remove(f) + defer os.Remove(f.Name()) - detected, err := setup.Detect(f, setup.DetectOptions{ForcedProcesses: []string{"foobar"}, SkipServices: []string{"wizard"}}) + detected, err := setup.Detect(&f, setup.DetectOptions{ForcedProcesses: []string{"foobar"}, SkipServices: []string{"wizard"}}) require.NoError(err) expected := setup.Setup{[]setup.ServiceSetup{}} @@ -823,12 +823,11 @@ func TestDetectForcedOS(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { f := tempYAML(t, tc.config) - defer os.Remove(f) + defer os.Remove(f.Name()) - detected, err := setup.Detect(f, setup.DetectOptions{ForcedOS: tc.forced}) + detected, err := setup.Detect(&f, setup.DetectOptions{ForcedOS: tc.forced}) cstest.RequireErrorContains(t, err, tc.expectedErr) require.Equal(tc.expected, detected) }) @@ -838,7 +837,6 @@ func TestDetectForcedOS(t *testing.T) { func TestDetectDatasourceValidation(t *testing.T) { // It could be a good idea to test UnmarshalConfig() separately in addition // to Configure(), in each datasource. For now, we test these here. - require := require.New(t) setup.ExecCommand = fakeExecCommand @@ -872,7 +870,7 @@ func TestDetectDatasourceValidation(t *testing.T) { datasource: source: wombat`, expected: setup.Setup{Setup: []setup.ServiceSetup{}}, - expectedErr: "invalid datasource for foobar: unknown source 'wombat'", + expectedErr: "invalid datasource for foobar: unknown data source wombat", }, { name: "source is misplaced", config: ` @@ -882,7 +880,7 @@ func TestDetectDatasourceValidation(t *testing.T) { datasource: source: file`, expected: setup.Setup{Setup: []setup.ServiceSetup{}}, - expectedErr: "while parsing {{.DetectYaml}}: yaml: unmarshal errors:\n line 6: field source not found in type setup.Service", + expectedErr: "yaml: unmarshal errors:\n line 6: field source not found in type setup.Service", }, { name: "source is mismatched", config: ` @@ -981,6 +979,16 @@ func TestDetectDatasourceValidation(t *testing.T) { source: kafka`, expected: setup.Setup{Setup: []setup.ServiceSetup{}}, expectedErr: "invalid datasource for foobar: cannot create a kafka reader with an empty list of broker addresses", + }, { + name: "source loki: required fields", + config: ` + version: 1.0 + detect: + foobar: + datasource: + source: loki`, + expected: setup.Setup{Setup: []setup.ServiceSetup{}}, + expectedErr: "invalid datasource for foobar: loki query is mandatory", }, } @@ -999,20 +1007,11 @@ func TestDetectDatasourceValidation(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { - detectYaml := tempYAML(t, tc.config) - defer os.Remove(detectYaml) - - data := map[string]string{ - "DetectYaml": detectYaml, - } - - expectedErr, err := csstring.Interpolate(tc.expectedErr, data) - require.NoError(err) - - detected, err := setup.Detect(detectYaml, setup.DetectOptions{}) - cstest.RequireErrorContains(t, err, expectedErr) + f := tempYAML(t, tc.config) + defer os.Remove(f.Name()) + detected, err := setup.Detect(&f, setup.DetectOptions{}) + cstest.RequireErrorContains(t, err, tc.expectedErr) require.Equal(tc.expected, detected) }) } diff --git a/pkg/setup/install.go b/pkg/setup/install.go index 5d3bfdbc995..d63a1ee1775 100644 --- a/pkg/setup/install.go +++ b/pkg/setup/install.go @@ -2,6 +2,8 @@ package setup import ( "bytes" + "context" + "errors" "fmt" "os" "path/filepath" @@ -10,7 +12,6 @@ import ( goccyyaml "github.com/goccy/go-yaml" "gopkg.in/yaml.v3" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) @@ -39,31 +40,19 @@ func decodeSetup(input []byte, fancyErrors bool) (Setup, error) { dec2.KnownFields(true) if err := dec2.Decode(&ret); err != nil { - return ret, fmt.Errorf("while unmarshaling setup file: %w", err) + return ret, fmt.Errorf("while parsing setup file: %w", err) } return ret, nil } // InstallHubItems installs the objects recommended in a setup file. -func InstallHubItems(csConfig *csconfig.Config, input []byte, dryRun bool) error { +func InstallHubItems(ctx context.Context, hub *cwhub.Hub, input []byte, dryRun bool) error { setupEnvelope, err := decodeSetup(input, false) if err != nil { return err } - if err := csConfig.LoadHub(); err != nil { - return fmt.Errorf("loading hub: %w", err) - } - - if err := cwhub.SetHubBranch(); err != nil { - return fmt.Errorf("setting hub branch: %w", err) - } - - if err := cwhub.GetHubIdx(csConfig.Hub); err != nil { - return fmt.Errorf("getting hub index: %w", err) - } - for _, setupItem := range setupEnvelope.Setup { forceAction := false downloadOnly := false @@ -73,59 +62,71 @@ func InstallHubItems(csConfig *csconfig.Config, input []byte, dryRun bool) error continue } - if len(install.Collections) > 0 { - for _, collection := range setupItem.Install.Collections { - if dryRun { - fmt.Println("dry-run: would install collection", collection) + for _, collection := range setupItem.Install.Collections { + item := hub.GetItem(cwhub.COLLECTIONS, collection) + if item == nil { + return fmt.Errorf("collection %s not found", collection) + } + + if dryRun { + fmt.Println("dry-run: would install collection", collection) - continue - } + continue + } - if err := cwhub.InstallItem(csConfig, collection, cwhub.COLLECTIONS, forceAction, downloadOnly); err != nil { - return fmt.Errorf("while installing collection %s: %w", collection, err) - } + if err := item.Install(ctx, forceAction, downloadOnly); err != nil { + return fmt.Errorf("while installing collection %s: %w", item.Name, err) } } - if len(install.Parsers) > 0 { - for _, parser := range setupItem.Install.Parsers { - if dryRun { - fmt.Println("dry-run: would install parser", parser) + for _, parser := range setupItem.Install.Parsers { + if dryRun { + fmt.Println("dry-run: would install parser", parser) - continue - } + continue + } + + item := hub.GetItem(cwhub.PARSERS, parser) + if item == nil { + return fmt.Errorf("parser %s not found", parser) + } - if err := cwhub.InstallItem(csConfig, parser, cwhub.PARSERS, forceAction, downloadOnly); err != nil { - return fmt.Errorf("while installing parser %s: %w", parser, err) - } + if err := item.Install(ctx, forceAction, downloadOnly); err != nil { + return fmt.Errorf("while installing parser %s: %w", item.Name, err) } } - if len(install.Scenarios) > 0 { - for _, scenario := range setupItem.Install.Scenarios { - if dryRun { - fmt.Println("dry-run: would install scenario", scenario) + for _, scenario := range setupItem.Install.Scenarios { + if dryRun { + fmt.Println("dry-run: would install scenario", scenario) - continue - } + continue + } - if err := cwhub.InstallItem(csConfig, scenario, cwhub.SCENARIOS, forceAction, downloadOnly); err != nil { - return fmt.Errorf("while installing scenario %s: %w", scenario, err) - } + item := hub.GetItem(cwhub.SCENARIOS, scenario) + if item == nil { + return fmt.Errorf("scenario %s not found", scenario) + } + + if err := item.Install(ctx, forceAction, downloadOnly); err != nil { + return fmt.Errorf("while installing scenario %s: %w", item.Name, err) } } - if len(install.PostOverflows) > 0 { - for _, postoverflow := range setupItem.Install.PostOverflows { - if dryRun { - fmt.Println("dry-run: would install postoverflow", postoverflow) + for _, postoverflow := range setupItem.Install.PostOverflows { + if dryRun { + fmt.Println("dry-run: would install postoverflow", postoverflow) + + continue + } - continue - } + item := hub.GetItem(cwhub.POSTOVERFLOWS, postoverflow) + if item == nil { + return fmt.Errorf("postoverflow %s not found", postoverflow) + } - if err := cwhub.InstallItem(csConfig, postoverflow, cwhub.PARSERS_OVFLW, forceAction, downloadOnly); err != nil { - return fmt.Errorf("while installing postoverflow %s: %w", postoverflow, err) - } + if err := item.Install(ctx, forceAction, downloadOnly); err != nil { + return fmt.Errorf("while installing postoverflow %s: %w", item.Name, err) } } } @@ -166,7 +167,7 @@ func marshalAcquisDocuments(ads []AcquisDocument, toDir string) (string, error) if toDir != "" { if ad.AcquisFilename == "" { - return "", fmt.Errorf("empty acquis filename") + return "", errors.New("empty acquis filename") } fname := filepath.Join(toDir, ad.AcquisFilename) diff --git a/pkg/setup/units.go b/pkg/setup/units.go index a0bccba4aac..861513d3f1d 100644 --- a/pkg/setup/units.go +++ b/pkg/setup/units.go @@ -2,6 +2,7 @@ package setup import ( "bufio" + "errors" "fmt" "strings" @@ -34,14 +35,14 @@ func systemdUnitList() ([]string, error) { for scanner.Scan() { line := scanner.Text() - if len(line) == 0 { + if line == "" { break // the rest of the output is footer } if !header { spaceIdx := strings.IndexRune(line, ' ') if spaceIdx == -1 { - return ret, fmt.Errorf("can't parse systemctl output") + return ret, errors.New("can't parse systemctl output") } line = line[:spaceIdx] diff --git a/pkg/types/appsec_event.go b/pkg/types/appsec_event.go new file mode 100644 index 00000000000..dc81c63b344 --- /dev/null +++ b/pkg/types/appsec_event.go @@ -0,0 +1,245 @@ +package types + +import ( + "regexp" + "slices" + + log "github.com/sirupsen/logrus" +) + +/* + 1. If user triggered a rule that is for a CVE, that has high confidence and that is blocking, ban + 2. If user triggered 3 distinct rules with medium confidence across 3 different requests, ban + + +any(evt.Waf.ByTag("CVE"), {.confidence == "high" && .action == "block"}) + +len(evt.Waf.ByTagRx("*CVE*").ByConfidence("high").ByAction("block")) > 1 + +*/ + +type MatchedRules []map[string]interface{} + +type AppsecEvent struct { + HasInBandMatches, HasOutBandMatches bool + MatchedRules + Vars map[string]string +} +type Field string + +func (f Field) String() string { + return string(f) +} + +const ( + ID Field = "id" + RuleType Field = "rule_type" + Tags Field = "tags" + File Field = "file" + Confidence Field = "confidence" + Revision Field = "revision" + SecMark Field = "secmark" + Accuracy Field = "accuracy" + Msg Field = "msg" + Severity Field = "severity" + Kind Field = "kind" +) + +func (w AppsecEvent) GetVar(varName string) string { + if w.Vars == nil { + return "" + } + if val, ok := w.Vars[varName]; ok { + return val + } + log.Infof("var %s not found. Available variables: %+v", varName, w.Vars) + return "" + +} + +// getters +func (w MatchedRules) GetField(field Field) []interface{} { + ret := make([]interface{}, 0) + for _, rule := range w { + ret = append(ret, rule[field.String()]) + } + return ret +} + +func (w MatchedRules) GetURI() string { + for _, rule := range w { + return rule["uri"].(string) + } + return "" +} + +func (w MatchedRules) GetHash() string { + for _, rule := range w { + //@sbl : let's fix this + return rule["hash"].(string) + } + return "" +} + +func (w MatchedRules) GetVersion() string { + for _, rule := range w { + //@sbl : let's fix this + return rule["version"].(string) + } + return "" +} + +func (w MatchedRules) GetName() string { + for _, rule := range w { + //@sbl : let's fix this + return rule["name"].(string) + } + return "" +} + +func (w MatchedRules) GetMethod() string { + for _, rule := range w { + return rule["method"].(string) + } + return "" +} + +func (w MatchedRules) GetRuleIDs() []int { + ret := make([]int, 0) + for _, rule := range w { + ret = append(ret, rule["id"].(int)) + } + return ret +} + +func (w MatchedRules) Kinds() []string { + ret := make([]string, 0) + for _, rule := range w { + exists := false + for _, val := range ret { + if val == rule["kind"] { + exists = true + break + } + } + if !exists { + ret = append(ret, rule["kind"].(string)) + } + } + return ret +} + +func (w MatchedRules) GetMatchedZones() []string { + ret := make([]string, 0) + + for _, rule := range w { + for _, zone := range rule["matched_zones"].([]string) { + if !slices.Contains(ret, zone) { + ret = append(ret, zone) + } + } + } + return ret +} + +// filters +func (w MatchedRules) ByID(id int) MatchedRules { + ret := MatchedRules{} + + for _, rule := range w { + if rule["id"] == id { + ret = append(ret, rule) + } + } + return ret +} + +func (w MatchedRules) ByKind(kind string) MatchedRules { + ret := MatchedRules{} + for _, rule := range w { + if rule["kind"] == kind { + ret = append(ret, rule) + } + } + return ret +} + +func (w MatchedRules) ByTags(match []string) MatchedRules { + ret := MatchedRules{} + for _, rule := range w { + for _, tag := range rule["tags"].([]string) { + for _, match_tag := range match { + if tag == match_tag { + ret = append(ret, rule) + break + } + } + } + } + return ret +} + +func (w MatchedRules) ByTag(match string) MatchedRules { + ret := MatchedRules{} + for _, rule := range w { + for _, tag := range rule["tags"].([]string) { + if tag == match { + ret = append(ret, rule) + break + } + } + } + return ret +} + +func (w MatchedRules) ByTagRx(rx string) MatchedRules { + ret := MatchedRules{} + re := regexp.MustCompile(rx) + if re == nil { + return ret + } + for _, rule := range w { + for _, tag := range rule["tags"].([]string) { + log.Debugf("ByTagRx: %s = %s -> %t", rx, tag, re.MatchString(tag)) + if re.MatchString(tag) { + ret = append(ret, rule) + break + } + } + } + return ret +} + +func (w MatchedRules) ByDisruptiveness(is bool) MatchedRules { + ret := MatchedRules{} + for _, rule := range w { + if rule["disruptive"] == is { + ret = append(ret, rule) + } + } + log.Debugf("ByDisruptiveness(%t) -> %d", is, len(ret)) + + return ret +} + +func (w MatchedRules) BySeverity(severity string) MatchedRules { + ret := MatchedRules{} + for _, rule := range w { + if rule["severity"] == severity { + ret = append(ret, rule) + } + } + log.Debugf("BySeverity(%s) -> %d", severity, len(ret)) + return ret +} + +func (w MatchedRules) ByAccuracy(accuracy string) MatchedRules { + ret := MatchedRules{} + for _, rule := range w { + if rule["accuracy"] == accuracy { + ret = append(ret, rule) + } + } + log.Debugf("ByAccuracy(%s) -> %d", accuracy, len(ret)) + return ret +} diff --git a/pkg/types/constants.go b/pkg/types/constants.go index fa50b64f367..acb5b5bfacf 100644 --- a/pkg/types/constants.go +++ b/pkg/types/constants.go @@ -17,6 +17,7 @@ const ConsoleOrigin = "console" const CscliImportOrigin = "cscli-import" const ListOrigin = "lists" const CAPIOrigin = "CAPI" +const CommunityBlocklistPullSourceScope = "crowdsecurity/community-blocklist" const DecisionTypeBan = "ban" diff --git a/pkg/types/event.go b/pkg/types/event.go index fc8d966abc7..e016d0294c4 100644 --- a/pkg/types/event.go +++ b/pkg/types/event.go @@ -1,27 +1,30 @@ package types import ( + "net" + "strings" "time" + "github.com/expr-lang/expr/vm" log "github.com/sirupsen/logrus" - "github.com/antonmedv/expr/vm" "github.com/crowdsecurity/crowdsec/pkg/models" ) const ( LOG = iota OVFLW + APPSEC ) // Event is the structure representing a runtime event (log or overflow) type Event struct { /* is it a log or an overflow */ - Type int `yaml:"Type,omitempty" json:"Type,omitempty"` //Can be types.LOG (0) or types.OVFLOW (1) - ExpectMode int `yaml:"ExpectMode,omitempty" json:"ExpectMode,omitempty"` //how to buckets should handle event : types.TIMEMACHINE or types.LIVE + Type int `yaml:"Type,omitempty" json:"Type,omitempty"` // Can be types.LOG (0) or types.OVFLOW (1) + ExpectMode int `yaml:"ExpectMode,omitempty" json:"ExpectMode,omitempty"` // how to buckets should handle event : types.TIMEMACHINE or types.LIVE Whitelisted bool `yaml:"Whitelisted,omitempty" json:"Whitelisted,omitempty"` WhitelistReason string `yaml:"WhitelistReason,omitempty" json:"whitelist_reason,omitempty"` - //should add whitelist reason ? + // should add whitelist reason ? /* the current stage of the line being parsed */ Stage string `yaml:"Stage,omitempty" json:"Stage,omitempty"` /* original line (produced by acquisition) */ @@ -34,21 +37,43 @@ type Event struct { Unmarshaled map[string]interface{} `yaml:"Unmarshaled,omitempty" json:"Unmarshaled,omitempty"` /* Overflow */ Overflow RuntimeAlert `yaml:"Overflow,omitempty" json:"Alert,omitempty"` - Time time.Time `yaml:"Time,omitempty" json:"Time,omitempty"` //parsed time `json:"-"` `` + Time time.Time `yaml:"Time,omitempty" json:"Time,omitempty"` // parsed time `json:"-"` `` StrTime string `yaml:"StrTime,omitempty" json:"StrTime,omitempty"` StrTimeFormat string `yaml:"StrTimeFormat,omitempty" json:"StrTimeFormat,omitempty"` MarshaledTime string `yaml:"MarshaledTime,omitempty" json:"MarshaledTime,omitempty"` - Process bool `yaml:"Process,omitempty" json:"Process,omitempty"` //can be set to false to avoid processing line + Process bool `yaml:"Process,omitempty" json:"Process,omitempty"` // can be set to false to avoid processing line + Appsec AppsecEvent `yaml:"Appsec,omitempty" json:"Appsec,omitempty"` /* Meta is the only part that will make it to the API - it should be normalized */ Meta map[string]string `yaml:"Meta,omitempty" json:"Meta,omitempty"` } +func (e *Event) SetMeta(key string, value string) bool { + if e.Meta == nil { + e.Meta = make(map[string]string) + } + + e.Meta[key] = value + + return true +} + +func (e *Event) SetParsed(key string, value string) bool { + if e.Parsed == nil { + e.Parsed = make(map[string]string) + } + + e.Parsed[key] = value + + return true +} + func (e *Event) GetType() string { - if e.Type == OVFLW { + switch e.Type { + case OVFLW: return "overflow" - } else if e.Type == LOG { + case LOG: return "log" - } else { + default: log.Warningf("unknown event type for %+v", e) return "unknown" } @@ -70,9 +95,27 @@ func (e *Event) GetMeta(key string) string { } } } + return "" } +func (e *Event) ParseIPSources() []net.IP { + var srcs []net.IP + + switch e.Type { + case LOG: + if _, ok := e.Meta["source_ip"]; ok { + srcs = append(srcs, net.ParseIP(e.Meta["source_ip"])) + } + case OVFLW: + for k := range e.Overflow.Sources { + srcs = append(srcs, net.ParseIP(k)) + } + } + + return srcs +} + // Move in leakybuckets const ( Undefined = "" @@ -96,8 +139,8 @@ type RuntimeAlert struct { Whitelisted bool `yaml:"Whitelisted,omitempty" json:"Whitelisted,omitempty"` Reprocess bool `yaml:"Reprocess,omitempty" json:"Reprocess,omitempty"` Sources map[string]models.Source `yaml:"Sources,omitempty" json:"Sources,omitempty"` - Alert *models.Alert `yaml:"Alert,omitempty" json:"Alert,omitempty"` //this one is a pointer to APIAlerts[0] for convenience. - //APIAlerts will be populated at the end when there is more than one source + Alert *models.Alert `yaml:"Alert,omitempty" json:"Alert,omitempty"` // this one is a pointer to APIAlerts[0] for convenience. + // APIAlerts will be populated at the end when there is more than one source APIAlerts []models.Alert `yaml:"APIAlerts,omitempty" json:"APIAlerts,omitempty"` } @@ -106,5 +149,21 @@ func (r RuntimeAlert) GetSources() []string { for key := range r.Sources { ret = append(ret, key) } + return ret } + +func NormalizeScope(scope string) string { + switch strings.ToLower(scope) { + case "ip": + return Ip + case "range": + return Range + case "as": + return AS + case "country": + return Country + default: + return scope + } +} diff --git a/pkg/types/event_test.go b/pkg/types/event_test.go new file mode 100644 index 00000000000..97b13f96d9a --- /dev/null +++ b/pkg/types/event_test.go @@ -0,0 +1,158 @@ +package types + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/crowdsecurity/crowdsec/pkg/models" +) + +func TestSetParsed(t *testing.T) { + tests := []struct { + name string + evt *Event + key string + value string + expected bool + }{ + { + name: "SetParsed: Valid", + evt: &Event{}, + key: "test", + value: "test", + expected: true, + }, + { + name: "SetParsed: Existing map", + evt: &Event{Parsed: map[string]string{}}, + key: "test", + value: "test", + expected: true, + }, + { + name: "SetParsed: Existing map+key", + evt: &Event{Parsed: map[string]string{"test": "foobar"}}, + key: "test", + value: "test", + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.evt.SetParsed(tt.key, tt.value) + assert.Equal(t, tt.value, tt.evt.Parsed[tt.key]) + }) + } + +} + +func TestSetMeta(t *testing.T) { + tests := []struct { + name string + evt *Event + key string + value string + expected bool + }{ + { + name: "SetMeta: Valid", + evt: &Event{}, + key: "test", + value: "test", + expected: true, + }, + { + name: "SetMeta: Existing map", + evt: &Event{Meta: map[string]string{}}, + key: "test", + value: "test", + expected: true, + }, + { + name: "SetMeta: Existing map+key", + evt: &Event{Meta: map[string]string{"test": "foobar"}}, + key: "test", + value: "test", + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.evt.SetMeta(tt.key, tt.value) + assert.Equal(t, tt.value, tt.evt.GetMeta(tt.key)) + }) + } + +} + +func TestParseIPSources(t *testing.T) { + tests := []struct { + name string + evt Event + expected []net.IP + }{ + { + name: "ParseIPSources: Valid Log Sources", + evt: Event{ + Type: LOG, + Meta: map[string]string{ + "source_ip": "127.0.0.1", + }, + }, + expected: []net.IP{ + net.ParseIP("127.0.0.1"), + }, + }, + { + name: "ParseIPSources: Valid Overflow Sources", + evt: Event{ + Type: OVFLW, + Overflow: RuntimeAlert{ + Sources: map[string]models.Source{ + "127.0.0.1": {}, + }, + }, + }, + expected: []net.IP{ + net.ParseIP("127.0.0.1"), + }, + }, + { + name: "ParseIPSources: Invalid Log Sources", + evt: Event{ + Type: LOG, + Meta: map[string]string{ + "source_ip": "IAMNOTANIP", + }, + }, + expected: []net.IP{ + nil, + }, + }, + { + name: "ParseIPSources: Invalid Overflow Sources", + evt: Event{ + Type: OVFLW, + Overflow: RuntimeAlert{ + Sources: map[string]models.Source{ + "IAMNOTANIP": {}, + }, + }, + }, + expected: []net.IP{ + nil, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ips := tt.evt.ParseIPSources() + assert.Equal(t, tt.expected, ips) + }) + } +} diff --git a/pkg/types/getfstype.go b/pkg/types/getfstype.go new file mode 100644 index 00000000000..728e986bed0 --- /dev/null +++ b/pkg/types/getfstype.go @@ -0,0 +1,115 @@ +//go:build !windows && !freebsd && !openbsd + +package types + +import ( + "fmt" + + "golang.org/x/sys/unix" +) + +// Generated with `man statfs | grep _MAGIC | awk '{split(tolower($1),a,"_"); print $2 ": \"" a[1] "\","}'` +// ext2/3/4 duplicates removed to just have ext4 +// XIAFS removed as well +var fsTypeMapping = map[int64]string{ + 0xadf5: "adfs", + 0xadff: "affs", + 0x5346414f: "afs", + 0x09041934: "anon", + 0x0187: "autofs", + 0x62646576: "bdevfs", + 0x42465331: "befs", + 0x1badface: "bfs", + 0x42494e4d: "binfmtfs", + 0xcafe4a11: "bpf", + 0x9123683e: "btrfs", + 0x73727279: "btrfs", + 0x27e0eb: "cgroup", + 0x63677270: "cgroup2", + 0xff534d42: "cifs", + 0x73757245: "coda", + 0x012ff7b7: "coh", + 0x28cd3d45: "cramfs", + 0x64626720: "debugfs", + 0x1373: "devfs", + 0x1cd1: "devpts", + 0xf15f: "ecryptfs", + 0xde5e81e4: "efivarfs", + 0x00414a53: "efs", + 0x137d: "ext", + 0xef51: "ext2", + 0xef53: "ext4", + 0xf2f52010: "f2fs", + 0x65735546: "fuse", + 0xbad1dea: "futexfs", + 0x4244: "hfs", + 0x00c0ffee: "hostfs", + 0xf995e849: "hpfs", + 0x958458f6: "hugetlbfs", + 0x9660: "isofs", + 0x72b6: "jffs2", + 0x3153464a: "jfs", + 0x137f: "minix", + 0x138f: "minix", + 0x2468: "minix2", + 0x2478: "minix2", + 0x4d5a: "minix3", + 0x19800202: "mqueue", + 0x4d44: "msdos", + 0x11307854: "mtd", + 0x564c: "ncp", + 0x6969: "nfs", + 0x3434: "nilfs", + 0x6e736673: "nsfs", + 0x5346544e: "ntfs", + 0x7461636f: "ocfs2", + 0x9fa1: "openprom", + 0x794c7630: "overlayfs", + 0x50495045: "pipefs", + 0x9fa0: "proc", + 0x6165676c: "pstorefs", + 0x002f: "qnx4", + 0x68191122: "qnx6", + 0x858458f6: "ramfs", + 0x52654973: "reiserfs", + 0x7275: "romfs", + 0x73636673: "securityfs", + 0xf97cff8c: "selinux", + 0x43415d53: "smack", + 0x517b: "smb", + 0xfe534d42: "smb2", + 0x534f434b: "sockfs", + 0x73717368: "squashfs", + 0x62656572: "sysfs", + 0x012ff7b6: "sysv2", + 0x012ff7b5: "sysv4", + 0x01021994: "tmpfs", + 0x74726163: "tracefs", + 0x15013346: "udf", + 0x00011954: "ufs", + 0x9fa2: "usbdevice", + 0x01021997: "v9fs", + 0xa501fcf5: "vxfs", + 0xabba1974: "xenfs", + 0x012ff7b4: "xenix", + 0x58465342: "xfs", + 0x2fc12fc1: "zfs", +} + +func GetFSType(path string) (string, error) { + var buf unix.Statfs_t + + err := unix.Statfs(path, &buf) + + if err != nil { + return "", err + } + + fsType, ok := fsTypeMapping[int64(buf.Type)] //nolint:unconvert + + if !ok { + return "", fmt.Errorf("unknown fstype %d", buf.Type) + } + + return fsType, nil +} diff --git a/pkg/types/getfstype_freebsd.go b/pkg/types/getfstype_freebsd.go new file mode 100644 index 00000000000..8fbe3dd7cc4 --- /dev/null +++ b/pkg/types/getfstype_freebsd.go @@ -0,0 +1,25 @@ +//go:build freebsd + +package types + +import ( + "fmt" + "syscall" +) + +func GetFSType(path string) (string, error) { + var fsStat syscall.Statfs_t + + if err := syscall.Statfs(path, &fsStat); err != nil { + return "", fmt.Errorf("failed to get filesystem type: %w", err) + } + + bs := fsStat.Fstypename + + b := make([]byte, len(bs)) + for i, v := range bs { + b[i] = byte(v) + } + + return string(b), nil +} diff --git a/pkg/types/getfstype_openbsd.go b/pkg/types/getfstype_openbsd.go new file mode 100644 index 00000000000..9ec254b7bec --- /dev/null +++ b/pkg/types/getfstype_openbsd.go @@ -0,0 +1,25 @@ +//go:build openbsd + +package types + +import ( + "fmt" + "syscall" +) + +func GetFSType(path string) (string, error) { + var fsStat syscall.Statfs_t + + if err := syscall.Statfs(path, &fsStat); err != nil { + return "", fmt.Errorf("failed to get filesystem type: %w", err) + } + + bs := fsStat.F_fstypename + + b := make([]byte, len(bs)) + for i, v := range bs { + b[i] = byte(v) + } + + return string(b), nil +} diff --git a/pkg/types/getfstype_windows.go b/pkg/types/getfstype_windows.go new file mode 100644 index 00000000000..03d8fffd48d --- /dev/null +++ b/pkg/types/getfstype_windows.go @@ -0,0 +1,53 @@ +package types + +import ( + "path/filepath" + "syscall" + "unsafe" +) + +func GetFSType(path string) (string, error) { + kernel32, err := syscall.LoadLibrary("kernel32.dll") + if err != nil { + return "", err + } + defer syscall.FreeLibrary(kernel32) + + getVolumeInformation, err := syscall.GetProcAddress(kernel32, "GetVolumeInformationW") + if err != nil { + return "", err + } + + // Convert relative path to absolute path + absPath, err := filepath.Abs(path) + if err != nil { + return "", err + } + + // Get the root path of the volume + volumeRoot := filepath.VolumeName(absPath) + "\\" + + volumeRootPtr, _ := syscall.UTF16PtrFromString(volumeRoot) + + var ( + fileSystemNameBuffer = make([]uint16, 260) + nFileSystemNameSize = uint32(len(fileSystemNameBuffer)) + ) + + ret, _, err := syscall.SyscallN(getVolumeInformation, + uintptr(unsafe.Pointer(volumeRootPtr)), + 0, + 0, + 0, + 0, + 0, + uintptr(unsafe.Pointer(&fileSystemNameBuffer[0])), + uintptr(nFileSystemNameSize), + 0) + + if ret == 0 { + return "", err + } + + return syscall.UTF16ToString(fileSystemNameBuffer), nil +} diff --git a/pkg/types/ip.go b/pkg/types/ip.go index 5e4d7734f2d..9d08afd8809 100644 --- a/pkg/types/ip.go +++ b/pkg/types/ip.go @@ -2,6 +2,7 @@ package types import ( "encoding/binary" + "errors" "fmt" "math" "net" @@ -15,6 +16,7 @@ func LastAddress(n net.IPNet) net.IP { if ip == nil { // IPv6 ip = n.IP + return net.IP{ ip[0] | ^n.Mask[0], ip[1] | ^n.Mask[1], ip[2] | ^n.Mask[2], ip[3] | ^n.Mask[3], ip[4] | ^n.Mask[4], ip[5] | ^n.Mask[5], @@ -38,12 +40,13 @@ func Addr2Ints(anyIP string) (int, int64, int64, int64, int64, error) { if err != nil { return -1, 0, 0, 0, 0, fmt.Errorf("while parsing range %s: %w", anyIP, err) } + return Range2Ints(*net) } ip := net.ParseIP(anyIP) if ip == nil { - return -1, 0, 0, 0, 0, fmt.Errorf("invalid address") + return -1, 0, 0, 0, 0, errors.New("invalid address") } sz, start, end, err := IP2Ints(ip) @@ -56,19 +59,22 @@ func Addr2Ints(anyIP string) (int, int64, int64, int64, int64, error) { /*size (16|4), nw_start, suffix_start, nw_end, suffix_end, error*/ func Range2Ints(network net.IPNet) (int, int64, int64, int64, int64, error) { - szStart, nwStart, sfxStart, err := IP2Ints(network.IP) if err != nil { return -1, 0, 0, 0, 0, fmt.Errorf("converting first ip in range: %w", err) } + lastAddr := LastAddress(network) + szEnd, nwEnd, sfxEnd, err := IP2Ints(lastAddr) if err != nil { return -1, 0, 0, 0, 0, fmt.Errorf("transforming last address of range: %w", err) } + if szEnd != szStart { return -1, 0, 0, 0, 0, fmt.Errorf("inconsistent size for range first(%d) and last(%d) ip", szStart, szEnd) } + return szStart, nwStart, sfxStart, nwEnd, sfxEnd, nil } @@ -85,6 +91,7 @@ func uint2int(u uint64) int64 { ret = int64(u) ret -= math.MaxInt64 } + return ret } @@ -97,13 +104,15 @@ func IP2Ints(pip net.IP) (int, int64, int64, error) { if pip4 != nil { ip_nw32 := binary.BigEndian.Uint32(pip4) - return 4, uint2int(uint64(ip_nw32)), uint2int(ip_sfx), nil - } else if pip16 != nil { + } + + if pip16 != nil { ip_nw = binary.BigEndian.Uint64(pip16[0:8]) ip_sfx = binary.BigEndian.Uint64(pip16[8:16]) + return 16, uint2int(ip_nw), uint2int(ip_sfx), nil - } else { - return -1, 0, 0, fmt.Errorf("unexpected len %d for %s", len(pip), pip) } + + return -1, 0, 0, fmt.Errorf("unexpected len %d for %s", len(pip), pip) } diff --git a/pkg/types/profile.go b/pkg/types/profile.go deleted file mode 100644 index e8034210cc3..00000000000 --- a/pkg/types/profile.go +++ /dev/null @@ -1,25 +0,0 @@ -package types - -import ( - "time" - - "github.com/antonmedv/expr/vm" -) - -/*Action profiles*/ -type RemediationProfile struct { - Apply bool - Ban bool - Slow bool - Captcha bool - Duration string - TimeDuration time.Duration -} -type Profile struct { - Profile string `yaml:"profile"` - Filter string `yaml:"filter"` - Remediation RemediationProfile `yaml:"remediation"` - RunTimeFilter *vm.Program - ApiPush *bool `yaml:"api"` - OutputConfigs []map[string]string `yaml:"outputs,omitempty"` -} diff --git a/pkg/leakybucket/queue.go b/pkg/types/queue.go similarity index 69% rename from pkg/leakybucket/queue.go rename to pkg/types/queue.go index 03130b71f73..12a3ab37074 100644 --- a/pkg/leakybucket/queue.go +++ b/pkg/types/queue.go @@ -1,13 +1,12 @@ -package leakybucket +package types import ( - "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" ) // Queue holds a limited size queue type Queue struct { - Queue []types.Event + Queue []Event L int //capacity } @@ -15,21 +14,21 @@ type Queue struct { func NewQueue(l int) *Queue { if l == -1 { return &Queue{ - Queue: make([]types.Event, 0), + Queue: make([]Event, 0), L: int(^uint(0) >> 1), // max integer value, architecture independent } } q := &Queue{ - Queue: make([]types.Event, 0, l), + Queue: make([]Event, 0, l), L: l, } - log.WithFields(log.Fields{"Capacity": q.L}).Debugf("Creating queue") + log.WithField("Capacity", q.L).Debugf("Creating queue") return q } // Add an event in the queue. If it has already l elements, the first // element is dropped before adding the new m element -func (q *Queue) Add(m types.Event) { +func (q *Queue) Add(m Event) { for len(q.Queue) > q.L { //we allow to add one element more than the true capacity q.Queue = q.Queue[1:] } @@ -37,6 +36,6 @@ func (q *Queue) Add(m types.Event) { } // GetQueue returns the entire queue -func (q *Queue) GetQueue() []types.Event { +func (q *Queue) GetQueue() []Event { return q.Queue } diff --git a/pkg/types/utils.go b/pkg/types/utils.go index 0485db59eaf..712d44ba12d 100644 --- a/pkg/types/utils.go +++ b/pkg/types/utils.go @@ -1,11 +1,8 @@ package types import ( - "bufio" "fmt" - "os" "path/filepath" - "strconv" "strings" "time" @@ -50,7 +47,7 @@ func SetDefaultLoggerConfig(cfgMode string, cfgFolder string, cfgLevel log.Level } logLevel = cfgLevel log.SetLevel(logLevel) - logFormatter = &log.TextFormatter{TimestampFormat: "02-01-2006 15:04:05", FullTimestamp: true, ForceColors: forceColors} + logFormatter = &log.TextFormatter{TimestampFormat: time.RFC3339, FullTimestamp: true, ForceColors: forceColors} log.SetFormatter(logFormatter) return nil } @@ -68,40 +65,15 @@ func ConfigureLogger(clog *log.Logger) error { return nil } -func ParseDuration(d string) (time.Duration, error) { - durationStr := d - if strings.HasSuffix(d, "d") { - days := strings.Split(d, "d")[0] - if len(days) == 0 { - return 0, fmt.Errorf("'%s' can't be parsed as duration", d) - } - daysInt, err := strconv.Atoi(days) - if err != nil { - return 0, err - } - durationStr = strconv.Itoa(daysInt*24) + "h" - } - duration, err := time.ParseDuration(durationStr) - if err != nil { - return 0, err - } - return duration, nil -} - func UtcNow() time.Time { return time.Now().UTC() } -func GetLineCountForFile(filepath string) int { - f, err := os.Open(filepath) +func IsNetworkFS(path string) (bool, string, error) { + fsType, err := GetFSType(path) if err != nil { - log.Fatalf("unable to open log file %s : %s", filepath, err) - } - defer f.Close() - lc := 0 - fs := bufio.NewScanner(f) - for fs.Scan() { - lc++ + return false, "", err } - return lc + fsType = strings.ToLower(fsType) + return fsType == "nfs" || fsType == "cifs" || fsType == "smb" || fsType == "smb2", fsType, nil } diff --git a/plugins/notifications/dummy/LICENSE b/plugins/notifications/dummy/LICENSE deleted file mode 100644 index 912563863fc..00000000000 --- a/plugins/notifications/dummy/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2021 Crowdsec - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/plugins/notifications/email/LICENSE b/plugins/notifications/email/LICENSE deleted file mode 100644 index 912563863fc..00000000000 --- a/plugins/notifications/email/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2021 Crowdsec - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/plugins/notifications/email/Makefile b/plugins/notifications/email/Makefile deleted file mode 100644 index ae548af0a38..00000000000 --- a/plugins/notifications/email/Makefile +++ /dev/null @@ -1,18 +0,0 @@ -ifeq ($(OS), Windows_NT) - SHELL := pwsh.exe - .SHELLFLAGS := -NoProfile -Command - EXT = .exe -endif - -PLUGIN = email -BINARY_NAME = notification-$(PLUGIN)$(EXT) - -GOCMD = go -GOBUILD = $(GOCMD) build - -build: clean - $(GOBUILD) $(LD_OPTS) $(BUILD_VENDOR_FLAGS) -o $(BINARY_NAME) - -.PHONY: clean -clean: - @$(RM) $(BINARY_NAME) $(WIN_IGNORE_ERR) diff --git a/plugins/notifications/email/go.mod b/plugins/notifications/email/go.mod deleted file mode 100644 index 6e3f55d3648..00000000000 --- a/plugins/notifications/email/go.mod +++ /dev/null @@ -1,29 +0,0 @@ -module github.com/crowdsecurity/email-plugin - -go 1.20 - -replace github.com/crowdsecurity/crowdsec => ../../../ - -require ( - github.com/crowdsecurity/crowdsec v1.5.2 - github.com/hashicorp/go-hclog v1.5.0 - github.com/hashicorp/go-plugin v1.4.10 - github.com/xhit/go-simple-mail/v2 v2.10.0 - gopkg.in/yaml.v2 v2.4.0 -) - -require ( - github.com/fatih/color v1.15.0 // indirect - github.com/golang/protobuf v1.5.3 // indirect - github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb // indirect - github.com/mattn/go-colorable v0.1.13 // indirect - github.com/mattn/go-isatty v0.0.19 // indirect - github.com/mitchellh/go-testing-interface v1.0.0 // indirect - github.com/oklog/run v1.0.0 // indirect - golang.org/x/net v0.10.0 // indirect - golang.org/x/sys v0.9.0 // indirect - golang.org/x/text v0.9.0 // indirect - google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 // indirect - google.golang.org/grpc v1.56.1 // indirect - google.golang.org/protobuf v1.30.0 // indirect -) diff --git a/plugins/notifications/email/go.sum b/plugins/notifications/email/go.sum deleted file mode 100644 index f4cad7b1b12..00000000000 --- a/plugins/notifications/email/go.sum +++ /dev/null @@ -1,66 +0,0 @@ -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= -github.com/fatih/color v1.15.0 h1:kOqh6YHBtK8aywxGerMG2Eq3H6Qgoqeo13Bk2Mv/nBs= -github.com/fatih/color v1.15.0/go.mod h1:0h5ZqXfHYED7Bhv2ZJamyIOUej9KtShiJESRwBDUSsw= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= -github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/hashicorp/go-hclog v1.5.0 h1:bI2ocEMgcVlz55Oj1xZNBsVi900c7II+fWDyV9o+13c= -github.com/hashicorp/go-hclog v1.5.0/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= -github.com/hashicorp/go-plugin v1.4.10 h1:xUbmA4jC6Dq163/fWcp8P3JuHilrHHMLNRxzGQJ9hNk= -github.com/hashicorp/go-plugin v1.4.10/go.mod h1:6/1TEzT0eQznvI/gV2CM29DLSkAK/e58mUWKVsPaph0= -github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb h1:b5rjCoWHc7eqmAS4/qyk21ZsHyb6Mxv/jykxvNTkU4M= -github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb/go.mod h1:+NfK9FKeTrX5uv1uIXGdwYDTeHna2qgaIlx54MXqjAM= -github.com/jhump/protoreflect v1.6.0 h1:h5jfMVslIg6l29nsMs0D8Wj17RDVdNYti0vDN/PZZoE= -github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= -github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= -github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= -github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= -github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= -github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= -github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= -github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mitchellh/go-testing-interface v1.0.0 h1:fzU/JVNcaqHQEcVFAKeR41fkiLdIPrefOvVG1VZ96U0= -github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= -github.com/oklog/run v1.0.0 h1:Ru7dDtJNOyC66gQ5dQmaCa0qIsAUFY3sFpK1Xk8igrw= -github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= -github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= -github.com/xhit/go-simple-mail/v2 v2.10.0 h1:nib6RaJ4qVh5HD9UE9QJqnUZyWp3upv+Z6CFxaMj0V8= -github.com/xhit/go-simple-mail/v2 v2.10.0/go.mod h1:kA1XbQfCI4JxQ9ccSN6VFyIEkkugOm7YiPkA5hKiQn4= -golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= -golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 h1:KpwkzHKEF7B9Zxg18WzOa7djJ+Ha5DzthMyZYQfEn2A= -google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU= -google.golang.org/grpc v1.56.1 h1:z0dNfjIl0VpaZ9iSVjA6daGatAYwPGstTjt5vkRMFkQ= -google.golang.org/grpc v1.56.1/go.mod h1:I9bI3vqKfayGqPUAwGdOSu7kt6oIJLixfffKrpXqQ9s= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= -google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/plugins/notifications/http/LICENSE b/plugins/notifications/http/LICENSE deleted file mode 100644 index 912563863fc..00000000000 --- a/plugins/notifications/http/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2021 Crowdsec - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/plugins/notifications/http/go.mod b/plugins/notifications/http/go.mod deleted file mode 100644 index b78efec7d42..00000000000 --- a/plugins/notifications/http/go.mod +++ /dev/null @@ -1,28 +0,0 @@ -module github.com/crowdsecurity/http-plugin - -go 1.20 - -replace github.com/crowdsecurity/crowdsec => ../../../ - -require ( - github.com/crowdsecurity/crowdsec v1.5.2 - github.com/hashicorp/go-hclog v1.5.0 - github.com/hashicorp/go-plugin v1.4.10 - gopkg.in/yaml.v2 v2.4.0 -) - -require ( - github.com/fatih/color v1.15.0 // indirect - github.com/golang/protobuf v1.5.3 // indirect - github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb // indirect - github.com/mattn/go-colorable v0.1.13 // indirect - github.com/mattn/go-isatty v0.0.19 // indirect - github.com/mitchellh/go-testing-interface v1.0.0 // indirect - github.com/oklog/run v1.0.0 // indirect - golang.org/x/net v0.10.0 // indirect - golang.org/x/sys v0.9.0 // indirect - golang.org/x/text v0.9.0 // indirect - google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 // indirect - google.golang.org/grpc v1.56.1 // indirect - google.golang.org/protobuf v1.30.0 // indirect -) diff --git a/plugins/notifications/http/go.sum b/plugins/notifications/http/go.sum deleted file mode 100644 index 0c5a2fcc01c..00000000000 --- a/plugins/notifications/http/go.sum +++ /dev/null @@ -1,64 +0,0 @@ -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= -github.com/fatih/color v1.15.0 h1:kOqh6YHBtK8aywxGerMG2Eq3H6Qgoqeo13Bk2Mv/nBs= -github.com/fatih/color v1.15.0/go.mod h1:0h5ZqXfHYED7Bhv2ZJamyIOUej9KtShiJESRwBDUSsw= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= -github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/hashicorp/go-hclog v1.5.0 h1:bI2ocEMgcVlz55Oj1xZNBsVi900c7II+fWDyV9o+13c= -github.com/hashicorp/go-hclog v1.5.0/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= -github.com/hashicorp/go-plugin v1.4.10 h1:xUbmA4jC6Dq163/fWcp8P3JuHilrHHMLNRxzGQJ9hNk= -github.com/hashicorp/go-plugin v1.4.10/go.mod h1:6/1TEzT0eQznvI/gV2CM29DLSkAK/e58mUWKVsPaph0= -github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb h1:b5rjCoWHc7eqmAS4/qyk21ZsHyb6Mxv/jykxvNTkU4M= -github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb/go.mod h1:+NfK9FKeTrX5uv1uIXGdwYDTeHna2qgaIlx54MXqjAM= -github.com/jhump/protoreflect v1.6.0 h1:h5jfMVslIg6l29nsMs0D8Wj17RDVdNYti0vDN/PZZoE= -github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= -github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= -github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= -github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= -github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= -github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= -github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= -github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mitchellh/go-testing-interface v1.0.0 h1:fzU/JVNcaqHQEcVFAKeR41fkiLdIPrefOvVG1VZ96U0= -github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= -github.com/oklog/run v1.0.0 h1:Ru7dDtJNOyC66gQ5dQmaCa0qIsAUFY3sFpK1Xk8igrw= -github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= -github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= -golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= -golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 h1:KpwkzHKEF7B9Zxg18WzOa7djJ+Ha5DzthMyZYQfEn2A= -google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU= -google.golang.org/grpc v1.56.1 h1:z0dNfjIl0VpaZ9iSVjA6daGatAYwPGstTjt5vkRMFkQ= -google.golang.org/grpc v1.56.1/go.mod h1:I9bI3vqKfayGqPUAwGdOSu7kt6oIJLixfffKrpXqQ9s= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= -google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/plugins/notifications/slack/LICENSE b/plugins/notifications/slack/LICENSE deleted file mode 100644 index 912563863fc..00000000000 --- a/plugins/notifications/slack/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2021 Crowdsec - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/plugins/notifications/slack/go.mod b/plugins/notifications/slack/go.mod deleted file mode 100644 index 03d1dc93429..00000000000 --- a/plugins/notifications/slack/go.mod +++ /dev/null @@ -1,31 +0,0 @@ -module github.com/crowdsecurity/slack-plugin - -go 1.20 - -replace github.com/crowdsecurity/crowdsec => ../../../ - -require ( - github.com/crowdsecurity/crowdsec v1.5.2 - github.com/hashicorp/go-hclog v1.5.0 - github.com/hashicorp/go-plugin v1.4.10 - github.com/slack-go/slack v0.9.2 - gopkg.in/yaml.v2 v2.4.0 -) - -require ( - github.com/fatih/color v1.15.0 // indirect - github.com/golang/protobuf v1.5.3 // indirect - github.com/gorilla/websocket v1.4.2 // indirect - github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb // indirect - github.com/mattn/go-colorable v0.1.13 // indirect - github.com/mattn/go-isatty v0.0.19 // indirect - github.com/mitchellh/go-testing-interface v1.0.0 // indirect - github.com/oklog/run v1.0.0 // indirect - github.com/pkg/errors v0.9.1 // indirect - golang.org/x/net v0.10.0 // indirect - golang.org/x/sys v0.9.0 // indirect - golang.org/x/text v0.9.0 // indirect - google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 // indirect - google.golang.org/grpc v1.56.1 // indirect - google.golang.org/protobuf v1.30.0 // indirect -) diff --git a/plugins/notifications/slack/go.sum b/plugins/notifications/slack/go.sum deleted file mode 100644 index e1eb7c4e280..00000000000 --- a/plugins/notifications/slack/go.sum +++ /dev/null @@ -1,74 +0,0 @@ -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= -github.com/fatih/color v1.15.0 h1:kOqh6YHBtK8aywxGerMG2Eq3H6Qgoqeo13Bk2Mv/nBs= -github.com/fatih/color v1.15.0/go.mod h1:0h5ZqXfHYED7Bhv2ZJamyIOUej9KtShiJESRwBDUSsw= -github.com/go-test/deep v1.0.4 h1:u2CU3YKy9I2pmu9pX0eq50wCgjfGIt539SqR7FbHiho= -github.com/go-test/deep v1.0.4/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= -github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= -github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/hashicorp/go-hclog v1.5.0 h1:bI2ocEMgcVlz55Oj1xZNBsVi900c7II+fWDyV9o+13c= -github.com/hashicorp/go-hclog v1.5.0/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= -github.com/hashicorp/go-plugin v1.4.10 h1:xUbmA4jC6Dq163/fWcp8P3JuHilrHHMLNRxzGQJ9hNk= -github.com/hashicorp/go-plugin v1.4.10/go.mod h1:6/1TEzT0eQznvI/gV2CM29DLSkAK/e58mUWKVsPaph0= -github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb h1:b5rjCoWHc7eqmAS4/qyk21ZsHyb6Mxv/jykxvNTkU4M= -github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb/go.mod h1:+NfK9FKeTrX5uv1uIXGdwYDTeHna2qgaIlx54MXqjAM= -github.com/jhump/protoreflect v1.6.0 h1:h5jfMVslIg6l29nsMs0D8Wj17RDVdNYti0vDN/PZZoE= -github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= -github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= -github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= -github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= -github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= -github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= -github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= -github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mitchellh/go-testing-interface v1.0.0 h1:fzU/JVNcaqHQEcVFAKeR41fkiLdIPrefOvVG1VZ96U0= -github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= -github.com/oklog/run v1.0.0 h1:Ru7dDtJNOyC66gQ5dQmaCa0qIsAUFY3sFpK1Xk8igrw= -github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA= -github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/slack-go/slack v0.9.2 h1:tjIrKKYUCOmWeEAktWShKW+3UjLTH/wmgmCkAGAf8wM= -github.com/slack-go/slack v0.9.2/go.mod h1:wWL//kk0ho+FcQXcBTmEafUI5dz4qz5f4mMk8oIkioQ= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= -github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= -golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= -golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 h1:KpwkzHKEF7B9Zxg18WzOa7djJ+Ha5DzthMyZYQfEn2A= -google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU= -google.golang.org/grpc v1.56.1 h1:z0dNfjIl0VpaZ9iSVjA6daGatAYwPGstTjt5vkRMFkQ= -google.golang.org/grpc v1.56.1/go.mod h1:I9bI3vqKfayGqPUAwGdOSu7kt6oIJLixfffKrpXqQ9s= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= -google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/plugins/notifications/splunk/LICENSE b/plugins/notifications/splunk/LICENSE deleted file mode 100644 index 912563863fc..00000000000 --- a/plugins/notifications/splunk/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2021 Crowdsec - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/plugins/notifications/splunk/go.mod b/plugins/notifications/splunk/go.mod deleted file mode 100644 index 1daf43e3452..00000000000 --- a/plugins/notifications/splunk/go.mod +++ /dev/null @@ -1,28 +0,0 @@ -module github.com/crowdsecurity/splunk-plugin - -go 1.20 - -replace github.com/crowdsecurity/crowdsec => ../../../ - -require ( - github.com/crowdsecurity/crowdsec v1.5.2 - github.com/hashicorp/go-hclog v1.5.0 - github.com/hashicorp/go-plugin v1.4.10 - gopkg.in/yaml.v2 v2.4.0 -) - -require ( - github.com/fatih/color v1.15.0 // indirect - github.com/golang/protobuf v1.5.3 // indirect - github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb // indirect - github.com/mattn/go-colorable v0.1.13 // indirect - github.com/mattn/go-isatty v0.0.19 // indirect - github.com/mitchellh/go-testing-interface v1.0.0 // indirect - github.com/oklog/run v1.0.0 // indirect - golang.org/x/net v0.10.0 // indirect - golang.org/x/sys v0.9.0 // indirect - golang.org/x/text v0.9.0 // indirect - google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 // indirect - google.golang.org/grpc v1.56.1 // indirect - google.golang.org/protobuf v1.30.0 // indirect -) diff --git a/plugins/notifications/splunk/go.sum b/plugins/notifications/splunk/go.sum deleted file mode 100644 index 0c5a2fcc01c..00000000000 --- a/plugins/notifications/splunk/go.sum +++ /dev/null @@ -1,64 +0,0 @@ -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= -github.com/fatih/color v1.15.0 h1:kOqh6YHBtK8aywxGerMG2Eq3H6Qgoqeo13Bk2Mv/nBs= -github.com/fatih/color v1.15.0/go.mod h1:0h5ZqXfHYED7Bhv2ZJamyIOUej9KtShiJESRwBDUSsw= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= -github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/hashicorp/go-hclog v1.5.0 h1:bI2ocEMgcVlz55Oj1xZNBsVi900c7II+fWDyV9o+13c= -github.com/hashicorp/go-hclog v1.5.0/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= -github.com/hashicorp/go-plugin v1.4.10 h1:xUbmA4jC6Dq163/fWcp8P3JuHilrHHMLNRxzGQJ9hNk= -github.com/hashicorp/go-plugin v1.4.10/go.mod h1:6/1TEzT0eQznvI/gV2CM29DLSkAK/e58mUWKVsPaph0= -github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb h1:b5rjCoWHc7eqmAS4/qyk21ZsHyb6Mxv/jykxvNTkU4M= -github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb/go.mod h1:+NfK9FKeTrX5uv1uIXGdwYDTeHna2qgaIlx54MXqjAM= -github.com/jhump/protoreflect v1.6.0 h1:h5jfMVslIg6l29nsMs0D8Wj17RDVdNYti0vDN/PZZoE= -github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= -github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= -github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= -github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= -github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= -github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= -github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= -github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mitchellh/go-testing-interface v1.0.0 h1:fzU/JVNcaqHQEcVFAKeR41fkiLdIPrefOvVG1VZ96U0= -github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= -github.com/oklog/run v1.0.0 h1:Ru7dDtJNOyC66gQ5dQmaCa0qIsAUFY3sFpK1Xk8igrw= -github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= -github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= -golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= -golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 h1:KpwkzHKEF7B9Zxg18WzOa7djJ+Ha5DzthMyZYQfEn2A= -google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU= -google.golang.org/grpc v1.56.1 h1:z0dNfjIl0VpaZ9iSVjA6daGatAYwPGstTjt5vkRMFkQ= -google.golang.org/grpc v1.56.1/go.mod h1:I9bI3vqKfayGqPUAwGdOSu7kt6oIJLixfffKrpXqQ9s= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= -google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/rpm/SOURCES/crowdsec.unit.patch b/rpm/SOURCES/crowdsec.unit.patch deleted file mode 100644 index b3c9b35fd27..00000000000 --- a/rpm/SOURCES/crowdsec.unit.patch +++ /dev/null @@ -1,13 +0,0 @@ ---- config/crowdsec.service-orig 2022-03-24 09:46:16.581681532 +0000 -+++ config/crowdsec.service 2022-03-24 09:46:28.761681532 +0000 -@@ -5,8 +5,8 @@ - [Service] - Type=notify - Environment=LC_ALL=C LANG=C --ExecStartPre=/usr/local/bin/crowdsec -c /etc/crowdsec/config.yaml -t --ExecStart=/usr/local/bin/crowdsec -c /etc/crowdsec/config.yaml -+ExecStartPre=/usr/bin/crowdsec -c /etc/crowdsec/config.yaml -t -+ExecStart=/usr/bin/crowdsec -c /etc/crowdsec/config.yaml - #ExecStartPost=/bin/sleep 0.1 - ExecReload=/bin/kill -HUP $MAINPID - Restart=always diff --git a/rpm/SPECS/crowdsec.spec b/rpm/SPECS/crowdsec.spec index a57492eea81..ab71b650d11 100644 --- a/rpm/SPECS/crowdsec.spec +++ b/rpm/SPECS/crowdsec.spec @@ -8,8 +8,7 @@ License: MIT URL: https://crowdsec.net Source0: https://github.com/crowdsecurity/%{name}/archive/v%(echo $VERSION).tar.gz Source1: 80-%{name}.preset -Patch0: crowdsec.unit.patch -Patch1: user.patch +Patch0: user.patch BuildRoot: %{_tmppath}/%{name}-%{version}-%{release}-root-%(%{__id_u} -n) BuildRequires: systemd @@ -32,13 +31,13 @@ Requires: crontabs %setup -q -T -b 0 %patch0 -%patch1 %build sed -i "s#/usr/local/lib/crowdsec/plugins/#%{_libdir}/%{name}/plugins/#g" config/config.yaml %install rm -rf %{buildroot} +mkdir -p %{buildroot}/etc/crowdsec/acquis.d mkdir -p %{buildroot}/etc/crowdsec/hub mkdir -p %{buildroot}/etc/crowdsec/patterns mkdir -p %{buildroot}/etc/crowdsec/console/ @@ -53,7 +52,7 @@ mkdir -p %{buildroot}%{_libdir}/%{name}/plugins/ install -m 755 -D cmd/crowdsec/crowdsec %{buildroot}%{_bindir}/%{name} install -m 755 -D cmd/crowdsec-cli/cscli %{buildroot}%{_bindir}/cscli install -m 755 -D wizard.sh %{buildroot}/usr/share/crowdsec/wizard.sh -install -m 644 -D config/crowdsec.service %{buildroot}%{_unitdir}/%{name}.service +install -m 644 -D debian/crowdsec.service %{buildroot}%{_unitdir}/%{name}.service install -m 644 -D config/patterns/* -t %{buildroot}%{_sysconfdir}/crowdsec/patterns install -m 600 -D config/config.yaml %{buildroot}%{_sysconfdir}/crowdsec install -m 644 -D config/simulation.yaml %{buildroot}%{_sysconfdir}/crowdsec @@ -63,15 +62,19 @@ install -m 644 -D config/context.yaml %{buildroot}%{_sysconfdir}/crowdsec/consol install -m 750 -D config/%{name}.cron.daily %{buildroot}%{_sysconfdir}/cron.daily/%{name} install -m 644 -D %{SOURCE1} %{buildroot}%{_presetdir} -install -m 551 plugins/notifications/slack/notification-slack %{buildroot}%{_libdir}/%{name}/plugins/ -install -m 551 plugins/notifications/http/notification-http %{buildroot}%{_libdir}/%{name}/plugins/ -install -m 551 plugins/notifications/splunk/notification-splunk %{buildroot}%{_libdir}/%{name}/plugins/ -install -m 551 plugins/notifications/email/notification-email %{buildroot}%{_libdir}/%{name}/plugins/ +install -m 551 cmd/notification-slack/notification-slack %{buildroot}%{_libdir}/%{name}/plugins/ +install -m 551 cmd/notification-http/notification-http %{buildroot}%{_libdir}/%{name}/plugins/ +install -m 551 cmd/notification-splunk/notification-splunk %{buildroot}%{_libdir}/%{name}/plugins/ +install -m 551 cmd/notification-email/notification-email %{buildroot}%{_libdir}/%{name}/plugins/ +install -m 551 cmd/notification-sentinel/notification-sentinel %{buildroot}%{_libdir}/%{name}/plugins/ +install -m 551 cmd/notification-file/notification-file %{buildroot}%{_libdir}/%{name}/plugins/ -install -m 600 plugins/notifications/slack/slack.yaml %{buildroot}%{_sysconfdir}/crowdsec/notifications/ -install -m 600 plugins/notifications/http/http.yaml %{buildroot}%{_sysconfdir}/crowdsec/notifications/ -install -m 600 plugins/notifications/splunk/splunk.yaml %{buildroot}%{_sysconfdir}/crowdsec/notifications/ -install -m 600 plugins/notifications/email/email.yaml %{buildroot}%{_sysconfdir}/crowdsec/notifications/ +install -m 600 cmd/notification-slack/slack.yaml %{buildroot}%{_sysconfdir}/crowdsec/notifications/ +install -m 600 cmd/notification-http/http.yaml %{buildroot}%{_sysconfdir}/crowdsec/notifications/ +install -m 600 cmd/notification-splunk/splunk.yaml %{buildroot}%{_sysconfdir}/crowdsec/notifications/ +install -m 600 cmd/notification-email/email.yaml %{buildroot}%{_sysconfdir}/crowdsec/notifications/ +install -m 600 cmd/notification-sentinel/sentinel.yaml %{buildroot}%{_sysconfdir}/crowdsec/notifications/ +install -m 600 cmd/notification-file/file.yaml %{buildroot}%{_sysconfdir}/crowdsec/notifications/ %clean rm -rf %{buildroot} @@ -85,6 +88,8 @@ rm -rf %{buildroot} %{_libdir}/%{name}/plugins/notification-http %{_libdir}/%{name}/plugins/notification-splunk %{_libdir}/%{name}/plugins/notification-email +%{_libdir}/%{name}/plugins/notification-sentinel +%{_libdir}/%{name}/plugins/notification-file %{_sysconfdir}/%{name}/patterns/linux-syslog %{_sysconfdir}/%{name}/patterns/ruby %{_sysconfdir}/%{name}/patterns/nginx @@ -119,6 +124,8 @@ rm -rf %{buildroot} %config(noreplace) %{_sysconfdir}/%{name}/notifications/slack.yaml %config(noreplace) %{_sysconfdir}/%{name}/notifications/splunk.yaml %config(noreplace) %{_sysconfdir}/%{name}/notifications/email.yaml +%config(noreplace) %{_sysconfdir}/%{name}/notifications/sentinel.yaml +%config(noreplace) %{_sysconfdir}/%{name}/notifications/file.yaml %config(noreplace) %{_sysconfdir}/cron.daily/%{name} %{_unitdir}/%{name}.service @@ -126,6 +133,7 @@ rm -rf %{buildroot} %ghost %{_sysconfdir}/%{name}/hub/.index.json %ghost %{_localstatedir}/log/%{name}.log %dir /var/lib/%{name}/data/ +%dir %{_sysconfdir}/%{name}/hub %ghost %{_sysconfdir}/crowdsec/local_api_credentials.yaml %ghost %{_sysconfdir}/crowdsec/online_api_credentials.yaml @@ -161,24 +169,23 @@ if [ $1 == 1 ]; then SILENT=true TMP_ACQUIS_FILE_SKIP=skip genacquisition set +e fi - if [ ! -f "%{_sysconfdir}/crowdsec/online_api_credentials.yaml" ] && [ ! -f "%{_sysconfdir}/crowdsec/local_api_credentials.yaml" ] ; then - install -m 600 /dev/null %{_sysconfdir}/crowdsec/online_api_credentials.yaml - install -m 600 /dev/null %{_sysconfdir}/crowdsec/local_api_credentials.yaml - cscli capi register - cscli machines add -a - fi if [ ! -f "%{_sysconfdir}/crowdsec/online_api_credentials.yaml" ] ; then - touch %{_sysconfdir}/crowdsec/online_api_credentials.yaml - cscli capi register + install -m 600 /dev/null /etc/crowdsec/online_api_credentials.yaml + cscli capi register --error fi if [ ! -f "%{_sysconfdir}/crowdsec/local_api_credentials.yaml" ] ; then - touch %{_sysconfdir}/crowdsec/local_api_credentials.yaml - cscli machines add -a + install -m 600 /dev/null /etc/crowdsec/local_api_credentials.yaml + cscli machines add -a --force --error fi cscli hub update CSCLI_BIN_INSTALLED="/usr/bin/cscli" SILENT=true install_collection + echo "Get started with CrowdSec:" + echo " * Detailed guides are available in our documentation: https://docs.crowdsec.net" + echo " * Configuration items created by the community can be found at the Hub: https://hub.crowdsec.net" + echo " * Gain insights into your use of CrowdSec with the help of the console https://app.crowdsec.net" + #upgrade elif [ $1 == 2 ] && [ -d /var/lib/crowdsec/backup ]; then cscli config restore /var/lib/crowdsec/backup @@ -199,7 +206,7 @@ fi if [ $1 == 1 ]; then API=$(cscli config show --key "Config.API.Server") - if [ "$API" = "" ] ; then + if [ "$API" = "nil" ] ; then LAPI=false else PORT=$(cscli config show --key "Config.API.Server.ListenURI"|cut -d ":" -f2) diff --git a/scripts/test_env.ps1 b/scripts/test_env.ps1 index 3d8e18ac296..f81b61d5a46 100644 --- a/scripts/test_env.ps1 +++ b/scripts/test_env.ps1 @@ -9,7 +9,7 @@ function show_help() { Write-Output ".\test_env.ps1 -d tests #creates test env in .\tests" } -function create_arbo() { +function create_tree() { $null = New-Item -ItemType Directory $data_dir $null = New-Item -ItemType Directory $log_dir $null = New-Item -ItemType Directory $config_dir @@ -37,8 +37,8 @@ function copy_file() { #envsubst < "./config/dev.yaml" > $BASE/dev.yaml Copy-Item .\config\dev.yaml $base\dev.yaml $plugins | ForEach-Object { - Copy-Item $plugins_dir\$notif_dir\$_\notification-$_.exe $base\$plugins_dir\notification-$_.exe - Copy-Item $plugins_dir\$notif_dir\$_\$_.yaml $config_dir\$notif_dir\$_.yaml + Copy-Item .\cmd\notification-$_\notification-$_.exe $base\$plugins_dir\notification-$_.exe + Copy-Item .\cmd\notification-$_\$_.yaml $config_dir\$notif_dir\$_.yaml } } @@ -71,14 +71,14 @@ $parser_s02="$parser_dir\s02-enrich" $scenarios_dir="$config_dir\scenarios" $postoverflows_dir="$config_dir\postoverflows" $hub_dir="$config_dir\hub" -$plugins=@("http", "slack", "splunk") +$plugins=@("http", "slack", "splunk", "email", "sentinel") $plugins_dir="plugins" $notif_dir="notifications" -Write-Output "Creating test arbo in $base" -create_arbo -Write-Output "Arbo created" +Write-Output "Creating test tree in $base" +create_tree +Write-Output "Tree created" Write-Output "Copying files" copy_file Write-Output "Files copied" diff --git a/scripts/test_env.sh b/scripts/test_env.sh index b203e7f3ef2..2e089ead073 100755 --- a/scripts/test_env.sh +++ b/scripts/test_env.sh @@ -3,10 +3,10 @@ BASE="./tests" usage() { - echo "Usage:" - echo " ./wizard.sh -h Display this help message." - echo " ./test_env.sh -d ./tests Create test environment in './tests' folder" - exit 0 + echo "Usage:" + echo " $0 -h Display this help message." + echo " $0 -d ./tests Create test environment in './tests' folder" + exit 0 } @@ -24,7 +24,7 @@ do exit 0 ;; *) # unknown option - log_err "Unknown argument ${key}." + echo "Unknown argument ${key}." >&2 usage exit 1 ;; @@ -47,7 +47,7 @@ PARSER_S02="$PARSER_DIR/s02-enrich" SCENARIOS_DIR="$CONFIG_DIR/scenarios" POSTOVERFLOWS_DIR="$CONFIG_DIR/postoverflows" HUB_DIR="$CONFIG_DIR/hub" -PLUGINS="http slack splunk email" +PLUGINS="http slack splunk email sentinel" PLUGINS_DIR="plugins" NOTIF_DIR="notifications" @@ -57,7 +57,7 @@ log_info() { echo -e "[$date][INFO] $msg" } -create_arbo() { +create_tree() { mkdir -p "$BASE" mkdir -p "$DATA_DIR" mkdir -p "$LOG_DIR" @@ -86,8 +86,8 @@ copy_files() { envsubst < "./config/dev.yaml" > $BASE/dev.yaml for plugin in $PLUGINS do - cp $PLUGINS_DIR/$NOTIF_DIR/$plugin/notification-$plugin $BASE/$PLUGINS_DIR/notification-$plugin - cp $PLUGINS_DIR/$NOTIF_DIR/$plugin/$plugin.yaml $CONFIG_DIR/$NOTIF_DIR/$plugin.yaml + cp cmd/notification-$plugin/notification-$plugin $BASE/$PLUGINS_DIR/notification-$plugin + cp cmd/notification-$plugin/$plugin.yaml $CONFIG_DIR/$NOTIF_DIR/$plugin.yaml done } @@ -103,9 +103,9 @@ setup_api() { main() { - log_info "Creating test arboresence in $BASE" - create_arbo - log_info "Arboresence created" + log_info "Creating test tree in $BASE" + create_tree + log_info "Tree created" log_info "Copying needed files for tests environment" copy_files log_info "Files copied" diff --git a/test/README.md b/test/README.md index 7f34bd3dbef..f7b036e7905 100644 --- a/test/README.md +++ b/test/README.md @@ -61,9 +61,6 @@ architectures. - `curl` - `daemonize` - `jq` - - `nc` - - `openssl` - - `openbsd-netcat` - `python3` ## Running all tests @@ -242,6 +239,11 @@ according to the specific needs of the group of tests in the file. crowdsec instance. Crowdsec must not be running while this operation is performed. + - instance-data lock/unlock + +When playing around with a local crowdsec installation, you can run "instance-data lock" +to prevent the bats suite from running, so it won't overwrite your configuration or data. + - `instance-crowdsec [ start | stop ]` Runs (or stops) crowdsec as a background process. PID and lockfiles are @@ -413,10 +415,3 @@ different syntax. Check the heredocs (the <=4.4). It takes an enviroment file, and optionally a list of directories with +>=4.4). It takes an environment file, and optionally a list of directories with vagrant configurations. With a single parameter, it loops over all the directories in alphabetical order, excluding those in the `experimental` directory. Watch out for running VMs if you break the loop by hand. diff --git a/test/ansible/debug_tools.yml b/test/ansible/debug_tools.yml index 769a973fe95..d2e493f8698 100644 --- a/test/ansible/debug_tools.yml +++ b/test/ansible/debug_tools.yml @@ -14,5 +14,6 @@ - zsh-autosuggestions - zsh-syntax-highlighting - zsh-theme-powerlevel9k + - silversearcher-ag when: - ansible_facts.os_family == "Debian" diff --git a/test/ansible/provision_dependencies.yml b/test/ansible/provision_dependencies.yml index bcfe8fccafb..144adf8ca36 100644 --- a/test/ansible/provision_dependencies.yml +++ b/test/ansible/provision_dependencies.yml @@ -1,6 +1,40 @@ # vim: set ft=yaml.ansible: --- +- name: "Fix EOL'd centos Stream 8" + hosts: all + tasks: + - name: "update repositories file" + ansible.builtin.find: + paths: /etc/yum.repos.d + patterns: "*.repo" + register: "repo_files" + when: + - ansible_facts.distribution == "CentOS" + - ansible_facts.distribution_major_version == '8' + - name: Replace old text with new text + become: true + ansible.builtin.replace: + path: "{{ item.path }}" + regexp: 'mirrorlist' + replace: '#mirrorlist' + loop: "{{ repo_files.files }}" + when: + - ansible_facts.distribution == "CentOS" + - ansible_facts.distribution_major_version == '8' + - repo_files.matched > 0 + - name: Replace old text with new text + become: true + ansible.builtin.replace: + path: "{{ item.path }}" + regexp: '#baseurl=http://mirror.centos.org' + replace: 'baseurl=https://vault.centos.org' + loop: "{{ repo_files.files }}" + when: + - ansible_facts.distribution == "CentOS" + - ansible_facts.distribution_major_version == '8' + - repo_files.matched > 0 + - name: "Install required packages" hosts: all vars_files: @@ -17,6 +51,19 @@ - crowdsecurity.testing.re2 - crowdsecurity.testing.bats_requirements +- name: "Install recent python" + hosts: all + vars_files: + - vars/python.yml + tasks: + - name: role "crowdsecurity.testing.python3" + ansible.builtin.include_role: + name: crowdsecurity.testing.python3 + when: + - ansible_facts.distribution in ['CentOS', 'OracleLinux'] + - ansible_facts.distribution_major_version == '8' or ansible_facts.distribution_major_version == '7' + + - name: "Install Postgres" hosts: all become: true diff --git a/test/ansible/requirements.yml b/test/ansible/requirements.yml index a780e827f85..d5a9b80f659 100644 --- a/test/ansible/requirements.yml +++ b/test/ansible/requirements.yml @@ -14,7 +14,7 @@ collections: - name: ansible.posix - name: https://github.com/crowdsecurity/ansible-collection-crowdsecurity.testing.git type: git - version: v0.0.5 + version: v0.0.7 # - name: crowdsecurity.testing # source: ../../../crowdsecurity.testing diff --git a/test/ansible/roles/make_fixture/tasks/main.yml b/test/ansible/roles/make_fixture/tasks/main.yml index 305cec3a697..908bcf4f14c 100644 --- a/test/ansible/roles/make_fixture/tasks/main.yml +++ b/test/ansible/roles/make_fixture/tasks/main.yml @@ -52,7 +52,7 @@ # daemonize -> /usr/bin or /usr/local/sbin # pidof -> /usr/sbin # bash -> /opt/bash/bin - PATH: "/opt/bash/bin:{{ ansible_env.PATH }}:/usr/sbin:/usr/local/sbin" + PATH: "/opt/bash/bin:{{ ansible_env.PATH }}:{{ golang_install_dir }}/bin/:/usr/sbin:/usr/local/sbin" rescue: - name: "Read crowdsec.log" ansible.builtin.slurp: diff --git a/test/ansible/vagrant/common b/test/ansible/vagrant/common index 4bc237a7e5b..83d7706756e 100644 --- a/test/ansible/vagrant/common +++ b/test/ansible/vagrant/common @@ -14,14 +14,14 @@ end Vagrant.configure('2') do |config| config.vm.define 'crowdsec' - if ARGV.any? { |arg| arg == 'up' || arg == 'provision' } + if ARGV.any? { |arg| arg == 'up' || arg == 'provision' } && !ARGV.include?('--no-provision') unless ENV['DB_BACKEND'] $stderr.puts "\e[31mThe DB_BACKEND environment variable is not defined. Please set up the environment and try again.\e[0m" exit 1 end end - config.vm.provision 'shell', path: 'bootstrap' if File.exists?('bootstrap') + config.vm.provision 'shell', path: 'bootstrap' if File.exist?('bootstrap') config.vm.synced_folder '.', '/vagrant', disabled: true config.vm.provider :libvirt do |libvirt| diff --git a/test/ansible/vagrant/experimental/alpine-3.16/bootstrap b/test/ansible/vagrant/experimental/alpine-3.16/bootstrap index 7fb806bb562..e4ac615f6d4 100755 --- a/test/ansible/vagrant/experimental/alpine-3.16/bootstrap +++ b/test/ansible/vagrant/experimental/alpine-3.16/bootstrap @@ -3,5 +3,5 @@ unset IFS set -euf # coreutils -> for timeout (busybox is not enough) -sudo apk add python3 go tar procps netcat-openbsd coreutils +sudo apk add python3 go tar procps coreutils diff --git a/test/ansible/vagrant/experimental/gentoo/bootstrap b/test/ansible/vagrant/experimental/gentoo/bootstrap index 513af50f97a..52f622983f1 100755 --- a/test/ansible/vagrant/experimental/gentoo/bootstrap +++ b/test/ansible/vagrant/experimental/gentoo/bootstrap @@ -1,3 +1,3 @@ #!/bin/sh -sudo emerge --quiet app-portage/gentoolkit dev-vcs/git net-misc/curl app-misc/jq net-analyzer/openbsd-netcat +sudo emerge --quiet app-portage/gentoolkit dev-vcs/git net-misc/curl app-misc/jq diff --git a/test/ansible/vagrant/experimental/opensuse-15.4/Vagrantfile b/test/ansible/vagrant/experimental/opensuse-15.6/Vagrantfile similarity index 84% rename from test/ansible/vagrant/experimental/opensuse-15.4/Vagrantfile rename to test/ansible/vagrant/experimental/opensuse-15.6/Vagrantfile index 4a3ec307c4f..f2dc70816c9 100644 --- a/test/ansible/vagrant/experimental/opensuse-15.4/Vagrantfile +++ b/test/ansible/vagrant/experimental/opensuse-15.6/Vagrantfile @@ -1,7 +1,8 @@ # frozen_string_literal: true Vagrant.configure('2') do |config| - config.vm.box = 'opensuse/Leap-15.4.x86_64' + config.vm.box = 'opensuse/Leap-15.6.x86_64' + config.vm.box_version = "15.6.13.280" config.vm.define 'crowdsec' config.vm.provision 'shell', path: 'bootstrap' diff --git a/test/ansible/vagrant/experimental/opensuse-15.6/bootstrap b/test/ansible/vagrant/experimental/opensuse-15.6/bootstrap new file mode 100644 index 00000000000..a43165d1828 --- /dev/null +++ b/test/ansible/vagrant/experimental/opensuse-15.6/bootstrap @@ -0,0 +1,3 @@ +#!/bin/sh + +zypper install -y kitty-terminfo diff --git a/test/ansible/vagrant/fedora-33/skip b/test/ansible/vagrant/fedora-37/skip old mode 100755 new mode 100644 similarity index 100% rename from test/ansible/vagrant/fedora-33/skip rename to test/ansible/vagrant/fedora-37/skip diff --git a/test/ansible/vagrant/fedora-34/skip b/test/ansible/vagrant/fedora-38/skip old mode 100755 new mode 100644 similarity index 100% rename from test/ansible/vagrant/fedora-34/skip rename to test/ansible/vagrant/fedora-38/skip diff --git a/test/ansible/vagrant/fedora-34/Vagrantfile b/test/ansible/vagrant/fedora-39/Vagrantfile similarity index 69% rename from test/ansible/vagrant/fedora-34/Vagrantfile rename to test/ansible/vagrant/fedora-39/Vagrantfile index db2db8d0879..ec03661fe39 100644 --- a/test/ansible/vagrant/fedora-34/Vagrantfile +++ b/test/ansible/vagrant/fedora-39/Vagrantfile @@ -1,8 +1,7 @@ # frozen_string_literal: true Vagrant.configure('2') do |config| - # config.vm.box = "fedora/34-cloud-base" - config.vm.box = 'generic/fedora34' + config.vm.box = "fedora/39-cloud-base" config.vm.provision "shell", inline: <<-SHELL SHELL end diff --git a/test/ansible/vagrant/fedora-39/skip b/test/ansible/vagrant/fedora-39/skip new file mode 100644 index 00000000000..4f1a9063d2b --- /dev/null +++ b/test/ansible/vagrant/fedora-39/skip @@ -0,0 +1,9 @@ +#!/bin/sh + +die() { + echo "$@" >&2 + exit 1 +} + +[ "${DB_BACKEND}" = "mysql" ] && die "mysql role does not support this distribution" +exit 0 diff --git a/test/ansible/vagrant/fedora-33/Vagrantfile b/test/ansible/vagrant/fedora-40/Vagrantfile similarity index 69% rename from test/ansible/vagrant/fedora-33/Vagrantfile rename to test/ansible/vagrant/fedora-40/Vagrantfile index df6f06944ae..ec03661fe39 100644 --- a/test/ansible/vagrant/fedora-33/Vagrantfile +++ b/test/ansible/vagrant/fedora-40/Vagrantfile @@ -1,8 +1,7 @@ # frozen_string_literal: true Vagrant.configure('2') do |config| - # config.vm.box = "fedora/33-cloud-base" - config.vm.box = 'generic/fedora33' + config.vm.box = "fedora/39-cloud-base" config.vm.provision "shell", inline: <<-SHELL SHELL end diff --git a/test/ansible/vagrant/fedora-40/skip b/test/ansible/vagrant/fedora-40/skip new file mode 100644 index 00000000000..4f1a9063d2b --- /dev/null +++ b/test/ansible/vagrant/fedora-40/skip @@ -0,0 +1,9 @@ +#!/bin/sh + +die() { + echo "$@" >&2 + exit 1 +} + +[ "${DB_BACKEND}" = "mysql" ] && die "mysql role does not support this distribution" +exit 0 diff --git a/test/ansible/vagrant/ubuntu-22.04-jammy/Vagrantfile b/test/ansible/vagrant/ubuntu-22.04-jammy/Vagrantfile index 9e17f71fb6d..9b399cae4f8 100644 --- a/test/ansible/vagrant/ubuntu-22.04-jammy/Vagrantfile +++ b/test/ansible/vagrant/ubuntu-22.04-jammy/Vagrantfile @@ -3,6 +3,7 @@ Vagrant.configure('2') do |config| config.vm.box = 'generic/ubuntu2204' config.vm.provision "shell", inline: <<-SHELL + sudo apt install -y kitty-terminfo SHELL end diff --git a/test/ansible/vagrant/ubuntu-22.10-kinetic/Vagrantfile b/test/ansible/vagrant/ubuntu-22.10-kinetic/Vagrantfile index 6c15b0a1e30..e08b595684a 100644 --- a/test/ansible/vagrant/ubuntu-22.10-kinetic/Vagrantfile +++ b/test/ansible/vagrant/ubuntu-22.10-kinetic/Vagrantfile @@ -3,6 +3,7 @@ Vagrant.configure('2') do |config| config.vm.box = 'generic/ubuntu2210' config.vm.provision "shell", inline: <<-SHELL + sudo apt install -y kitty-terminfo SHELL end diff --git a/test/ansible/vagrant/ubuntu-23.04-lunar/Vagrantfile b/test/ansible/vagrant/ubuntu-23.04-lunar/Vagrantfile index f40fb7bd59d..367cf5279d5 100644 --- a/test/ansible/vagrant/ubuntu-23.04-lunar/Vagrantfile +++ b/test/ansible/vagrant/ubuntu-23.04-lunar/Vagrantfile @@ -3,6 +3,7 @@ Vagrant.configure('2') do |config| config.vm.box = 'bento/ubuntu-23.04' config.vm.provision "shell", inline: <<-SHELL + sudo apt install -y kitty-terminfo SHELL end diff --git a/test/ansible/vagrant/ubuntu-24-04-noble/Vagrantfile b/test/ansible/vagrant/ubuntu-24-04-noble/Vagrantfile new file mode 100644 index 00000000000..52490900fd8 --- /dev/null +++ b/test/ansible/vagrant/ubuntu-24-04-noble/Vagrantfile @@ -0,0 +1,10 @@ +# frozen_string_literal: true + +Vagrant.configure('2') do |config| + config.vm.box = 'alvistack/ubuntu-24.04' + config.vm.provision "shell", inline: <<-SHELL + SHELL +end + +common = '../common' +load common if File.exist?(common) diff --git a/test/ansible/vagrant/wizard/centos-8/Vagrantfile b/test/ansible/vagrant/wizard/centos-8/Vagrantfile index 9db09a4ce01..4b469ad65dc 100644 --- a/test/ansible/vagrant/wizard/centos-8/Vagrantfile +++ b/test/ansible/vagrant/wizard/centos-8/Vagrantfile @@ -10,4 +10,4 @@ Vagrant.configure('2') do |config| end common = '../common' -load common if File.exists?(common) +load common if File.exist?(common) diff --git a/test/ansible/vagrant/wizard/common b/test/ansible/vagrant/wizard/common index be1820914c2..fda6d50f4fc 100644 --- a/test/ansible/vagrant/wizard/common +++ b/test/ansible/vagrant/wizard/common @@ -21,7 +21,7 @@ Vagrant.configure('2') do |config| end end - config.vm.provision 'shell', path: 'bootstrap' if File.exists?('bootstrap') + config.vm.provision 'shell', path: 'bootstrap' if File.exist?('bootstrap') config.vm.synced_folder '.', '/vagrant', disabled: true config.vm.provider :libvirt do |libvirt| diff --git a/test/ansible/vagrant/wizard/debian-10-buster/Vagrantfile b/test/ansible/vagrant/wizard/debian-10-buster/Vagrantfile index 3b10b312d0d..9602acb698c 100644 --- a/test/ansible/vagrant/wizard/debian-10-buster/Vagrantfile +++ b/test/ansible/vagrant/wizard/debian-10-buster/Vagrantfile @@ -9,4 +9,4 @@ Vagrant.configure('2') do |config| end common = '../common' -load common if File.exists?(common) +load common if File.exist?(common) diff --git a/test/ansible/vagrant/wizard/debian-11-bullseye/Vagrantfile b/test/ansible/vagrant/wizard/debian-11-bullseye/Vagrantfile index 6dd7bb2fc9c..9184fb67639 100644 --- a/test/ansible/vagrant/wizard/debian-11-bullseye/Vagrantfile +++ b/test/ansible/vagrant/wizard/debian-11-bullseye/Vagrantfile @@ -9,4 +9,4 @@ Vagrant.configure('2') do |config| end common = '../common' -load common if File.exists?(common) +load common if File.exist?(common) diff --git a/test/ansible/vagrant/wizard/debian-12-bookworm/Vagrantfile b/test/ansible/vagrant/wizard/debian-12-bookworm/Vagrantfile index 5ccf234eb3e..1a0a43eb26f 100644 --- a/test/ansible/vagrant/wizard/debian-12-bookworm/Vagrantfile +++ b/test/ansible/vagrant/wizard/debian-12-bookworm/Vagrantfile @@ -9,4 +9,4 @@ Vagrant.configure('2') do |config| end common = '../common' -load common if File.exists?(common) +load common if File.exist?(common) diff --git a/test/ansible/vagrant/wizard/fedora-36/Vagrantfile b/test/ansible/vagrant/wizard/fedora-36/Vagrantfile index 969a8e70c87..ac9a0319e4b 100644 --- a/test/ansible/vagrant/wizard/fedora-36/Vagrantfile +++ b/test/ansible/vagrant/wizard/fedora-36/Vagrantfile @@ -8,4 +8,4 @@ Vagrant.configure('2') do |config| end common = '../common' -load common if File.exists?(common) +load common if File.exist?(common) diff --git a/test/ansible/vagrant/wizard/ubuntu-22.04-jammy/Vagrantfile b/test/ansible/vagrant/wizard/ubuntu-22.04-jammy/Vagrantfile index c13d2f9468e..f1ebf43a025 100644 --- a/test/ansible/vagrant/wizard/ubuntu-22.04-jammy/Vagrantfile +++ b/test/ansible/vagrant/wizard/ubuntu-22.04-jammy/Vagrantfile @@ -3,9 +3,9 @@ Vagrant.configure('2') do |config| config.vm.box = 'generic/ubuntu2204' config.vm.provision "shell", inline: <<-SHELL - sudo apt install -y aptitude kitty-terminfo + sudo env DEBIAN_FRONTEND=noninteractive apt install -y aptitude kitty-terminfo SHELL end common = '../common' -load common if File.exists?(common) +load common if File.exist?(common) diff --git a/test/ansible/vagrant/wizard/ubuntu-22.10-kinetic/Vagrantfile b/test/ansible/vagrant/wizard/ubuntu-22.10-kinetic/Vagrantfile index d0e2e3cdaa8..5875587eeb4 100644 --- a/test/ansible/vagrant/wizard/ubuntu-22.10-kinetic/Vagrantfile +++ b/test/ansible/vagrant/wizard/ubuntu-22.10-kinetic/Vagrantfile @@ -8,4 +8,4 @@ Vagrant.configure('2') do |config| end common = '../common' -load common if File.exists?(common) +load common if File.exist?(common) diff --git a/test/ansible/vars/go.yml b/test/ansible/vars/go.yml index 0f60356b167..a735b954349 100644 --- a/test/ansible/vars/go.yml +++ b/test/ansible/vars/go.yml @@ -1,5 +1,5 @@ # vim: set ft=yaml.ansible: --- -golang_version: "1.20.6" +golang_version: "1.21.4" golang_install_dir: "/opt/go/{{ golang_version }}" diff --git a/test/ansible/vars/python.yml b/test/ansible/vars/python.yml new file mode 100644 index 00000000000..0cafdcc3d4c --- /dev/null +++ b/test/ansible/vars/python.yml @@ -0,0 +1 @@ +python_version: "3.12.3" diff --git a/test/bats-detect/proftpd-deb.bats b/test/bats-detect/proftpd-deb.bats index b21ea466d8d..fce556cafee 100644 --- a/test/bats-detect/proftpd-deb.bats +++ b/test/bats-detect/proftpd-deb.bats @@ -10,7 +10,8 @@ setup_file() { teardown_file() { load "../lib/teardown_file.sh" - deb-remove proftpd + systemctl stop proftpd.service || : + deb-remove proftpd proftpd-core } setup() { @@ -32,6 +33,7 @@ setup() { @test "proftpd: install" { run -0 deb-install proftpd + run -0 sudo systemctl unmask proftpd.service run -0 sudo systemctl enable proftpd.service } diff --git a/test/bats.mk b/test/bats.mk index 259bef379a9..72ac8863f72 100644 --- a/test/bats.mk +++ b/test/bats.mk @@ -38,6 +38,7 @@ define ENV := export TEST_DIR="$(TEST_DIR)" export LOCAL_DIR="$(LOCAL_DIR)" export BIN_DIR="$(BIN_DIR)" +# append .min to the binary names to use the minimal profile export CROWDSEC="$(CROWDSEC)" export CSCLI="$(CSCLI)" export CONFIG_YAML="$(CONFIG_DIR)/config.yaml" @@ -62,55 +63,55 @@ bats-environment: export ENV:=$(ENV) bats-environment: @echo "$${ENV}" > $(TEST_DIR)/.environment.sh -# Verify dependencies and submodules -bats-check-requirements: +bats-check-requirements: ## Check dependencies for functional tests @$(TEST_DIR)/bin/check-requirements -# Install/update some of the tools required to run the tests -bats-update-tools: - # yq v4.34.1 - GOBIN=$(TEST_DIR)/tools go install github.com/mikefarah/yq/v4@5ef537f3fd1a9437aa3ee44c32c6459a126efdc4 - # cfssl v1.6.4 - GOBIN=$(TEST_DIR)/tools go install github.com/cloudflare/cfssl/cmd/cfssl@b4d0d877cac528f63db39dfb62d5c96cd3a32a0b - GOBIN=$(TEST_DIR)/tools go install github.com/cloudflare/cfssl/cmd/cfssljson@b4d0d877cac528f63db39dfb62d5c96cd3a32a0b +bats-update-tools: ## Install/update tools required for functional tests + # yq v4.44.3 + GOBIN=$(TEST_DIR)/tools go install github.com/mikefarah/yq/v4@bbdd97482f2d439126582a59689eb1c855944955 + # cfssl v1.6.5 + GOBIN=$(TEST_DIR)/tools go install github.com/cloudflare/cfssl/cmd/cfssl@96259aa29c9cc9b2f4e04bad7d4bc152e5405dda + GOBIN=$(TEST_DIR)/tools go install github.com/cloudflare/cfssl/cmd/cfssljson@96259aa29c9cc9b2f4e04bad7d4bc152e5405dda # Build and installs crowdsec in a local directory. Rebuilds if already exists. -bats-build: bats-environment +bats-build: bats-environment ## Build binaries for functional tests @$(MKDIR) $(BIN_DIR) $(LOG_DIR) $(PID_DIR) $(BATS_PLUGIN_DIR) - @TEST_COVERAGE=$(TEST_COVERAGE) DEFAULT_CONFIGDIR=$(CONFIG_DIR) DEFAULT_DATADIR=$(DATA_DIR) $(MAKE) build + # minimal profile + @$(MAKE) build DEBUG=1 TEST_COVERAGE=$(TEST_COVERAGE) DEFAULT_CONFIGDIR=$(CONFIG_DIR) DEFAULT_DATADIR=$(DATA_DIR) BUILD_PROFILE=minimal + @install -m 0755 cmd/crowdsec/crowdsec $(BIN_DIR)/crowdsec.min + @install -m 0755 cmd/crowdsec-cli/cscli $(BIN_DIR)/cscli.min + # default profile + @$(MAKE) build DEBUG=1 TEST_COVERAGE=$(TEST_COVERAGE) DEFAULT_CONFIGDIR=$(CONFIG_DIR) DEFAULT_DATADIR=$(DATA_DIR) @install -m 0755 cmd/crowdsec/crowdsec cmd/crowdsec-cli/cscli $(BIN_DIR)/ - @install -m 0755 plugins/notifications/*/notification-* $(BATS_PLUGIN_DIR)/ + @install -m 0755 cmd/notification-*/notification-* $(BATS_PLUGIN_DIR)/ # Create a reusable package with initial configuration + data -bats-fixture: bats-check-requirements bats-update-tools - @echo "Creating functional test fixture..." +bats-fixture: bats-check-requirements bats-update-tools ## Build fixture for functional tests + @echo "Creating functional test fixture." @$(TEST_DIR)/instance-data make # Remove the local crowdsec installation and the fixture config + data # Don't remove LOCAL_DIR directly because it could be / or anything else outside the repo -bats-clean: +bats-clean: ## Remove functional test environment @$(RM) $(TEST_DIR)/local $(WIN_IGNORE_ERR) @$(RM) $(LOCAL_INIT_DIR) $(WIN_IGNORE_ERR) @$(RM) $(TEST_DIR)/dyn-bats/*.bats $(WIN_IGNORE_ERR) @$(RM) test/.environment.sh $(WIN_IGNORE_ERR) @$(RM) test/coverage/* $(WIN_IGNORE_ERR) -# Run the test suite -bats-test: bats-environment +bats-test: bats-environment ## Run functional tests $(TEST_DIR)/run-tests $(TEST_DIR)/bats -# Generate dynamic tests -bats-test-hub: bats-environment bats-check-requirements +bats-test-hub: bats-environment bats-check-requirements ## Run all hub tests @$(TEST_DIR)/bin/generate-hub-tests $(TEST_DIR)/run-tests $(TEST_DIR)/dyn-bats -# Static checks for the test scripts. # Not failproof but they can catch bugs and improve learning of sh/bash -bats-lint: +bats-lint: ## Static checks for the test scripts. @shellcheck --version >/dev/null 2>&1 || (echo "ERROR: shellcheck is required."; exit 1) @shellcheck -x $(TEST_DIR)/bats/*.bats -bats-test-package: bats-environment +bats-test-package: bats-environment ## CI only - test a binary package (deb, rpm, ...) $(TEST_DIR)/instance-data make $(TEST_DIR)/run-tests $(TEST_DIR)/bats $(TEST_DIR)/run-tests $(TEST_DIR)/dyn-bats diff --git a/test/bats/00_wait_for.bats b/test/bats/00_wait_for.bats new file mode 100644 index 00000000000..94c65033bb4 --- /dev/null +++ b/test/bats/00_wait_for.bats @@ -0,0 +1,70 @@ +#!/usr/bin/env bats +# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si: + +set -u + +setup_file() { + load "../lib/setup_file.sh" +} + +setup() { + load "../lib/setup.sh" +} + +@test "run a command and capture its stdout" { + run -0 wait-for seq 1 3 + assert_output - <<-EOT + 1 + 2 + 3 + EOT +} + +@test "run a command and capture its stderr" { + rune -0 wait-for sh -c 'seq 1 3 >&2' + assert_stderr - <<-EOT + 1 + 2 + 3 + EOT +} + +@test "run a command until a pattern is found in stdout" { + run -0 wait-for --out "1[12]0" seq 1 200 + assert_line --index 0 "1" + assert_line --index -1 "110" + refute_line "111" +} + +@test "run a command until a pattern is found in stderr" { + rune -0 wait-for --err "10" sh -c 'seq 1 20 >&2' + assert_stderr - <<-EOT + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + EOT +} + +@test "run a command with timeout (no match)" { + # when the process is terminated without a match, it returns + # 256 - 15 (SIGTERM) = 241 + rune -241 wait-for --timeout 0.1 --out "10" sh -c 'echo 1; sleep 3; echo 2' + assert_line 1 + # there may be more, but we don't care +} + +@test "run a command with timeout (match)" { + # when the process is terminated with a match, return code is 128 + rune -128 wait-for --timeout .4 --out "2" sh -c 'echo 1; sleep .1; echo 2; echo 3; echo 4; sleep 10' + assert_output - <<-EOT + 1 + 2 + EOT +} diff --git a/test/bats/01_crowdsec.bats b/test/bats/01_crowdsec.bats index a1a2861f6df..aa5830a6bae 100644 --- a/test/bats/01_crowdsec.bats +++ b/test/bats/01_crowdsec.bats @@ -24,71 +24,75 @@ teardown() { #---------- @test "crowdsec (usage)" { - rune -0 timeout 2s "${CROWDSEC}" -h - assert_stderr_line --regexp "Usage of .*:" - - rune -0 timeout 2s "${CROWDSEC}" --help - assert_stderr_line --regexp "Usage of .*:" + rune -0 wait-for --out "Usage of " "$CROWDSEC" -h + rune -0 wait-for --out "Usage of " "$CROWDSEC" --help } @test "crowdsec (unknown flag)" { - rune -2 timeout 2s "${CROWDSEC}" --foobar - assert_stderr_line "flag provided but not defined: -foobar" - assert_stderr_line --regexp "Usage of .*" + rune -0 wait-for --err "flag provided but not defined: -foobar" "$CROWDSEC" --foobar } @test "crowdsec (unknown argument)" { - rune -2 timeout 2s "${CROWDSEC}" trololo - assert_stderr_line "argument provided but not defined: trololo" - assert_stderr_line --regexp "Usage of .*" + rune -0 wait-for --err "argument provided but not defined: trololo" "$CROWDSEC" trololo +} + +@test "crowdsec -version" { + rune -0 "$CROWDSEC" -version + assert_output --partial "version:" } @test "crowdsec (no api and no agent)" { - rune -1 timeout 2s "${CROWDSEC}" -no-api -no-cs - assert_stderr_line --partial "You must run at least the API Server or crowdsec" + rune -0 wait-for \ + --err "you must run at least the API Server or crowdsec" \ + "$CROWDSEC" -no-api -no-cs } @test "crowdsec - print error on exit" { # errors that cause program termination are printed to stderr, not only logs config_set '.db_config.type="meh"' - rune -1 "${CROWDSEC}" + rune -1 "$CROWDSEC" assert_stderr --partial "unable to create database client: unknown database type 'meh'" } -@test "crowdsec - bad configuration (empty/missing common section)" { +@test "crowdsec - default logging configuration (empty/missing common section)" { config_set '.common={}' - rune -1 "${CROWDSEC}" + rune -0 wait-for \ + --err "Starting processing data" \ + "$CROWDSEC" refute_output - assert_stderr --partial "unable to load configuration: common section is empty" config_set 'del(.common)' - rune -1 "${CROWDSEC}" + rune -0 wait-for \ + --err "Starting processing data" \ + "$CROWDSEC" refute_output - assert_stderr --partial "unable to load configuration: common section is empty" } @test "CS_LAPI_SECRET not strong enough" { - CS_LAPI_SECRET=foo rune -1 timeout 2s "${CROWDSEC}" + CS_LAPI_SECRET=foo rune -1 wait-for "$CROWDSEC" assert_stderr --partial "api server init: unable to run local API: controller init: CS_LAPI_SECRET not strong enough" } @test "crowdsec - reload (change of logfile, disabled agent)" { - logdir1=$(TMPDIR="${BATS_TEST_TMPDIR}" mktemp -u) + logdir1=$(TMPDIR="$BATS_TEST_TMPDIR" mktemp -u) log_old="${logdir1}/crowdsec.log" config_set ".common.log_dir=\"${logdir1}\"" rune -0 ./instance-crowdsec start-pid PID="$output" - assert_file_exist "$log_old" + + sleep .5 + + assert_file_exists "$log_old" assert_file_contains "$log_old" "Starting processing data" - logdir2=$(TMPDIR="${BATS_TEST_TMPDIR}" mktemp -u) + logdir2=$(TMPDIR="$BATS_TEST_TMPDIR" mktemp -u) log_new="${logdir2}/crowdsec.log" config_set ".common.log_dir=\"${logdir2}\"" config_disable_agent - sleep 5 + sleep 2 rune -0 kill -HUP "$PID" @@ -111,13 +115,13 @@ teardown() { assert_file_contains "$log_old" "Bucket routine exiting" assert_file_contains "$log_old" "serve: shutting down api server" - sleep 5 + sleep 2 - assert_file_exist "$log_new" + assert_file_exists "$log_new" for ((i=0; i<10; i++)); do sleep 1 - grep -q "Reload is finished" <"$log_old" && break + grep -q "Reload is finished" <"$log_new" && break done echo "waited $i seconds" @@ -138,8 +142,8 @@ teardown() { ACQUIS_YAML=$(config_get '.crowdsec_service.acquisition_path') rm -f "$ACQUIS_YAML" - rune -1 timeout 2s "${CROWDSEC}" - assert_stderr_line --partial "acquis.yaml: no such file or directory" + rune -1 wait-for "$CROWDSEC" + assert_stderr --partial "acquis.yaml: no such file or directory" } @test "crowdsec (error if acquisition_path is not defined and acquisition_dir is empty)" { @@ -148,10 +152,10 @@ teardown() { config_set '.crowdsec_service.acquisition_path=""' ACQUIS_DIR=$(config_get '.crowdsec_service.acquisition_dir') - rm -f "$ACQUIS_DIR" + rm -rf "$ACQUIS_DIR" config_set '.common.log_media="stdout"' - rune -1 timeout 2s "${CROWDSEC}" + rune -1 wait-for "$CROWDSEC" # check warning assert_stderr --partial "no acquisition file found" assert_stderr --partial "crowdsec init: while loading acquisition config: no datasource enabled" @@ -163,17 +167,19 @@ teardown() { config_set '.crowdsec_service.acquisition_path=""' ACQUIS_DIR=$(config_get '.crowdsec_service.acquisition_dir') - rm -f "$ACQUIS_DIR" + rm -rf "$ACQUIS_DIR" config_set '.crowdsec_service.acquisition_dir=""' config_set '.common.log_media="stdout"' - rune -1 timeout 2s "${CROWDSEC}" + rune -1 wait-for "$CROWDSEC" # check warning assert_stderr --partial "no acquisition_path or acquisition_dir specified" assert_stderr --partial "crowdsec init: while loading acquisition config: no datasource enabled" } @test "crowdsec (no error if acquisition_path is empty string but acquisition_dir is not empty)" { + config_set '.common.log_media="stdout"' + ACQUIS_YAML=$(config_get '.crowdsec_service.acquisition_path') config_set '.crowdsec_service.acquisition_path=""' @@ -181,17 +187,60 @@ teardown() { mkdir -p "$ACQUIS_DIR" mv "$ACQUIS_YAML" "$ACQUIS_DIR"/foo.yaml - rune -124 timeout 2s "${CROWDSEC}" + rune -0 wait-for \ + --err "Starting processing data" \ + "$CROWDSEC" # now, if foo.yaml is empty instead, there won't be valid datasources. cat /dev/null >"$ACQUIS_DIR"/foo.yaml - rune -1 timeout 2s "${CROWDSEC}" + rune -1 wait-for "$CROWDSEC" assert_stderr --partial "crowdsec init: while loading acquisition config: no datasource enabled" } -@test "crowdsec (disabled datasources)" { +@test "crowdsec (datasource not built)" { + config_set '.common.log_media="stdout"' + + # a datasource cannot run - it's not built in the log processor executable + + ACQUIS_DIR=$(config_get '.crowdsec_service.acquisition_dir') + mkdir -p "$ACQUIS_DIR" + cat >"$ACQUIS_DIR"/foo.yaml <<-EOT + source: journalctl + journalctl_filter: + - "_SYSTEMD_UNIT=ssh.service" + labels: + type: syslog + EOT + + #shellcheck disable=SC2016 + rune -1 wait-for \ + --err "crowdsec init: while loading acquisition config: in file $ACQUIS_DIR/foo.yaml (position: 0) - data source journalctl is not built in this version of crowdsec" \ + env PATH='' "$CROWDSEC".min + + # auto-detection of journalctl_filter still works + cat >"$ACQUIS_DIR"/foo.yaml <<-EOT + source: whatever + journalctl_filter: + - "_SYSTEMD_UNIT=ssh.service" + labels: + type: syslog + EOT + + #shellcheck disable=SC2016 + rune -1 wait-for \ + --err "crowdsec init: while loading acquisition config: in file $ACQUIS_DIR/foo.yaml (position: 0) - data source journalctl is not built in this version of crowdsec" \ + env PATH='' "$CROWDSEC".min +} + +@test "crowdsec (disabled datasource)" { + if is_package_testing; then + # we can't hide journalctl in package testing + # because crowdsec is run from systemd + skip "n/a for package testing" + fi + config_set '.common.log_media="stdout"' # a datasource cannot run - missing journalctl command @@ -206,9 +255,10 @@ teardown() { type: syslog EOT - rune -124 timeout 2s env PATH='' "${CROWDSEC}" #shellcheck disable=SC2016 - assert_stderr --partial 'datasource '\''journalctl'\'' is not available: exec: "journalctl": executable file not found in $PATH' + rune -0 wait-for \ + --err 'datasource '\''journalctl'\'' is not available: exec: \\"journalctl\\": executable file not found in ' \ + env PATH='' "$CROWDSEC" # if all datasources are disabled, crowdsec should exit @@ -216,7 +266,22 @@ teardown() { rm -f "$ACQUIS_YAML" config_set '.crowdsec_service.acquisition_path=""' - rune -1 timeout 2s env PATH='' "${CROWDSEC}" + rune -1 wait-for env PATH='' "$CROWDSEC" assert_stderr --partial "crowdsec init: while loading acquisition config: no datasource enabled" } +@test "crowdsec -t (error in acquisition file)" { + # we can verify the acquisition configuration without running crowdsec + ACQUIS_YAML=$(config_get '.crowdsec_service.acquisition_path') + config_set "$ACQUIS_YAML" 'del(.filenames)' + + # if filenames are missing, it won't be able to detect source type + config_set "$ACQUIS_YAML" '.source="file"' + rune -1 wait-for "$CROWDSEC" + assert_stderr --partial "failed to configure datasource file: no filename or filenames configuration provided" + + config_set "$ACQUIS_YAML" '.filenames=["file.log"]' + config_set "$ACQUIS_YAML" '.meh=3' + rune -1 wait-for "$CROWDSEC" + assert_stderr --partial "field meh not found in type fileacquisition.FileConfiguration" +} diff --git a/test/bats/01_crowdsec_lapi.bats b/test/bats/01_crowdsec_lapi.bats new file mode 100644 index 00000000000..21e1d7a093e --- /dev/null +++ b/test/bats/01_crowdsec_lapi.bats @@ -0,0 +1,50 @@ +#!/usr/bin/env bats +# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si: + +set -u + +setup_file() { + load "../lib/setup_file.sh" +} + +teardown_file() { + load "../lib/teardown_file.sh" +} + +setup() { + load "../lib/setup.sh" + load "../lib/bats-file/load.bash" + ./instance-data load +} + +teardown() { + ./instance-crowdsec stop +} + +#---------- + +# Tests for LAPI configuration and startup + +@test "lapi (.api.server.enable=false)" { + rune -0 config_set '.api.server.enable=false' + rune -1 "$CROWDSEC" -no-cs + assert_stderr --partial "you must run at least the API Server or crowdsec" +} + +@test "lapi (no .api.server.listen_uri)" { + rune -0 config_set 'del(.api.server.listen_socket) | del(.api.server.listen_uri)' + rune -1 "$CROWDSEC" -no-cs + assert_stderr --partial "no listen_uri or listen_socket specified" +} + +@test "lapi (bad .api.server.listen_uri)" { + rune -0 config_set 'del(.api.server.listen_socket) | .api.server.listen_uri="127.0.0.1:-80"' + rune -1 "$CROWDSEC" -no-cs + assert_stderr --partial "local API server stopped with error: listening on 127.0.0.1:-80: listen tcp: address -80: invalid port" +} + +@test "lapi (listen on random port)" { + config_set '.common.log_media="stdout"' + rune -0 config_set 'del(.api.server.listen_socket) | .api.server.listen_uri="127.0.0.1:0"' + rune -0 wait-for --err "CrowdSec Local API listening on 127.0.0.1:" "$CROWDSEC" -no-cs +} diff --git a/test/bats/01_cscli.bats b/test/bats/01_cscli.bats index 0664c5691af..264870501a5 100644 --- a/test/bats/01_cscli.bats +++ b/test/bats/01_cscli.bats @@ -15,7 +15,7 @@ setup() { load "../lib/setup.sh" load "../lib/bats-file/load.bash" ./instance-data load - ./instance-crowdsec start + # don't run crowdsec here, not all tests require a running instance } teardown() { @@ -40,20 +40,20 @@ teardown() { @test "cscli version" { rune -0 cscli version - assert_stderr --partial "version:" - assert_stderr --partial "Codename:" - assert_stderr --partial "BuildDate:" - assert_stderr --partial "GoVersion:" - assert_stderr --partial "Platform:" - assert_stderr --partial "Constraint_parser:" - assert_stderr --partial "Constraint_scenario:" - assert_stderr --partial "Constraint_api:" - assert_stderr --partial "Constraint_acquis:" + assert_output --partial "version:" + assert_output --partial "Codename:" + assert_output --partial "BuildDate:" + assert_output --partial "GoVersion:" + assert_output --partial "Platform:" + assert_output --partial "Constraint_parser:" + assert_output --partial "Constraint_scenario:" + assert_output --partial "Constraint_api:" + assert_output --partial "Constraint_acquis:" # should work without configuration file - rm "${CONFIG_YAML}" + rm "$CONFIG_YAML" rune -0 cscli version - assert_stderr --partial "version:" + assert_output --partial "version:" } @test "cscli help" { @@ -62,7 +62,7 @@ teardown() { assert_line --regexp ".* help .* Help about any command" # should work without configuration file - rm "${CONFIG_YAML}" + rm "$CONFIG_YAML" rune -0 cscli help assert_line "Available Commands:" } @@ -100,14 +100,65 @@ teardown() { # check that LAPI configuration is loaded (human and json, not shows in raw) + sock=$(config_get '.api.server.listen_socket') + rune -0 cscli config show -o human assert_line --regexp ".*- URL +: http://127.0.0.1:8080/" - assert_line --regexp ".*- Login +: githubciXXXXXXXXXXXXXXXXXXXXXXXX" + assert_line --regexp ".*- Login +: githubciXXXXXXXXXXXXXXXXXXXXXXXX([a-zA-Z0-9]{16})?" assert_line --regexp ".*- Credentials File +: .*/local_api_credentials.yaml" + assert_line --regexp ".*- Listen URL +: 127.0.0.1:8080" + assert_line --regexp ".*- Listen Socket +: $sock" rune -0 cscli config show -o json - rune -0 jq -c '.API.Client.Credentials | [.url,.login]' <(output) - assert_output '["http://127.0.0.1:8080/","githubciXXXXXXXXXXXXXXXXXXXXXXXX"]' + rune -0 jq -c '.API.Client.Credentials | [.url,.login[0:32]]' <(output) + assert_json '["http://127.0.0.1:8080/","githubciXXXXXXXXXXXXXXXXXXXXXXXX"]' + + # pointer to boolean + + rune -0 cscli config show --key Config.API.Client.InsecureSkipVerify + assert_output "&false" + + # complex type + rune -0 cscli config show --key Config.Prometheus + assert_output - <<-EOT + &csconfig.PrometheusCfg{ + Enabled: true, + Level: "full", + ListenAddr: "127.0.0.1", + ListenPort: 6060, + } + EOT +} + +@test "cscli - required configuration paths" { + config=$(cat "$CONFIG_YAML") + configdir=$(config_get '.config_paths.config_dir') + + # required configuration paths with no defaults + + config_set 'del(.config_paths)' + rune -1 cscli hub list + assert_stderr --partial 'no configuration paths provided' + echo "$config" > "$CONFIG_YAML" + + config_set 'del(.config_paths.data_dir)' + rune -1 cscli hub list + assert_stderr --partial "please provide a data directory with the 'data_dir' directive in the 'config_paths' section" + echo "$config" > "$CONFIG_YAML" + + # defaults + + config_set 'del(.config_paths.hub_dir)' + rune -0 cscli hub list + rune -0 cscli config show --key Config.ConfigPaths.HubDir + assert_output "$configdir/hub" + echo "$config" > "$CONFIG_YAML" + + config_set 'del(.config_paths.index_path)' + rune -0 cscli hub list + rune -0 cscli config show --key Config.ConfigPaths.HubIndexFile + assert_output "$configdir/hub/.index.json" + echo "$config" > "$CONFIG_YAML" } @test "cscli config show-yaml" { @@ -130,106 +181,34 @@ teardown() { assert_stderr --partial "failed to backup config: while creating /dev/null/blah: mkdir /dev/null/blah: not a directory" # pick a dirpath - backupdir=$(TMPDIR="${BATS_TEST_TMPDIR}" mktemp -u) + backupdir=$(TMPDIR="$BATS_TEST_TMPDIR" mktemp -u) # succeed the first time - rune -0 cscli config backup "${backupdir}" + rune -0 cscli config backup "$backupdir" assert_stderr --partial "Starting configuration backup" # don't overwrite an existing backup - rune -1 cscli config backup "${backupdir}" + rune -1 cscli config backup "$backupdir" assert_stderr --partial "failed to backup config" assert_stderr --partial "file exists" SIMULATION_YAML="$(config_get '.config_paths.simulation_path')" # restore - rm "${SIMULATION_YAML}" - rune -0 cscli config restore "${backupdir}" - assert_file_exist "${SIMULATION_YAML}" + rm "$SIMULATION_YAML" + rune -0 cscli config restore "$backupdir" + assert_file_exists "$SIMULATION_YAML" # cleanup rm -rf -- "${backupdir:?}" # backup: detect missing files - rm "${SIMULATION_YAML}" - rune -1 cscli config backup "${backupdir}" + rm "$SIMULATION_YAML" + rune -1 cscli config backup "$backupdir" assert_stderr --regexp "failed to backup config: failed copy .* to .*: stat .*: no such file or directory" rm -rf -- "${backupdir:?}" } -@test "cscli lapi status" { - rune -0 cscli lapi status - - assert_stderr --partial "Loaded credentials from" - assert_stderr --partial "Trying to authenticate with username" - assert_stderr --partial " on http://127.0.0.1:8080/" - assert_stderr --partial "You can successfully interact with Local API (LAPI)" -} - -@test "cscli - missing LAPI credentials file" { - LOCAL_API_CREDENTIALS=$(config_get '.api.client.credentials_path') - rm -f "${LOCAL_API_CREDENTIALS}" - rune -1 cscli lapi status - assert_stderr --partial "loading api client: while reading yaml file: open ${LOCAL_API_CREDENTIALS}: no such file or directory" - - rune -1 cscli alerts list - assert_stderr --partial "loading api client: while reading yaml file: open ${LOCAL_API_CREDENTIALS}: no such file or directory" - - rune -1 cscli decisions list - assert_stderr --partial "loading api client: while reading yaml file: open ${LOCAL_API_CREDENTIALS}: no such file or directory" -} - -@test "cscli - empty LAPI credentials file" { - LOCAL_API_CREDENTIALS=$(config_get '.api.client.credentials_path') - : > "${LOCAL_API_CREDENTIALS}" - rune -1 cscli lapi status - assert_stderr --partial "no credentials or URL found in api client configuration '${LOCAL_API_CREDENTIALS}'" - - rune -1 cscli alerts list - assert_stderr --partial "no credentials or URL found in api client configuration '${LOCAL_API_CREDENTIALS}'" - - rune -1 cscli decisions list - assert_stderr --partial "no credentials or URL found in api client configuration '${LOCAL_API_CREDENTIALS}'" -} - -@test "cscli - missing LAPI client settings" { - config_set 'del(.api.client)' - rune -1 cscli lapi status - assert_stderr --partial "loading api client: no API client section in configuration" - - rune -1 cscli alerts list - assert_stderr --partial "loading api client: no API client section in configuration" - - rune -1 cscli decisions list - assert_stderr --partial "loading api client: no API client section in configuration" -} - -@test "cscli - malformed LAPI url" { - LOCAL_API_CREDENTIALS=$(config_get '.api.client.credentials_path') - config_set "${LOCAL_API_CREDENTIALS}" '.url="https://127.0.0.1:-80"' - - rune -1 cscli lapi status - assert_stderr --partial 'parsing api url' - assert_stderr --partial 'invalid port \":-80\" after host' - - rune -1 cscli alerts list - assert_stderr --partial 'parsing api url' - assert_stderr --partial 'invalid port \":-80\" after host' - - rune -1 cscli decisions list - assert_stderr --partial 'parsing api url' - assert_stderr --partial 'invalid port \":-80\" after host' -} - -@test "cscli metrics" { - rune -0 cscli lapi status - rune -0 cscli metrics - assert_output --partial "Route" - assert_output --partial '/v1/watchers/login' - assert_output --partial "Local API Metrics:" -} - @test "'cscli completion' with or without configuration file" { rune -0 cscli completion bash assert_output --partial "# bash completion for cscli" @@ -240,55 +219,29 @@ teardown() { rune -0 cscli completion fish assert_output --partial "# fish completion for cscli" - rm "${CONFIG_YAML}" + rm "$CONFIG_YAML" rune -0 cscli completion bash assert_output --partial "# bash completion for cscli" } -@test "cscli hub list" { - # we check for the presence of some objects. There may be others when we - # use $PACKAGE_TESTING, so the order is not important. - - rune -0 cscli hub list -o human - assert_line --regexp '^ crowdsecurity/linux' - assert_line --regexp '^ crowdsecurity/sshd' - assert_line --regexp '^ crowdsecurity/dateparse-enrich' - assert_line --regexp '^ crowdsecurity/geoip-enrich' - assert_line --regexp '^ crowdsecurity/sshd-logs' - assert_line --regexp '^ crowdsecurity/syslog-logs' - assert_line --regexp '^ crowdsecurity/ssh-bf' - assert_line --regexp '^ crowdsecurity/ssh-slow-bf' - - rune -0 cscli hub list -o raw - assert_line --regexp '^crowdsecurity/linux,enabled,[0-9]+\.[0-9]+,core linux support : syslog\+geoip\+ssh,collections$' - assert_line --regexp '^crowdsecurity/sshd,enabled,[0-9]+\.[0-9]+,sshd support : parser and brute-force detection,collections$' - assert_line --regexp '^crowdsecurity/dateparse-enrich,enabled,[0-9]+\.[0-9]+,,parsers$' - assert_line --regexp '^crowdsecurity/geoip-enrich,enabled,[0-9]+\.[0-9]+,"Populate event with geoloc info : as, country, coords, source range.",parsers$' - assert_line --regexp '^crowdsecurity/sshd-logs,enabled,[0-9]+\.[0-9]+,Parse openSSH logs,parsers$' - assert_line --regexp '^crowdsecurity/syslog-logs,enabled,[0-9]+\.[0-9]+,,parsers$' - assert_line --regexp '^crowdsecurity/ssh-bf,enabled,[0-9]+\.[0-9]+,Detect ssh bruteforce,scenarios$' - assert_line --regexp '^crowdsecurity/ssh-slow-bf,enabled,[0-9]+\.[0-9]+,Detect slow ssh bruteforce,scenarios$' - - rune -0 cscli hub list -o json - rune -0 jq -r '.collections[].name, .parsers[].name, .scenarios[].name' <(output) - assert_line 'crowdsecurity/linux' - assert_line 'crowdsecurity/sshd' - assert_line 'crowdsecurity/dateparse-enrich' - assert_line 'crowdsecurity/geoip-enrich' - assert_line 'crowdsecurity/sshd-logs' - assert_line 'crowdsecurity/syslog-logs' - assert_line 'crowdsecurity/ssh-bf' - assert_line 'crowdsecurity/ssh-slow-bf' -} - @test "cscli support dump (smoke test)" { rune -0 cscli support dump -f "$BATS_TEST_TMPDIR"/dump.zip - assert_file_exist "$BATS_TEST_TMPDIR"/dump.zip + assert_file_exists "$BATS_TEST_TMPDIR"/dump.zip } @test "cscli explain" { - rune -0 cscli explain --log "Sep 19 18:33:22 scw-d95986 sshd[24347]: pam_unix(sshd:auth): authentication failure; logname= uid=0 euid=0 tty=ssh ruser= rhost=1.2.3.4" --type syslog --crowdsec "$CROWDSEC" + rune -0 ./instance-crowdsec start + line="Sep 19 18:33:22 scw-d95986 sshd[24347]: pam_unix(sshd:auth): authentication failure; logname= uid=0 euid=0 tty=ssh ruser= rhost=1.2.3.4" + + rune -0 cscli parsers install crowdsecurity/syslog-logs + rune -0 cscli collections install crowdsecurity/sshd + + rune -0 cscli explain --log "$line" --type syslog --only-successful-parsers --crowdsec "$CROWDSEC" assert_output - <"$BATS_TEST_DIRNAME"/testdata/explain/explain-log.txt + + rune -0 cscli parsers remove --all --purge + rune -1 cscli explain --log "$line" --type syslog --crowdsec "$CROWDSEC" + assert_stderr --partial "unable to load parser dump result: no parser found. Please install the appropriate parser and retry" } @test 'Allow variable expansion and literal $ characters in passwords' { @@ -310,25 +263,31 @@ teardown() { } @test "cscli doc" { - # generating documentation requires a directory named "doc" - cd "$BATS_TEST_TMPDIR" rune -1 cscli doc refute_output - assert_stderr --regexp 'Failed to generate cobra doc: open doc/.*: no such file or directory' + assert_stderr --regexp 'failed to generate cscli documentation: open doc/.*: no such file or directory' mkdir -p doc rune -0 cscli doc - refute_output + assert_output "Documentation generated in ./doc" refute_stderr - assert_file_exist "doc/cscli.md" + assert_file_exists "doc/cscli.md" assert_file_not_exist "doc/cscli_setup.md" # commands guarded by feature flags are not documented unless the feature flag is set export CROWDSEC_FEATURE_CSCLI_SETUP="true" rune -0 cscli doc - assert_file_exist "doc/cscli_setup.md" + assert_file_exists "doc/cscli_setup.md" + + # specify a target directory + mkdir -p "$BATS_TEST_TMPDIR/doc2" + rune -0 cscli doc --target "$BATS_TEST_TMPDIR/doc2" + assert_output "Documentation generated in $BATS_TEST_TMPDIR/doc2" + refute_stderr + assert_file_exists "$BATS_TEST_TMPDIR/doc2/cscli_setup.md" + } @test "feature.yaml for subcommands" { @@ -341,3 +300,24 @@ teardown() { rune -0 cscli setup assert_output --partial 'cscli setup [command]' } + +@test "cscli config feature-flags" { + # disabled + rune -0 cscli config feature-flags + assert_line '✗ cscli_setup: Enable cscli setup command (service detection)' + + # enabled in feature.yaml + CONFIG_DIR=$(dirname "$CONFIG_YAML") + echo ' - cscli_setup' >> "$CONFIG_DIR"/feature.yaml + rune -0 cscli config feature-flags + assert_line '✓ cscli_setup: Enable cscli setup command (service detection)' + + # enabled in environment + # shellcheck disable=SC2031 + export CROWDSEC_FEATURE_CSCLI_SETUP="true" + rune -0 cscli config feature-flags + assert_line '✓ cscli_setup: Enable cscli setup command (service detection)' + + # there are no retired features + rune -0 cscli config feature-flags --retired +} diff --git a/test/bats/01_cscli_lapi.bats b/test/bats/01_cscli_lapi.bats new file mode 100644 index 00000000000..6e876576a6e --- /dev/null +++ b/test/bats/01_cscli_lapi.bats @@ -0,0 +1,213 @@ +#!/usr/bin/env bats +# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si: + +set -u + +setup_file() { + load "../lib/setup_file.sh" +} + +teardown_file() { + load "../lib/teardown_file.sh" +} + +setup() { + load "../lib/setup.sh" + load "../lib/bats-file/load.bash" + ./instance-data load + # don't run crowdsec here, not all tests require a running instance +} + +teardown() { + cd "$TEST_DIR" || exit 1 + ./instance-crowdsec stop +} + +#---------- + +@test "cscli lapi status" { + rune -0 ./instance-crowdsec start + rune -0 cscli lapi status + + assert_output --partial "Loaded credentials from" + assert_output --partial "Trying to authenticate with username" + assert_output --partial "You can successfully interact with Local API (LAPI)" +} + +@test "cscli - missing LAPI credentials file" { + LOCAL_API_CREDENTIALS=$(config_get '.api.client.credentials_path') + rm -f "$LOCAL_API_CREDENTIALS" + rune -1 cscli lapi status + assert_stderr --partial "loading api client: while reading yaml file: open $LOCAL_API_CREDENTIALS: no such file or directory" + + rune -1 cscli alerts list + assert_stderr --partial "loading api client: while reading yaml file: open $LOCAL_API_CREDENTIALS: no such file or directory" + + rune -1 cscli decisions list + assert_stderr --partial "loading api client: while reading yaml file: open $LOCAL_API_CREDENTIALS: no such file or directory" +} + +@test "cscli - empty LAPI credentials file" { + LOCAL_API_CREDENTIALS=$(config_get '.api.client.credentials_path') + : > "$LOCAL_API_CREDENTIALS" + rune -1 cscli lapi status + assert_stderr --partial "no credentials or URL found in api client configuration '$LOCAL_API_CREDENTIALS'" + + rune -1 cscli alerts list + assert_stderr --partial "no credentials or URL found in api client configuration '$LOCAL_API_CREDENTIALS'" + + rune -1 cscli decisions list + assert_stderr --partial "no credentials or URL found in api client configuration '$LOCAL_API_CREDENTIALS'" +} + +@test "cscli - LAPI credentials file can reference env variables" { + LOCAL_API_CREDENTIALS=$(config_get '.api.client.credentials_path') + URL=$(config_get "$LOCAL_API_CREDENTIALS" '.url') + export URL + LOGIN=$(config_get "$LOCAL_API_CREDENTIALS" '.login') + export LOGIN + PASSWORD=$(config_get "$LOCAL_API_CREDENTIALS" '.password') + export PASSWORD + + # shellcheck disable=SC2016 + echo '{"url":"$URL","login":"$LOGIN","password":"$PASSWORD"}' > "$LOCAL_API_CREDENTIALS".local + + config_set '.crowdsec_service.enable=false' + rune -0 ./instance-crowdsec start + + rune -0 cscli lapi status + assert_output --partial "You can successfully interact with Local API (LAPI)" + + rm "$LOCAL_API_CREDENTIALS".local + + # shellcheck disable=SC2016 + config_set "$LOCAL_API_CREDENTIALS" '.url="$URL"' + # shellcheck disable=SC2016 + config_set "$LOCAL_API_CREDENTIALS" '.login="$LOGIN"' + # shellcheck disable=SC2016 + config_set "$LOCAL_API_CREDENTIALS" '.password="$PASSWORD"' + + rune -0 cscli lapi status + assert_output --partial "You can successfully interact with Local API (LAPI)" + + # but if a variable is not defined, there is no specific error message + unset URL + rune -1 cscli lapi status + # shellcheck disable=SC2016 + assert_stderr --partial 'BaseURL must have a trailing slash' +} + +@test "cscli - missing LAPI client settings" { + config_set 'del(.api.client)' + rune -1 cscli lapi status + assert_stderr --partial "loading api client: no API client section in configuration" + + rune -1 cscli alerts list + assert_stderr --partial "loading api client: no API client section in configuration" + + rune -1 cscli decisions list + assert_stderr --partial "loading api client: no API client section in configuration" +} + +@test "cscli - malformed LAPI url" { + LOCAL_API_CREDENTIALS=$(config_get '.api.client.credentials_path') + config_set "$LOCAL_API_CREDENTIALS" '.url="http://127.0.0.1:-80"' + + rune -1 cscli lapi status -o json + rune -0 jq -r '.msg' <(stderr) + assert_output 'failed to authenticate to Local API (LAPI): parse "http://127.0.0.1:-80/": invalid port ":-80" after host' +} + +@test "cscli - bad LAPI password" { + rune -0 ./instance-crowdsec start + LOCAL_API_CREDENTIALS=$(config_get '.api.client.credentials_path') + config_set "$LOCAL_API_CREDENTIALS" '.password="meh"' + + rune -1 cscli lapi status -o json + rune -0 jq -r '.msg' <(stderr) + assert_output 'failed to authenticate to Local API (LAPI): API error: incorrect Username or Password' +} + +@test "cscli lapi register / machines validate" { + rune -1 cscli lapi register + assert_stderr --partial "connection refused" + + LOCAL_API_CREDENTIALS=$(config_get '.api.client.credentials_path') + + rune -0 ./instance-crowdsec start + rune -0 cscli lapi register + assert_stderr --partial "Successfully registered to Local API" + assert_stderr --partial "Local API credentials written to '$LOCAL_API_CREDENTIALS'" + assert_stderr --partial "Run 'sudo systemctl reload crowdsec' for the new configuration to be effective." + + LOGIN=$(config_get "$LOCAL_API_CREDENTIALS" '.login') + + rune -0 cscli machines inspect "$LOGIN" -o json + rune -0 jq -r '.isValidated' <(output) + assert_output "null" + + rune -0 cscli machines validate "$LOGIN" + + rune -0 cscli machines inspect "$LOGIN" -o json + rune -0 jq -r '.isValidated' <(output) + assert_output "true" +} + +@test "cscli lapi register --machine" { + rune -0 ./instance-crowdsec start + rune -0 cscli lapi register --machine newmachine + rune -0 cscli machines validate newmachine + rune -0 cscli machines inspect newmachine -o json + rune -0 jq -r '.isValidated' <(output) + assert_output "true" +} + +@test "cscli lapi register --token (ignored)" { + # A token is ignored if the server is not configured with it + rune -1 cscli lapi register --machine newmachine --token meh + assert_stderr --partial "connection refused" + + rune -0 ./instance-crowdsec start + rune -1 cscli lapi register --machine newmachine --token meh + assert_stderr --partial '422 Unprocessable Entity: API error: http code 422, invalid request:' + assert_stderr --partial 'registration_token in body should be at least 32 chars long' + + rune -0 cscli lapi register --machine newmachine --token 12345678901234567890123456789012 + assert_stderr --partial "Successfully registered to Local API" + + rune -0 cscli machines inspect newmachine -o json + rune -0 jq -r '.isValidated' <(output) + assert_output "null" +} + +@test "cscli lapi register --token" { + config_set '.api.server.auto_registration.enabled=true' + config_set '.api.server.auto_registration.token="12345678901234567890123456789012"' + config_set '.api.server.auto_registration.allowed_ranges=["127.0.0.1/32"]' + + rune -0 ./instance-crowdsec start + + rune -1 cscli lapi register --machine malicious --token 123456789012345678901234badtoken + assert_stderr --partial "401 Unauthorized: API error: invalid token for auto registration" + rune -1 cscli machines inspect malicious -o json + assert_stderr --partial "unable to read machine data 'malicious': user 'malicious': user doesn't exist" + + rune -0 cscli lapi register --machine newmachine --token 12345678901234567890123456789012 + assert_stderr --partial "Successfully registered to Local API" + rune -0 cscli machines inspect newmachine -o json + rune -0 jq -r '.isValidated' <(output) + assert_output "true" +} + +@test "cscli lapi register --token (bad source ip)" { + config_set '.api.server.auto_registration.enabled=true' + config_set '.api.server.auto_registration.token="12345678901234567890123456789012"' + config_set '.api.server.auto_registration.allowed_ranges=["127.0.0.2/32"]' + + rune -0 ./instance-crowdsec start + + rune -1 cscli lapi register --machine outofrange --token 12345678901234567890123456789012 + assert_stderr --partial "401 Unauthorized: API error: IP not in allowed range for auto registration" + rune -1 cscli machines inspect outofrange -o json + assert_stderr --partial "unable to read machine data 'outofrange': user 'outofrange': user doesn't exist" +} diff --git a/test/bats/02_nolapi.bats b/test/bats/02_nolapi.bats index c457900eeb2..cefa6d798b4 100644 --- a/test/bats/02_nolapi.bats +++ b/test/bats/02_nolapi.bats @@ -24,45 +24,54 @@ teardown() { #---------- @test "test without -no-api flag" { - rune -124 timeout 2s "${CROWDSEC}" - # from `man timeout`: If the command times out, and --preserve-status is not set, then exit with status 124. + config_set '.common.log_media="stdout"' + rune -0 wait-for \ + --err "CrowdSec Local API listening" \ + "$CROWDSEC" } @test "crowdsec should not run without LAPI (-no-api flag)" { - # really needs 4 secs on slow boxes - rune -1 timeout 4s "${CROWDSEC}" -no-api + config_set '.common.log_media="stdout"' + rune -1 wait-for "$CROWDSEC" -no-api } @test "crowdsec should not run without LAPI (no api.server in configuration file)" { config_disable_lapi config_log_stderr - # really needs 4 secs on slow boxes - rune -1 timeout 4s "${CROWDSEC}" - assert_stderr --partial "crowdsec local API is disabled" + rune -0 wait-for \ + --err "crowdsec local API is disabled" \ + "$CROWDSEC" } @test "capi status shouldn't be ok without api.server" { config_disable_lapi rune -1 cscli capi status assert_stderr --partial "crowdsec local API is disabled" - assert_stderr --partial "There is no configuration on 'api.server:'" + assert_stderr --partial "local API is disabled -- this command must be run on the local API machine" } -@test "cscli config show -o human" { - config_disable_lapi +@test "no lapi: cscli config show -o human" { + config_set '.api.server.enable=false' + rune -0 cscli config show -o human + assert_output --partial "Global:" + assert_output --partial "Crowdsec:" + assert_output --partial "cscli:" + assert_output --partial "Local API Server (disabled):" + + config_set 'del(.api.server)' rune -0 cscli config show -o human assert_output --partial "Global:" assert_output --partial "Crowdsec:" assert_output --partial "cscli:" - refute_output --partial "Local API Server:" + refute_output --partial "Local API Server" } @test "cscli config backup" { config_disable_lapi - backupdir=$(TMPDIR="${BATS_TEST_TMPDIR}" mktemp -u) - rune -0 cscli config backup "${backupdir}" + backupdir=$(TMPDIR="$BATS_TEST_TMPDIR" mktemp -u) + rune -0 cscli config backup "$backupdir" assert_stderr --partial "Starting configuration backup" - rune -1 cscli config backup "${backupdir}" + rune -1 cscli config backup "$backupdir" rm -rf -- "${backupdir:?}" assert_stderr --partial "failed to backup config" @@ -71,9 +80,8 @@ teardown() { @test "lapi status shouldn't be ok without api.server" { config_disable_lapi - ./instance-crowdsec start || true rune -1 cscli machines list - assert_stderr --partial "local API is disabled, please run this command on the local API machine" + assert_stderr --partial "local API is disabled -- this command must be run on the local API machine" } @test "cscli metrics" { @@ -85,5 +93,5 @@ teardown() { assert_output --partial "/v1/watchers/login" assert_stderr --partial "crowdsec local API is disabled" - assert_stderr --partial "local API is disabled, please run this command on the local API machine" + assert_stderr --partial "local API is disabled -- this command must be run on the local API machine" } diff --git a/test/bats/03_noagent.bats b/test/bats/03_noagent.bats index 12c66d2c09d..6be5101cee2 100644 --- a/test/bats/03_noagent.bats +++ b/test/bats/03_noagent.bats @@ -23,38 +23,49 @@ teardown() { #---------- @test "with agent: test without -no-cs flag" { - rune -124 timeout 2s "${CROWDSEC}" - # from `man timeout`: If the command times out, and --preserve-status is not set, then exit with status 124. + config_set '.common.log_media="stdout"' + rune -0 wait-for \ + --err "Starting processing data" \ + "$CROWDSEC" } @test "no agent: crowdsec LAPI should run (-no-cs flag)" { - rune -124 timeout 2s "${CROWDSEC}" -no-cs + config_set '.common.log_media="stdout"' + rune -0 wait-for \ + --err "CrowdSec Local API listening" \ + "$CROWDSEC" -no-cs } @test "no agent: crowdsec LAPI should run (no crowdsec_service in configuration file)" { config_disable_agent config_log_stderr - rune -124 timeout 2s "${CROWDSEC}" - - assert_stderr --partial "crowdsec agent is disabled" + rune -0 wait-for \ + --err "crowdsec agent is disabled" \ + "$CROWDSEC" } @test "no agent: cscli config show" { - config_disable_agent + config_set '.crowdsec_service.enable=false' rune -0 cscli config show -o human assert_output --partial "Global:" assert_output --partial "cscli:" assert_output --partial "Local API Server:" + assert_output --partial "Crowdsec (disabled):" - refute_output --partial "Crowdsec:" + config_set 'del(.crowdsec_service)' + rune -0 cscli config show -o human + assert_output --partial "Global:" + assert_output --partial "cscli:" + assert_output --partial "Local API Server:" + refute_output --partial "Crowdsec" } @test "no agent: cscli config backup" { config_disable_agent - backupdir=$(TMPDIR="${BATS_TEST_TMPDIR}" mktemp -u) - rune -0 cscli config backup "${backupdir}" + backupdir=$(TMPDIR="$BATS_TEST_TMPDIR" mktemp -u) + rune -0 cscli config backup "$backupdir" assert_stderr --partial "Starting configuration backup" - rune -1 cscli config backup "${backupdir}" + rune -1 cscli config backup "$backupdir" assert_stderr --partial "failed to backup config" assert_stderr --partial "file exists" @@ -65,7 +76,7 @@ teardown() { config_disable_agent ./instance-crowdsec start rune -0 cscli lapi status - assert_stderr --partial "You can successfully interact with Local API (LAPI)" + assert_output --partial "You can successfully interact with Local API (LAPI)" } @test "cscli metrics" { diff --git a/test/bats/04_capi.bats b/test/bats/04_capi.bats index f4c9f49e0f7..7ba6bfa4428 100644 --- a/test/bats/04_capi.bats +++ b/test/bats/04_capi.bats @@ -19,14 +19,59 @@ setup() { #---------- -@test "cscli capi status" { +@test "cscli capi status: fails without credentials" { config_enable_capi + ONLINE_API_CREDENTIALS_YAML="$(config_get '.api.server.online_client.credentials_path')" + # bogus values, won't be used + echo '{"login":"login","password":"password","url":"url"}' > "${ONLINE_API_CREDENTIALS_YAML}" + + config_set "$ONLINE_API_CREDENTIALS_YAML" 'del(.url)' + rune -1 cscli capi status + assert_stderr --partial "can't load CAPI credentials from '$ONLINE_API_CREDENTIALS_YAML' (missing url field)" + + config_set "$ONLINE_API_CREDENTIALS_YAML" 'del(.password)' + rune -1 cscli capi status + assert_stderr --partial "can't load CAPI credentials from '$ONLINE_API_CREDENTIALS_YAML' (missing password field)" + + config_set "$ONLINE_API_CREDENTIALS_YAML" 'del(.login)' + rune -1 cscli capi status + assert_stderr --partial "can't load CAPI credentials from '$ONLINE_API_CREDENTIALS_YAML' (missing login field)" + + rm "${ONLINE_API_CREDENTIALS_YAML}" + rune -1 cscli capi status + assert_stderr --partial "failed to load Local API: loading online client credentials: open ${ONLINE_API_CREDENTIALS_YAML}: no such file or directory" + + config_set 'del(.api.server.online_client)' + rune -1 cscli capi status + assert_stderr --regexp "no configuration for Central API \(CAPI\) in '$(echo $CONFIG_YAML|sed s#//#/#g)'" +} + +@test "cscli {capi,papi} status" { + ./instance-data load + config_enable_capi + + # should not panic with no credentials, but return an error + rune -1 cscli papi status + assert_stderr --partial "the Central API (CAPI) must be configured with 'cscli capi register'" + rune -0 cscli capi register --schmilblick githubciXXXXXXXXXXXXXXXXXXXXXXXX + rune -1 cscli capi status + assert_stderr --partial "no scenarios or appsec-rules installed, abort" + + rune -1 cscli papi status + assert_stderr --partial "no PAPI URL in configuration" + + rune -0 cscli console enable console_management + rune -1 cscli papi status + assert_stderr --partial "unable to get PAPI permissions" + assert_stderr --partial "Forbidden for plan" + + rune -0 cscli scenarios install crowdsecurity/ssh-bf rune -0 cscli capi status - assert_stderr --partial "Loaded credentials from" - assert_stderr --partial "Trying to authenticate with username" - assert_stderr --partial " on https://api.crowdsec.net/" - assert_stderr --partial "You can successfully interact with Central API (CAPI)" + assert_output --partial "Loaded credentials from" + assert_output --partial "Trying to authenticate with username" + assert_output --partial " on https://api.crowdsec.net/" + assert_output --partial "You can successfully interact with Central API (CAPI)" } @test "cscli alerts list: receive a community pull when capi is enabled" { @@ -34,7 +79,7 @@ setup() { ./instance-crowdsec start for ((i=0; i<15; i++)); do sleep 2 - [[ $(cscli alerts list -a -o json 2>/dev/null || cscli alerts list -o json) != "null" ]] && break + [[ $(cscli alerts list -a -o json) != "[]" ]] && break done rune -0 cscli alerts list -a -o json @@ -45,7 +90,7 @@ setup() { @test "we have exactly one machine, localhost" { rune -0 cscli machines list -o json rune -0 jq -c '[. | length, .[0].machineId[0:32], .[0].isValidated, .[0].ipAddress]' <(output) - assert_output '[1,"githubciXXXXXXXXXXXXXXXXXXXXXXXX",true,"127.0.0.1"]' + assert_json '[1,"githubciXXXXXXXXXXXXXXXXXXXXXXXX",true,"127.0.0.1"]' } @test "no agent: capi status should be ok" { @@ -53,12 +98,11 @@ setup() { config_disable_agent ./instance-crowdsec start rune -0 cscli capi status - assert_stderr --partial "You can successfully interact with Central API (CAPI)" + assert_output --partial "You can successfully interact with Central API (CAPI)" } -@test "cscli capi status: fails without credentials" { - ONLINE_API_CREDENTIALS_YAML="$(config_get '.api.server.online_client.credentials_path')" - rm "${ONLINE_API_CREDENTIALS_YAML}" - rune -1 cscli capi status - assert_stderr --partial "local API is disabled, please run this command on the local API machine: loading online client credentials: failed to read api server credentials configuration file '${ONLINE_API_CREDENTIALS_YAML}': open ${ONLINE_API_CREDENTIALS_YAML}: no such file or directory" +@test "capi register must be run from lapi" { + config_disable_lapi + rune -1 cscli capi register --schmilblick githubciXXXXXXXXXXXXXXXXXXXXXXXX + assert_stderr --partial "local API is disabled -- this command must be run on the local API machine" } diff --git a/test/bats/04_nocapi.bats b/test/bats/04_nocapi.bats index 388277cca3e..d22a6f0a953 100644 --- a/test/bats/04_nocapi.bats +++ b/test/bats/04_nocapi.bats @@ -25,40 +25,38 @@ teardown() { @test "without capi: crowdsec LAPI should run without capi (-no-capi flag)" { config_set '.common.log_media="stdout"' - rune -124 timeout 1s "${CROWDSEC}" -no-capi - assert_stderr --partial "Communication with CrowdSec Central API disabled from args" + rune -0 wait-for \ + --err "Communication with CrowdSec Central API disabled from args" \ + "$CROWDSEC" -no-capi } @test "without capi: crowdsec LAPI should still work" { config_disable_capi config_set '.common.log_media="stdout"' - rune -124 timeout 1s "${CROWDSEC}" - # from `man timeout`: If the command times out, and --preserve-status is not set, then exit with status 124. - assert_stderr --partial "push and pull to Central API disabled" + rune -0 wait-for \ + --err "push and pull to Central API disabled" \ + "$CROWDSEC" } @test "without capi: cscli capi status -> fail" { config_disable_capi ./instance-crowdsec start rune -1 cscli capi status - assert_stderr --partial "no configuration for Central API in " + assert_stderr --partial "no configuration for Central API (CAPI) in " } @test "no capi: cscli config show" { config_disable_capi rune -0 cscli config show -o human - assert_output --partial "Global:" - assert_output --partial "cscli:" - assert_output --partial "Crowdsec:" - assert_output --partial "Local API Server:" + assert_output --regexp "Global:.*Crowdsec.*cscli:.*Local API Server:" } @test "no agent: cscli config backup" { config_disable_capi - backupdir=$(TMPDIR="${BATS_TEST_TMPDIR}" mktemp -u) - rune -0 cscli config backup "${backupdir}" + backupdir=$(TMPDIR="$BATS_TEST_TMPDIR" mktemp -u) + rune -0 cscli config backup "$backupdir" assert_stderr --partial "Starting configuration backup" - rune -1 cscli config backup "${backupdir}" + rune -1 cscli config backup "$backupdir" assert_stderr --partial "failed to backup config" assert_stderr --partial "file exists" rm -rf -- "${backupdir:?}" @@ -68,7 +66,7 @@ teardown() { config_disable_capi ./instance-crowdsec start rune -0 cscli lapi status - assert_stderr --partial "You can successfully interact with Local API (LAPI)" + assert_output --partial "You can successfully interact with Local API (LAPI)" } @test "cscli metrics" { diff --git a/test/bats/05_config_yaml_local.bats b/test/bats/05_config_yaml_local.bats index 3cc20819bfb..ec7a4201964 100644 --- a/test/bats/05_config_yaml_local.bats +++ b/test/bats/05_config_yaml_local.bats @@ -21,7 +21,7 @@ setup() { load "../lib/setup.sh" ./instance-data load rune -0 config_get '.api.client.credentials_path' - LOCAL_API_CREDENTIALS="${output}" + LOCAL_API_CREDENTIALS="$output" export LOCAL_API_CREDENTIALS } @@ -34,50 +34,50 @@ teardown() { @test "config.yaml.local - cscli (log_level)" { config_set '.common.log_level="warning"' rune -0 cscli config show --key Config.Common.LogLevel - assert_output "warning" + assert_output "&3" echo "{'common':{'log_level':'debug'}}" >"${CONFIG_YAML}.local" rune -0 cscli config show --key Config.Common.LogLevel - assert_output "debug" + assert_output "&5" } @test "config.yaml.local - cscli (log_level - with envvar)" { config_set '.common.log_level="warning"' rune -0 cscli config show --key Config.Common.LogLevel - assert_output "warning" + assert_output "&3" export CROWDSEC_LOG_LEVEL=debug echo "{'common':{'log_level':'${CROWDSEC_LOG_LEVEL}'}}" >"${CONFIG_YAML}.local" rune -0 cscli config show --key Config.Common.LogLevel - assert_output "debug" + assert_output "&5" } @test "config.yaml.local - crowdsec (listen_url)" { # disable the agent or we'll need to patch api client credentials too rune -0 config_disable_agent ./instance-crowdsec start - rune -0 ./bin/wait-for-port -q 8080 + rune -0 wait-for-port -q 8080 ./instance-crowdsec stop - rune -1 ./bin/wait-for-port -q 8080 + rune -1 wait-for-port -q 8080 echo "{'api':{'server':{'listen_uri':127.0.0.1:8083}}}" >"${CONFIG_YAML}.local" ./instance-crowdsec start - rune -0 ./bin/wait-for-port -q 8083 - rune -1 ./bin/wait-for-port -q 8080 + rune -0 wait-for-port -q 8083 + rune -1 wait-for-port -q 8080 ./instance-crowdsec stop rm -f "${CONFIG_YAML}.local" ./instance-crowdsec start - rune -1 ./bin/wait-for-port -q 8083 - rune -0 ./bin/wait-for-port -q 8080 + rune -1 wait-for-port -q 8083 + rune -0 wait-for-port -q 8080 } @test "local_api_credentials.yaml.local" { rune -0 config_disable_agent echo "{'api':{'server':{'listen_uri':127.0.0.1:8083}}}" >"${CONFIG_YAML}.local" ./instance-crowdsec start - rune -0 ./bin/wait-for-port -q 8083 + rune -0 wait-for-port -q 8083 rune -1 cscli decisions list echo "{'url':'http://127.0.0.1:8083'}" >"${LOCAL_API_CREDENTIALS}.local" @@ -88,13 +88,13 @@ teardown() { @test "simulation.yaml.local" { rune -0 config_get '.config_paths.simulation_path' refute_output null - SIMULATION="${output}" + SIMULATION="$output" - echo "simulation: off" >"${SIMULATION}" + echo "simulation: off" >"$SIMULATION" rune -0 cscli simulation status -o human assert_stderr --partial "global simulation: disabled" - echo "simulation: on" >"${SIMULATION}" + echo "simulation: on" >"$SIMULATION" rune -0 cscli simulation status -o human assert_stderr --partial "global simulation: enabled" @@ -110,7 +110,7 @@ teardown() { @test "profiles.yaml.local" { rune -0 config_get '.api.server.profiles_path' refute_output null - PROFILES="${output}" + PROFILES="$output" cat <<-EOT >"${PROFILES}.local" name: default_ip_remediation @@ -122,14 +122,17 @@ teardown() { on_success: break EOT - tmpfile=$(TMPDIR="${BATS_TEST_TMPDIR}" mktemp) - touch "${tmpfile}" + tmpfile=$(TMPDIR="$BATS_TEST_TMPDIR" mktemp) + touch "$tmpfile" ACQUIS_YAML=$(config_get '.crowdsec_service.acquisition_path') - echo -e "---\nfilename: ${tmpfile}\nlabels:\n type: syslog\n" >>"${ACQUIS_YAML}" + echo -e "---\nfilename: ${tmpfile}\nlabels:\n type: syslog\n" >>"$ACQUIS_YAML" + + rune -0 cscli collections install crowdsecurity/sshd + rune -0 cscli parsers install crowdsecurity/syslog-logs ./instance-crowdsec start sleep .5 - fake_log >>"${tmpfile}" + fake_log >>"$tmpfile" # this could be simplified, but some systems are slow and we don't want to # wait more than required @@ -138,6 +141,6 @@ teardown() { rune -0 cscli decisions list -o json rune -0 jq --exit-status '.[].decisions[0] | [.value,.type] == ["1.1.1.172","captcha"]' <(output) && break done - rm -f -- "${tmpfile}" - [[ "${status}" -eq 0 ]] || fail "captcha not triggered" + rm -f -- "$tmpfile" + [[ "$status" -eq 0 ]] || fail "captcha not triggered" } diff --git a/test/bats/07_setup.bats b/test/bats/07_setup.bats index c63f0702421..f832ac572d2 100644 --- a/test/bats/07_setup.bats +++ b/test/bats/07_setup.bats @@ -7,6 +7,8 @@ setup_file() { load "../lib/setup_file.sh" ./instance-data load HUB_DIR=$(config_get '.config_paths.hub_dir') + # remove trailing slash if any (like in default config.yaml from package) + HUB_DIR=${HUB_DIR%/} export HUB_DIR DETECT_YAML="${HUB_DIR}/detect.yaml" export DETECT_YAML @@ -68,7 +70,11 @@ teardown() { assert_line --partial "--skip-service strings ignore a service, don't recommend hub/datasources (can be repeated)" rune -1 cscli setup detect --detect-config /path/does/not/exist - assert_stderr --partial "detecting services: while reading file: open /path/does/not/exist: no such file or directory" + assert_stderr --partial "open /path/does/not/exist: no such file or directory" + + # - is stdin + rune -1 cscli setup detect --detect-config - <<< "{}" + assert_stderr --partial "detecting services: missing version tag (must be 1.0)" # rm -f "${HUB_DIR}/detect.yaml" } @@ -142,7 +148,7 @@ teardown() { EOT rune -1 cscli setup detect --list-supported-services --detect-config "$tempfile" - assert_stderr --partial "while parsing ${tempfile}: yaml: unmarshal errors:" + assert_stderr --partial "yaml: unmarshal errors:" rm -f "$tempfile" } @@ -311,7 +317,7 @@ update-notifier-motd.timer enabled enabled @test "cscli setup detect (process)" { # This is harder to mock, because gopsutil requires proc/ to be a mount # point. So we pick a process that exists for sure. - expected_process=$(basename "$SHELL") + expected_process=cscli cat <<-EOT >"${DETECT_YAML}" version: 1.0 @@ -501,46 +507,49 @@ update-notifier-motd.timer enabled enabled @test "cscli setup install-hub (dry run)" { # it's not installed - rune -0 cscli collections list -o json - rune -0 jq -r '.collections[].name' <(output) - refute_line "crowdsecurity/apache2" + rune -0 cscli collections inspect crowdsecurity/apache2 -o json + rune -0 jq -e '.installed == false' <(output) # we install it rune -0 cscli setup install-hub /dev/stdin --dry-run <<< '{"setup":[{"install":{"collections":["crowdsecurity/apache2"]}}]}' assert_output 'dry-run: would install collection crowdsecurity/apache2' # still not installed - rune -0 cscli collections list -o json - rune -0 jq -r '.collections[].name' <(output) - refute_line "crowdsecurity/apache2" + rune -0 cscli collections inspect crowdsecurity/apache2 -o json + rune -0 jq -e '.installed == false' <(output) + + # same with dependencies + rune -0 cscli collections remove --all + rune -0 cscli setup install-hub /dev/stdin --dry-run <<< '{"setup":[{"install":{"collections":["crowdsecurity/linux"]}}]}' + assert_output 'dry-run: would install collection crowdsecurity/linux' } @test "cscli setup install-hub (dry run: install multiple collections)" { # it's not installed - rune -0 cscli collections list -o json - rune -0 jq -r '.collections[].name' <(output) - refute_line "crowdsecurity/apache2" + rune -0 cscli collections inspect crowdsecurity/apache2 -o json + rune -0 jq -e '.installed == false' <(output) # we install it rune -0 cscli setup install-hub /dev/stdin --dry-run <<< '{"setup":[{"install":{"collections":["crowdsecurity/apache2"]}}]}' assert_output 'dry-run: would install collection crowdsecurity/apache2' # still not installed - rune -0 cscli collections list -o json - rune -0 jq -r '.collections[].name' <(output) - refute_line "crowdsecurity/apache2" + rune -0 cscli collections inspect crowdsecurity/apache2 -o json + rune -0 jq -e '.installed == false' <(output) } @test "cscli setup install-hub (dry run: install multiple collections, parsers, scenarios, postoverflows)" { - rune -0 cscli setup install-hub /dev/stdin --dry-run <<< '{"setup":[{"install":{"collections":["crowdsecurity/foo","johndoe/bar"],"parsers":["crowdsecurity/fooparser","johndoe/barparser"],"scenarios":["crowdsecurity/fooscenario","johndoe/barscenario"],"postoverflows":["crowdsecurity/foopo","johndoe/barpo"]}}]}' - assert_line 'dry-run: would install collection crowdsecurity/foo' - assert_line 'dry-run: would install collection johndoe/bar' - assert_line 'dry-run: would install parser crowdsecurity/fooparser' - assert_line 'dry-run: would install parser johndoe/barparser' - assert_line 'dry-run: would install scenario crowdsecurity/fooscenario' - assert_line 'dry-run: would install scenario johndoe/barscenario' - assert_line 'dry-run: would install postoverflow crowdsecurity/foopo' - assert_line 'dry-run: would install postoverflow johndoe/barpo' + rune -0 cscli setup install-hub /dev/stdin --dry-run <<< '{"setup":[{"install":{"collections":["crowdsecurity/aws-console","crowdsecurity/caddy"],"parsers":["crowdsecurity/asterisk-logs"],"scenarios":["crowdsecurity/smb-fs"],"postoverflows":["crowdsecurity/cdn-whitelist","crowdsecurity/rdns"]}}]}' + assert_line 'dry-run: would install collection crowdsecurity/aws-console' + assert_line 'dry-run: would install collection crowdsecurity/caddy' + assert_line 'dry-run: would install parser crowdsecurity/asterisk-logs' + assert_line 'dry-run: would install scenario crowdsecurity/smb-fs' + assert_line 'dry-run: would install postoverflow crowdsecurity/cdn-whitelist' + assert_line 'dry-run: would install postoverflow crowdsecurity/rdns' + + rune -1 cscli setup install-hub /dev/stdin --dry-run <<< '{"setup":[{"install":{"collections":["crowdsecurity/foo"]}}]}' + assert_stderr --partial 'collection crowdsecurity/foo not found' + } @test "cscli setup datasources" { @@ -810,7 +819,6 @@ update-notifier-motd.timer enabled enabled setup: alsdk al; sdf EOT - assert_output "while unmarshaling setup file: yaml: line 2: could not find expected ':'" + assert_output "while parsing setup file: yaml: line 2: could not find expected ':'" assert_stderr --partial "invalid setup file" } - diff --git a/test/bats/08_metrics.bats b/test/bats/08_metrics.bats index 836e220484a..e260e667524 100644 --- a/test/bats/08_metrics.bats +++ b/test/bats/08_metrics.bats @@ -23,10 +23,9 @@ teardown() { #---------- @test "cscli metrics (crowdsec not running)" { - rune -1 cscli metrics - # crowdsec is down - assert_stderr --partial "failed to fetch prometheus metrics" - assert_stderr --partial "connect: connection refused" + rune -0 cscli metrics + # crowdsec is down, we won't get an error because some metrics come from the db instead + assert_stderr --partial 'while fetching metrics: executing GET request for URL \"http://127.0.0.1:6060/metrics\" failed: Get \"http://127.0.0.1:6060/metrics\": dial tcp 127.0.0.1:6060: connect: connection refused' } @test "cscli metrics (bad configuration)" { @@ -43,18 +42,62 @@ teardown() { @test "cscli metrics (missing listen_addr)" { config_set 'del(.prometheus.listen_addr)' - rune -1 cscli metrics - assert_stderr --partial "no prometheus url, please specify" + rune -0 ./instance-crowdsec start + rune -0 cscli metrics --debug + assert_stderr --partial "prometheus.listen_addr is empty, defaulting to 127.0.0.1" } @test "cscli metrics (missing listen_port)" { - config_set 'del(.prometheus.listen_addr)' - rune -1 cscli metrics - assert_stderr --partial "no prometheus url, please specify" + config_set 'del(.prometheus.listen_port)' + rune -0 ./instance-crowdsec start + rune -0 cscli metrics --debug + assert_stderr --partial "prometheus.listen_port is empty or zero, defaulting to 6060" } @test "cscli metrics (missing prometheus section)" { config_set 'del(.prometheus)' rune -1 cscli metrics - assert_stderr --partial "prometheus section missing, can't show metrics" + assert_stderr --partial "prometheus is not enabled, can't show metrics" +} + +@test "cscli metrics" { + rune -0 ./instance-crowdsec start + rune -0 cscli lapi status + rune -0 cscli metrics + assert_output --partial "Route" + assert_output --partial '/v1/watchers/login' + assert_output --partial "Local API Metrics:" + + rune -0 cscli metrics -o json + rune -0 jq 'keys' <(output) + assert_output --partial '"alerts",' + assert_output --partial '"parsers",' +} + +@test "cscli metrics list" { + rune -0 cscli metrics list + assert_output --regexp "Type.*Title.*Description" + + rune -0 cscli metrics list -o json + rune -0 jq -c '.[] | [.type,.title]' <(output) + assert_line '["acquisition","Acquisition Metrics"]' +} + +@test "cscli metrics show" { + rune -0 ./instance-crowdsec start + rune -0 cscli lapi status + + assert_equal "$(cscli metrics)" "$(cscli metrics show)" + + rune -1 cscli metrics show foobar + assert_stderr --partial "unknown metrics type: foobar" + + rune -0 cscli metrics show lapi + assert_output --partial "Local API Metrics:" + assert_output --regexp "Route.*Method.*Hits" + assert_output --regexp "/v1/watchers/login.*POST" + + rune -0 cscli metrics show lapi -o json + rune -0 jq -c '.lapi."/v1/watchers/login" | keys' <(output) + assert_json '["POST"]' } diff --git a/test/bats/08_metrics_bouncer.bats b/test/bats/08_metrics_bouncer.bats new file mode 100644 index 00000000000..c4dfebbab1d --- /dev/null +++ b/test/bats/08_metrics_bouncer.bats @@ -0,0 +1,527 @@ +#!/usr/bin/env bats +# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si: + +set -u + +setup_file() { + load "../lib/setup_file.sh" +} + +teardown_file() { + load "../lib/teardown_file.sh" +} + +setup() { + load "../lib/setup.sh" + ./instance-data load + ./instance-crowdsec start +} + +teardown() { + ./instance-crowdsec stop +} + +#---------- + +@test "cscli metrics show bouncers (empty)" { + # this message is given only if we ask explicitly for bouncers + notfound="No bouncer metrics found." + + rune -0 cscli metrics show bouncers + assert_output "$notfound" + + rune -0 cscli metrics list + refute_output "$notfound" +} + +@test "rc usage metrics (empty payload)" { + # a registered bouncer can send metrics for the lapi and console + API_KEY=$(cscli bouncers add testbouncer -o raw) + export API_KEY + + payload=$(yq -o j <<-EOT + remediation_components: [] + log_processors: [] + EOT + ) + + rune -22 curl-with-key '/v1/usage-metrics' -X POST --data "$payload" + assert_stderr --partial 'error: 400' + assert_json '{message: "Missing remediation component data"}' +} + +@test "rc usage metrics (bad payload)" { + API_KEY=$(cscli bouncers add testbouncer -o raw) + export API_KEY + + payload=$(yq -o j <<-EOT + remediation_components: + - version: "v1.0" + log_processors: [] + EOT + ) + + rune -22 curl-with-key '/v1/usage-metrics' -X POST --data "$payload" + assert_stderr --partial "error: 422" + rune -0 jq -r '.message' <(output) + assert_output - <<-EOT + validation failure list: + remediation_components.0.utc_startup_timestamp in body is required + EOT + + # validation, like timestamp format + + payload=$(yq -o j '.remediation_components[0].utc_startup_timestamp = "2021-09-01T00:00:00Z"' <<<"$payload") + rune -22 curl-with-key '/v1/usage-metrics' -X POST --data "$payload" + assert_stderr --partial "error: 400" + assert_json '{message: "json: cannot unmarshal string into Go struct field AllMetrics.remediation_components of type int64"}' + + payload=$(yq -o j '.remediation_components[0].utc_startup_timestamp = 1707399316' <<<"$payload") + rune -0 curl-with-key '/v1/usage-metrics' -X POST --data "$payload" + refute_output + + payload=$(yq -o j '.remediation_components[0].metrics = [{"meta": {}}]' <<<"$payload") + rune -22 curl-with-key '/v1/usage-metrics' -X POST --data "$payload" + assert_stderr --partial "error: 422" + rune -0 jq -r '.message' <(output) + assert_output - <<-EOT + validation failure list: + remediation_components.0.metrics.0.items in body is required + validation failure list: + remediation_components.0.metrics.0.meta.utc_now_timestamp in body is required + remediation_components.0.metrics.0.meta.window_size_seconds in body is required + EOT +} + +@test "rc usage metrics (good payload)" { + API_KEY=$(cscli bouncers add testbouncer -o raw) + export API_KEY + + payload=$(yq -o j <<-EOT + remediation_components: + - version: "v1.0" + utc_startup_timestamp: 1707399316 + log_processors: [] + EOT + ) + + # bouncers have feature flags too + + payload=$(yq -o j ' + .remediation_components[0].feature_flags = ["huey", "dewey", "louie"] | + .remediation_components[0].os = {"name": "Multics", "version": "MR12.5"} + ' <<<"$payload") + rune -0 curl-with-key '/v1/usage-metrics' -X POST --data "$payload" + rune -0 cscli bouncer inspect testbouncer -o json + rune -0 yq -o j '[.os,.featureflags]' <(output) + assert_json '["Multics/MR12.5",["huey","dewey","louie"]]' + + payload=$(yq -o j ' + .remediation_components[0].metrics = [ + { + "meta": {"utc_now_timestamp": 1707399316, "window_size_seconds":600}, + "items":[ + {"name": "foo", "unit": "pound", "value": 3.1415926}, + {"name": "foo", "unit": "pound", "value": 2.7182818}, + {"name": "foo", "unit": "dogyear", "value": 2.7182818} + ] + } + ] + ' <<<"$payload") + rune -0 curl-with-key '/v1/usage-metrics' -X POST --data "$payload" + rune -0 cscli metrics show bouncers -o json + # aggregation is ok -- we are truncating, not rounding, because the float is mandated by swagger. + # but without labels the origin string is empty + assert_json '{bouncers:{testbouncer:{"": {foo: {dogyear: 2, pound: 5}}}}}' + + rune -0 cscli metrics show bouncers + assert_output - <<-EOT + Bouncer Metrics (testbouncer) since 2024-02-08 13:35:16 +0000 UTC: + +--------+-----------------+ + | Origin | foo | + | | dogyear | pound | + +--------+---------+-------+ + | Total | 2 | 5 | + +--------+---------+-------+ + EOT + + # some more realistic values, at least for the labels + # we don't use the same now_timestamp or the payload will be silently discarded + + payload=$(yq -o j ' + .remediation_components[0].metrics = [ + { + "meta": {"utc_now_timestamp": 1707399916, "window_size_seconds":600}, + "items":[ + {"name": "active_decisions", "unit": "ip", "value": 500, "labels": {"ip_type": "ipv4", "origin": "lists:firehol_voipbl"}}, + {"name": "active_decisions", "unit": "ip", "value": 1, "labels": {"ip_type": "ipv6", "origin": "cscli"}}, + {"name": "dropped", "unit": "byte", "value": 3800, "labels": {"ip_type": "ipv4", "origin": "CAPI"}}, + {"name": "dropped", "unit": "byte", "value": 0, "labels": {"ip_type": "ipv4", "origin": "cscli"}}, + {"name": "dropped", "unit": "byte", "value": 1034, "labels": {"ip_type": "ipv4", "origin": "lists:firehol_cruzit_web_attacks"}}, + {"name": "dropped", "unit": "byte", "value": 3847, "labels": {"ip_type": "ipv4", "origin": "lists:firehol_voipbl"}}, + {"name": "dropped", "unit": "byte", "value": 380, "labels": {"ip_type": "ipv6", "origin": "cscli"}}, + {"name": "dropped", "unit": "packet", "value": 100, "labels": {"ip_type": "ipv4", "origin": "CAPI"}}, + {"name": "dropped", "unit": "packet", "value": 10, "labels": {"ip_type": "ipv4", "origin": "cscli"}}, + {"name": "dropped", "unit": "packet", "value": 23, "labels": {"ip_type": "ipv4", "origin": "lists:firehol_cruzit_web_attacks"}}, + {"name": "dropped", "unit": "packet", "value": 58, "labels": {"ip_type": "ipv4", "origin": "lists:firehol_voipbl"}}, + {"name": "dropped", "unit": "packet", "value": 0, "labels": {"ip_type": "ipv4", "origin": "lists:anotherlist"}}, + {"name": "dropped", "unit": "byte", "value": 0, "labels": {"ip_type": "ipv4", "origin": "lists:anotherlist"}}, + {"name": "dropped", "unit": "packet", "value": 0, "labels": {"ip_type": "ipv6", "origin": "cscli"}} + ] + } + ] | + .remediation_components[0].type = "crowdsec-firewall-bouncer" + ' <<<"$payload") + + rune -0 curl-with-key '/v1/usage-metrics' -X POST --data "$payload" + rune -0 cscli metrics show bouncers -o json + assert_json '{ + "bouncers": { + "testbouncer": { + "": { + "foo": { + "dogyear": 2, + "pound": 5 + } + }, + "CAPI": { + "dropped": { + "byte": 3800, + "packet": 100 + } + }, + "cscli": { + "active_decisions": { + "ip": 1 + }, + "dropped": { + "byte": 380, + "packet": 10 + } + }, + "lists:firehol_cruzit_web_attacks": { + "dropped": { + "byte": 1034, + "packet": 23 + } + }, + "lists:firehol_voipbl": { + "active_decisions": { + "ip": 500 + }, + "dropped": { + "byte": 3847, + "packet": 58 + }, + }, + "lists:anotherlist": { + "dropped": { + "byte": 0, + "packet": 0 + } + } + } + } + }' + + rune -0 cscli metrics show bouncers + assert_output - <<-EOT + Bouncer Metrics (testbouncer) since 2024-02-08 13:35:16 +0000 UTC: + +----------------------------------+------------------+-------------------+-----------------+ + | Origin | active_decisions | dropped | foo | + | | IPs | bytes | packets | dogyear | pound | + +----------------------------------+------------------+---------+---------+---------+-------+ + | CAPI (community blocklist) | - | 3.80k | 100 | - | - | + | cscli (manual decisions) | 1 | 380 | 10 | - | - | + | lists:anotherlist | - | 0 | 0 | - | - | + | lists:firehol_cruzit_web_attacks | - | 1.03k | 23 | - | - | + | lists:firehol_voipbl | 500 | 3.85k | 58 | - | - | + +----------------------------------+------------------+---------+---------+---------+-------+ + | Total | 501 | 9.06k | 191 | 2 | 5 | + +----------------------------------+------------------+---------+---------+---------+-------+ + EOT + + # active_decisions is actually a gauge: values should not be aggregated, keep only the latest one + + payload=$(yq -o j ' + .remediation_components[0].metrics = [ + { + "meta": {"utc_now_timestamp": 1707450000, "window_size_seconds":600}, + "items":[ + {"name": "active_decisions", "unit": "ip", "value": 250, "labels": {"ip_type": "ipv4", "origin": "lists:firehol_voipbl"}}, + {"name": "active_decisions", "unit": "ip", "value": 10, "labels": {"ip_type": "ipv6", "origin": "cscli"}} + ] + } + ] | + .remediation_components[0].type = "crowdsec-firewall-bouncer" + ' <<<"$payload") + + rune -0 curl-with-key '/v1/usage-metrics' -X POST --data "$payload" + rune -0 cscli metrics show bouncers -o json + assert_json '{ + "bouncers": { + "testbouncer": { + "": { + "foo": { + "dogyear": 2, + "pound": 5 + } + }, + "CAPI": { + "dropped": { + "byte": 3800, + "packet": 100 + } + }, + "cscli": { + "active_decisions": { + "ip": 10 + }, + "dropped": { + "byte": 380, + "packet": 10 + } + }, + "lists:firehol_cruzit_web_attacks": { + "dropped": { + "byte": 1034, + "packet": 23 + } + }, + "lists:firehol_voipbl": { + "active_decisions": { + "ip": 250 + }, + "dropped": { + "byte": 3847, + "packet": 58 + }, + }, + "lists:anotherlist": { + "dropped": { + "byte": 0, + "packet": 0 + } + } + } + } + }' + + rune -0 cscli metrics show bouncers + assert_output - <<-EOT + Bouncer Metrics (testbouncer) since 2024-02-08 13:35:16 +0000 UTC: + +----------------------------------+------------------+-------------------+-----------------+ + | Origin | active_decisions | dropped | foo | + | | IPs | bytes | packets | dogyear | pound | + +----------------------------------+------------------+---------+---------+---------+-------+ + | CAPI (community blocklist) | - | 3.80k | 100 | - | - | + | cscli (manual decisions) | 10 | 380 | 10 | - | - | + | lists:anotherlist | - | 0 | 0 | - | - | + | lists:firehol_cruzit_web_attacks | - | 1.03k | 23 | - | - | + | lists:firehol_voipbl | 250 | 3.85k | 58 | - | - | + +----------------------------------+------------------+---------+---------+---------+-------+ + | Total | 260 | 9.06k | 191 | 2 | 5 | + +----------------------------------+------------------+---------+---------+---------+-------+ + EOT +} + +@test "rc usage metrics (unknown metrics)" { + # new metrics are introduced in a new bouncer version, unknown by this version of cscli: some are gauges, some are not + + API_KEY=$(cscli bouncers add testbouncer -o raw) + export API_KEY + + payload=$(yq -o j <<-EOT + remediation_components: + - version: "v1.0" + utc_startup_timestamp: 1707369316 + log_processors: [] + EOT + ) + + payload=$(yq -o j ' + .remediation_components[0].metrics = [ + { + "meta": {"utc_now_timestamp": 1707460000, "window_size_seconds":600}, + "items":[ + {"name": "ima_gauge", "unit": "second", "value": 30, "labels": {"origin": "cscli"}}, + {"name": "notagauge", "unit": "inch", "value": 15, "labels": {"origin": "cscli"}} + ] + }, { + "meta": {"utc_now_timestamp": 1707450000, "window_size_seconds":600}, + "items":[ + {"name": "ima_gauge", "unit": "second", "value": 20, "labels": {"origin": "cscli"}}, + {"name": "notagauge", "unit": "inch", "value": 10, "labels": {"origin": "cscli"}} + ] + } + ] | + .remediation_components[0].type = "crowdsec-firewall-bouncer" + ' <<<"$payload") + + rune -0 curl-with-key '/v1/usage-metrics' -X POST --data "$payload" + + rune -0 cscli metrics show bouncers -o json + assert_json '{bouncers: {testbouncer: {cscli: {ima_gauge: {second: 30}, notagauge: {inch: 25}}}}}' + + rune -0 cscli metrics show bouncers + assert_output - <<-EOT + Bouncer Metrics (testbouncer) since 2024-02-09 03:40:00 +0000 UTC: + +--------------------------+--------+-----------+ + | Origin | ima | notagauge | + | | second | inch | + +--------------------------+--------+-----------+ + | cscli (manual decisions) | 30 | 25 | + +--------------------------+--------+-----------+ + | Total | 30 | 25 | + +--------------------------+--------+-----------+ + EOT +} + +@test "rc usage metrics (ipv4/ipv6)" { + # gauge metrics are not aggregated over time, but they are over ip type + + API_KEY=$(cscli bouncers add testbouncer -o raw) + export API_KEY + + payload=$(yq -o j <<-EOT + remediation_components: + - version: "v1.0" + utc_startup_timestamp: 1707369316 + log_processors: [] + EOT + ) + + payload=$(yq -o j ' + .remediation_components[0].metrics = [ + { + "meta": {"utc_now_timestamp": 1707460000, "window_size_seconds":600}, + "items":[ + {"name": "active_decisions", "unit": "ip", "value": 200, "labels": {"ip_type": "ipv4", "origin": "cscli"}}, + {"name": "active_decisions", "unit": "ip", "value": 30, "labels": {"ip_type": "ipv6", "origin": "cscli"}} + ] + }, { + "meta": {"utc_now_timestamp": 1707450000, "window_size_seconds":600}, + "items":[ + {"name": "active_decisions", "unit": "ip", "value": 400, "labels": {"ip_type": "ipv4", "origin": "cscli"}}, + {"name": "active_decisions", "unit": "ip", "value": 50, "labels": {"ip_type": "ipv6", "origin": "cscli"}} + ] + } + ] | + .remediation_components[0].type = "crowdsec-firewall-bouncer" + ' <<<"$payload") + + rune -0 curl-with-key '/v1/usage-metrics' -X POST --data "$payload" + + rune -0 cscli metrics show bouncers -o json + assert_json '{bouncers: {testbouncer: {cscli: {active_decisions: {ip: 230}}}}}' + + rune -0 cscli metrics show bouncers + assert_output - <<-EOT + Bouncer Metrics (testbouncer) since 2024-02-09 03:40:00 +0000 UTC: + +--------------------------+------------------+ + | Origin | active_decisions | + | | IPs | + +--------------------------+------------------+ + | cscli (manual decisions) | 230 | + +--------------------------+------------------+ + | Total | 230 | + +--------------------------+------------------+ + EOT +} + +@test "rc usage metrics (multiple bouncers)" { + # multiple bouncers have separate totals and can have different types of metrics and units -> different columns + + API_KEY=$(cscli bouncers add bouncer1 -o raw) + export API_KEY + + payload=$(yq -o j <<-EOT + remediation_components: + - version: "v1.0" + utc_startup_timestamp: 1707369316 + metrics: + - meta: + utc_now_timestamp: 1707399316 + window_size_seconds: 600 + items: + - name: dropped + unit: byte + value: 1000 + labels: + origin: CAPI + - name: dropped + unit: byte + value: 800 + labels: + origin: lists:somelist + - name: processed + unit: byte + value: 12340 + - name: processed + unit: packet + value: 100 + EOT + ) + + rune -0 curl-with-key '/v1/usage-metrics' -X POST --data "$payload" + + API_KEY=$(cscli bouncers add bouncer2 -o raw) + export API_KEY + + payload=$(yq -o j <<-EOT + remediation_components: + - version: "v1.0" + utc_startup_timestamp: 1707379316 + metrics: + - meta: + utc_now_timestamp: 1707389316 + window_size_seconds: 600 + items: + - name: dropped + unit: byte + value: 1500 + labels: + origin: lists:somelist + - name: dropped + unit: byte + value: 2000 + labels: + origin: CAPI + - name: dropped + unit: packet + value: 20 + labels: + origin: lists:somelist + EOT + ) + + rune -0 curl-with-key '/v1/usage-metrics' -X POST --data "$payload" + + rune -0 cscli metrics show bouncers -o json + assert_json '{bouncers:{bouncer1:{"":{processed:{byte:12340,packet:100}},CAPI:{dropped:{byte:1000}},"lists:somelist":{dropped:{byte:800}}},bouncer2:{"lists:somelist":{dropped:{byte:1500,packet:20}},CAPI:{dropped:{byte:2000}}}}}' + + rune -0 cscli metrics show bouncers + assert_output - <<-EOT + Bouncer Metrics (bouncer1) since 2024-02-08 13:35:16 +0000 UTC: + +----------------------------+---------+-----------------------+ + | Origin | dropped | processed | + | | bytes | bytes | packets | + +----------------------------+---------+-----------+-----------+ + | CAPI (community blocklist) | 1.00k | - | - | + | lists:somelist | 800 | - | - | + +----------------------------+---------+-----------+-----------+ + | Total | 1.80k | 12.34k | 100 | + +----------------------------+---------+-----------+-----------+ + + Bouncer Metrics (bouncer2) since 2024-02-08 10:48:36 +0000 UTC: + +----------------------------+-------------------+ + | Origin | dropped | + | | bytes | packets | + +----------------------------+---------+---------+ + | CAPI (community blocklist) | 2.00k | - | + | lists:somelist | 1.50k | 20 | + +----------------------------+---------+---------+ + | Total | 3.50k | 20 | + +----------------------------+---------+---------+ + EOT +} diff --git a/test/bats/08_metrics_machines.bats b/test/bats/08_metrics_machines.bats new file mode 100644 index 00000000000..3b73839e753 --- /dev/null +++ b/test/bats/08_metrics_machines.bats @@ -0,0 +1,100 @@ +#!/usr/bin/env bats +# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si: + +set -u + +setup_file() { + load "../lib/setup_file.sh" +} + +teardown_file() { + load "../lib/teardown_file.sh" +} + +setup() { + load "../lib/setup.sh" + ./instance-data load + ./instance-crowdsec start +} + +teardown() { + ./instance-crowdsec stop +} + +#---------- + +@test "lp usage metrics (empty payload)" { + # a registered log processor can send metrics for the lapi and console + TOKEN=$(lp-get-token) + export TOKEN + + payload=$(yq -o j <<-EOT + remediation_components: [] + log_processors: [] + EOT + ) + + rune -22 curl-with-token '/v1/usage-metrics' -X POST --data "$payload" + assert_stderr --partial 'error: 400' + assert_json '{message: "Missing log processor data"}' +} + +@test "lp usage metrics (bad payload)" { + TOKEN=$(lp-get-token) + export TOKEN + + payload=$(yq -o j <<-EOT + remediation_components: [] + log_processors: + - version: "v1.0" + EOT + ) + + rune -22 curl-with-token '/v1/usage-metrics' -X POST --data "$payload" + assert_stderr --partial "error: 422" + rune -0 jq -r '.message' <(output) + assert_output - <<-EOT + validation failure list: + log_processors.0.utc_startup_timestamp in body is required + log_processors.0.datasources in body is required + log_processors.0.hub_items in body is required + EOT +} + +@test "lp usage metrics (full payload)" { + TOKEN=$(lp-get-token) + export TOKEN + + # base payload without any measurement + + payload=$(yq -o j <<-EOT + remediation_components: [] + log_processors: + - version: "v1.0" + utc_startup_timestamp: 1707399316 + hub_items: {} + feature_flags: + - marshmallows + os: + name: CentOS + version: "8" + metrics: + - name: logs_parsed + value: 5000 + unit: count + labels: {} + items: [] + meta: + window_size_seconds: 600 + utc_now_timestamp: 1707485349 + console_options: + - share_context + datasources: + syslog: 1 + file: 4 + EOT + ) + + rune -0 curl-with-token '/v1/usage-metrics' -X POST --data "$payload" + refute_output +} diff --git a/test/bats/09_console.bats b/test/bats/09_console.bats new file mode 100644 index 00000000000..2e2f9bf058d --- /dev/null +++ b/test/bats/09_console.bats @@ -0,0 +1,100 @@ +#!/usr/bin/env bats +# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si: + +set -u + +setup_file() { + load "../lib/setup_file.sh" +} + +teardown_file() { + load "../lib/teardown_file.sh" +} + +setup() { + load "../lib/setup.sh" + load "../lib/bats-file/load.bash" + ./instance-data load + config_enable_capi + + config_set "$(config_get '.api.server.online_client.credentials_path')" ' + .url="https://api.crowdsec.net/" | + .login="test" | + .password="test" + ' +} + +#---------- + +@test "cscli console status" { + rune -0 cscli console status + assert_output --partial "Option Name" + assert_output --partial "Activated" + assert_output --partial "Description" + assert_output --partial "custom" + assert_output --partial "manual" + assert_output --partial "tainted" + assert_output --partial "context" + assert_output --partial "console_management" + rune -0 cscli console status -o json + assert_json - <<- EOT + { + "console_management": false, + "context": false, + "custom": true, + "manual": false, + "tainted": true + } + EOT + rune -0 cscli console status -o raw + assert_output - <<-EOT + option,enabled + manual,false + custom,true + tainted,true + context,false + console_management,false + EOT +} + +@test "cscli console enable" { + rune -0 cscli console enable manual --debug + assert_stderr --partial "manual set to true" + assert_stderr --partial "[manual] have been enabled" + rune -0 cscli console enable manual --debug + assert_stderr --partial "manual already set to true" + assert_stderr --partial "[manual] have been enabled" + rune -0 cscli console enable manual context --debug + assert_stderr --partial "context set to true" + assert_stderr --partial "[manual context] have been enabled" + rune -0 cscli console enable --all --debug + assert_stderr --partial "custom already set to true" + assert_stderr --partial "manual already set to true" + assert_stderr --partial "tainted already set to true" + assert_stderr --partial "context already set to true" + assert_stderr --partial "console_management set to true" + assert_stderr --partial "All features have been enabled successfully" + rune -1 cscli console enable tralala + assert_stderr --partial "unknown flag tralala" +} + +@test "cscli console disable" { + rune -0 cscli console disable tainted --debug + assert_stderr --partial "tainted set to false" + assert_stderr --partial "[tainted] have been disabled" + rune -0 cscli console disable tainted --debug + assert_stderr --partial "tainted already set to false" + assert_stderr --partial "[tainted] have been disabled" + rune -0 cscli console disable tainted custom --debug + assert_stderr --partial "custom set to false" + assert_stderr --partial "[tainted custom] have been disabled" + rune -0 cscli console disable --all --debug + assert_stderr --partial "custom already set to false" + assert_stderr --partial "manual already set to false" + assert_stderr --partial "tainted already set to false" + assert_stderr --partial "context already set to false" + assert_stderr --partial "console_management already set to false" + assert_stderr --partial "All features have been disabled" + rune -1 cscli console disable tralala + assert_stderr --partial "unknown flag tralala" +} diff --git a/test/bats/09_context.bats b/test/bats/09_context.bats new file mode 100644 index 00000000000..71aabc68d29 --- /dev/null +++ b/test/bats/09_context.bats @@ -0,0 +1,113 @@ +#!/usr/bin/env bats +# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si: + +set -u + +setup_file() { + load "../lib/setup_file.sh" + CONFIG_DIR=$(config_get '.config_paths.config_dir') + export CONFIG_DIR + CONTEXT_YAML="$CONFIG_DIR/console/context.yaml" + export CONTEXT_YAML +} + +teardown_file() { + load "../lib/teardown_file.sh" +} + +setup() { + load "../lib/setup.sh" + load "../lib/bats-file/load.bash" + ./instance-data load + config_set '.common.log_media="stdout"' + mkdir -p "$CONFIG_DIR/console" +} + +teardown() { + ./instance-crowdsec stop +} + +#---------- + +@test "detect available context" { + rune -0 cscli lapi context detect -a + rune -0 yq -o json <(output) + assert_json '{"Acquisition":["evt.Line.Module","evt.Line.Raw","evt.Line.Src"]}' + + rune -0 cscli parsers install crowdsecurity/dateparse-enrich + rune -0 cscli lapi context detect crowdsecurity/dateparse-enrich + rune -0 yq -o json '.crowdsecurity/dateparse-enrich' <(output) + assert_json '["evt.MarshaledTime","evt.Meta.timestamp"]' +} + +@test "attempt to load from default context file, ignore if missing" { + rune -0 rm -f "$CONTEXT_YAML" + rune -0 "$CROWDSEC" -t --trace + assert_stderr --partial "loading console context from $CONTEXT_YAML" +} + +@test "no error if context file is missing but not explicitly set" { + config_set "del(.crowdsec_service.console_context_path)" + rune -0 rm -f "$CONTEXT_YAML" + rune -0 cscli lapi context status --error + refute_stderr + assert_output --partial "No context found on this agent." + rune -0 "$CROWDSEC" -t + refute_stderr --partial "no such file or directory" +} + +@test "error if context file is explicitly set but does not exist" { + config_set ".crowdsec_service.console_context_path=strenv(CONTEXT_YAML)" + rune -0 rm -f "$CONTEXT_YAML" + rune -1 cscli lapi context status --error + assert_stderr --partial "context.yaml: no such file or directory" + rune -1 "$CROWDSEC" -t + assert_stderr --partial "while checking console_context_path: stat $CONTEXT_YAML: no such file or directory" +} + +@test "csli lapi context delete" { + rune -1 cscli lapi context delete + assert_stderr --partial "command 'delete' has been removed, please manually edit the context file" +} + +@test "context file is bad" { + echo "bad yaml" > "$CONTEXT_YAML" + rune -1 "$CROWDSEC" -t + assert_stderr --partial "while loading context: $CONTEXT_YAML: yaml: unmarshal errors" +} + +@test "context file is good" { + echo '{"source_ip":["evt.Parsed.source_ip"]}' > "$CONTEXT_YAML" + rune -0 "$CROWDSEC" -t --debug + # the log content may have quotes escaped or not, depending on tty detection + assert_stderr --regexp 'console context to send: .*source_ip.*evt.Parsed.source_ip' +} + +@test "context file is from hub (local item)" { + mkdir -p "$CONFIG_DIR/contexts" + config_set "del(.crowdsec_service.console_context_path)" + echo '{"context":{"source_ip":["evt.Parsed.source_ip"]}}' > "$CONFIG_DIR/contexts/foobar.yaml" + rune -0 "$CROWDSEC" -t --trace + assert_stderr --partial "loading console context from $CONFIG_DIR/contexts/foobar.yaml" + assert_stderr --regexp 'console context to send: .*source_ip.*evt.Parsed.source_ip' +} + +@test "merge multiple contexts" { + mkdir -p "$CONFIG_DIR/contexts" + echo '{"context":{"one":["evt.Parsed.source_ip"]}}' > "$CONFIG_DIR/contexts/one.yaml" + echo '{"context":{"two":["evt.Parsed.source_ip"]}}' > "$CONFIG_DIR/contexts/two.yaml" + rune -0 "$CROWDSEC" -t --trace + assert_stderr --partial "loading console context from $CONFIG_DIR/contexts/one.yaml" + assert_stderr --partial "loading console context from $CONFIG_DIR/contexts/two.yaml" + assert_stderr --regexp 'console context to send: .*one.*evt.Parsed.source_ip.*two.*evt.Parsed.source_ip' +} + +@test "merge contexts from hub and context.yaml file" { + mkdir -p "$CONFIG_DIR/contexts" + echo '{"context":{"one":["evt.Parsed.source_ip"]}}' > "$CONFIG_DIR/contexts/one.yaml" + echo '{"one":["evt.Parsed.source_ip_2"]}' > "$CONFIG_DIR/console/context.yaml" + rune -0 "$CROWDSEC" -t --trace + assert_stderr --partial "loading console context from $CONFIG_DIR/contexts/one.yaml" + assert_stderr --partial "loading console context from $CONFIG_DIR/console/context.yaml" + assert_stderr --regexp 'console context to send: .*one.*evt.Parsed.source_ip.*evt.Parsed.source_ip_2' +} diff --git a/test/bats/09_socket.bats b/test/bats/09_socket.bats new file mode 100644 index 00000000000..f861d8a40dc --- /dev/null +++ b/test/bats/09_socket.bats @@ -0,0 +1,158 @@ +#!/usr/bin/env bats +# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si: + +set -u + +setup_file() { + load "../lib/setup_file.sh" + sockdir=$(TMPDIR="$BATS_FILE_TMPDIR" mktemp -u) + export sockdir + mkdir -p "$sockdir" + socket="$sockdir/crowdsec_api.sock" + export socket + LOCAL_API_CREDENTIALS=$(config_get '.api.client.credentials_path') + export LOCAL_API_CREDENTIALS +} + +teardown_file() { + load "../lib/teardown_file.sh" +} + +setup() { + load "../lib/setup.sh" + load "../lib/bats-file/load.bash" + ./instance-data load + config_set ".api.server.listen_socket=strenv(socket)" +} + +teardown() { + ./instance-crowdsec stop +} + +#---------- + +@test "cscli - connects from existing machine with socket" { + config_set "$LOCAL_API_CREDENTIALS" ".url=strenv(socket)" + + ./instance-crowdsec start + + rune -0 cscli lapi status + assert_output --regexp "Trying to authenticate with username .* on $socket" + assert_output --partial "You can successfully interact with Local API (LAPI)" +} + +@test "crowdsec - listen on both socket and TCP" { + ./instance-crowdsec start + + rune -0 cscli lapi status + assert_output --regexp "Trying to authenticate with username .* on http://127.0.0.1:8080/" + assert_output --partial "You can successfully interact with Local API (LAPI)" + + config_set "$LOCAL_API_CREDENTIALS" ".url=strenv(socket)" + + rune -0 cscli lapi status + assert_output --regexp "Trying to authenticate with username .* on $socket" + assert_output --partial "You can successfully interact with Local API (LAPI)" +} + +@test "cscli - authenticate new machine with socket" { + # verify that if a listen_uri and a socket are set, the socket is used + # by default when creating a local machine. + + rune -0 cscli machines delete "$(cscli machines list -o json | jq -r '.[].machineId')" + + # this one should be using the socket + rune -0 cscli machines add --auto --force + + using=$(config_get "$LOCAL_API_CREDENTIALS" ".url") + + assert [ "$using" = "$socket" ] + + # disable the agent because it counts as a first authentication + config_disable_agent + ./instance-crowdsec start + + # the machine does not have an IP yet + + rune -0 cscli machines list -o json + rune -0 jq -r '.[].ipAddress' <(output) + assert_output null + + # upon first authentication, it's assigned to localhost + + rune -0 cscli lapi status + + rune -0 cscli machines list -o json + rune -0 jq -r '.[].ipAddress' <(output) + assert_output 127.0.0.1 +} + +bouncer_http() { + URI="$1" + curl -fs -H "X-Api-Key: $API_KEY" "http://localhost:8080$URI" +} + +bouncer_socket() { + URI="$1" + curl -fs -H "X-Api-Key: $API_KEY" --unix-socket "$socket" "http://localhost$URI" +} + +@test "lapi - connects from existing bouncer with socket" { + ./instance-crowdsec start + API_KEY=$(cscli bouncers add testbouncer -o raw) + export API_KEY + + # the bouncer does not have an IP yet + + rune -0 cscli bouncers list -o json + rune -0 jq -r '.[].ip_address' <(output) + assert_output "" + + # upon first authentication, it's assigned to localhost + + rune -0 bouncer_socket '/v1/decisions' + assert_output 'null' + refute_stderr + + rune -0 cscli bouncers list -o json + rune -0 jq -r '.[].ip_address' <(output) + assert_output "127.0.0.1" + + # we can still use TCP of course + + rune -0 bouncer_http '/v1/decisions' + assert_output 'null' + refute_stderr +} + +@test "lapi - listen on socket only" { + config_set "del(.api.server.listen_uri)" + + mkdir -p "$sockdir" + + # agent is not able to connect right now + config_disable_agent + ./instance-crowdsec start + + API_KEY=$(cscli bouncers add testbouncer -o raw) + export API_KEY + + # now we can't + + rune -1 cscli lapi status + assert_stderr --partial "connection refused" + + rune -7 bouncer_http '/v1/decisions' + refute_output + refute_stderr + + # here we can + + config_set "$LOCAL_API_CREDENTIALS" ".url=strenv(socket)" + + rune -0 cscli lapi status + + rune -0 bouncer_socket '/v1/decisions' + assert_output 'null' + refute_stderr +} diff --git a/test/bats/10_bouncers.bats b/test/bats/10_bouncers.bats index 79ba0eda82d..f99913dcee5 100644 --- a/test/bats/10_bouncers.bats +++ b/test/bats/10_bouncers.bats @@ -25,7 +25,13 @@ teardown() { @test "there are 0 bouncers" { rune -0 cscli bouncers list -o json - assert_output "[]" + assert_json '[]' + + rune -0 cscli bouncers list -o human + assert_output --partial "Name" + + rune -0 cscli bouncers list -o raw + assert_output --partial 'name' } @test "we can add one bouncer, and delete it" { @@ -33,7 +39,80 @@ teardown() { assert_output --partial "API key for 'ciTestBouncer':" rune -0 cscli bouncers delete ciTestBouncer rune -0 cscli bouncers list -o json - assert_output '[]' + assert_json '[]' +} + +@test "bouncer api-key auth" { + rune -0 cscli bouncers add ciTestBouncer --key "goodkey" + + # connect with good credentials + rune -0 curl-tcp "/v1/decisions" -sS --fail-with-body -H "X-Api-Key: goodkey" + assert_output null + + # connect with bad credentials + rune -22 curl-tcp "/v1/decisions" -sS --fail-with-body -H "X-Api-Key: badkey" + assert_stderr --partial 'error: 403' + assert_json '{message:"access forbidden"}' + + # connect with no credentials + rune -22 curl-tcp "/v1/decisions" -sS --fail-with-body + assert_stderr --partial 'error: 403' + assert_json '{message:"access forbidden"}' +} + +@test "delete non-existent bouncer" { + # this is a fatal error, which is not consistent with "machines delete" + rune -1 cscli bouncers delete something + assert_stderr --partial "unable to delete bouncer: 'something' does not exist" + rune -0 cscli bouncers delete something --ignore-missing + refute_stderr +} + +@test "bouncers delete has autocompletion" { + rune -0 cscli bouncers add foo1 + rune -0 cscli bouncers add foo2 + rune -0 cscli bouncers add bar + rune -0 cscli bouncers add baz + rune -0 cscli __complete bouncers delete 'foo' + assert_line --index 0 'foo1' + assert_line --index 1 'foo2' + refute_line 'bar' + refute_line 'baz' +} + +@test "cscli bouncers list" { + export API_KEY=bouncerkey + rune -0 cscli bouncers add ciTestBouncer --key "$API_KEY" + + rune -0 cscli bouncers list -o json + rune -0 jq -c '.[] | [.ip_address,.last_pull,.name]' <(output) + assert_json '["",null,"ciTestBouncer"]' + rune -0 cscli bouncers list -o raw + assert_line 'name,ip,revoked,last_pull,type,version,auth_type' + assert_line 'ciTestBouncer,,validated,,,,api-key' + rune -0 cscli bouncers list -o human + assert_output --regexp 'ciTestBouncer.*api-key.*' + + # the first connection sets last_pull and ip address + rune -0 curl-with-key '/v1/decisions' + rune -0 cscli bouncers list -o json + rune -0 jq -r '.[] | .ip_address' <(output) + assert_output 127.0.0.1 + rune -0 cscli bouncers list -o json + rune -0 jq -r '.[] | .last_pull' <(output) + refute_output null +} + +@test "we can create a bouncer with a known key" { + # also test the output formats since we know the key + rune -0 cscli bouncers add ciTestBouncer --key "foobarbaz" -o human + assert_output --partial 'foobarbaz' + rune -0 cscli bouncers delete ciTestBouncer + rune -0 cscli bouncers add ciTestBouncer --key "foobarbaz" -o json + assert_output '"foobarbaz"' + rune -0 cscli bouncers delete ciTestBouncer + rune -0 cscli bouncers add ciTestBouncer --key "foobarbaz" -o raw + assert_output foobarbaz } @test "we can't add the same bouncer twice" { @@ -56,3 +135,12 @@ teardown() { rune -1 cscli bouncers delete ciTestBouncer rune -1 cscli bouncers delete foobarbaz } + +@test "cscli bouncers prune" { + rune -0 cscli bouncers prune + assert_output 'No bouncers to prune.' + rune -0 cscli bouncers add ciTestBouncer + + rune -0 cscli bouncers prune + assert_output 'No bouncers to prune.' +} diff --git a/test/bats/11_bouncers_tls.bats b/test/bats/11_bouncers_tls.bats index 8fb4579259d..554308ae962 100644 --- a/test/bats/11_bouncers_tls.bats +++ b/test/bats/11_bouncers_tls.bats @@ -3,36 +3,116 @@ set -u +# root: root CA +# inter: intermediate CA +# inter_rev: intermediate CA revoked by root (CRL3) +# leaf: valid client cert +# leaf_rev1: client cert revoked by inter (CRL1) +# leaf_rev2: client cert revoked by inter (CRL2) +# leaf_rev3: client cert (indirectly) revoked by root +# +# CRL1: inter revokes leaf_rev1 +# CRL2: inter revokes leaf_rev2 +# CRL3: root revokes inter_rev +# CRL4: root revokes leaf, but is ignored + setup_file() { load "../lib/setup_file.sh" ./instance-data load - tmpdir="${BATS_FILE_TMPDIR}" + tmpdir="$BATS_FILE_TMPDIR" export tmpdir - CFDIR="${BATS_TEST_DIRNAME}/testdata/cfssl" + CFDIR="$BATS_TEST_DIRNAME/testdata/cfssl" export CFDIR - #gen the CA - cfssl gencert --initca "${CFDIR}/ca.json" 2>/dev/null | cfssljson --bare "${tmpdir}/ca" - #gen an intermediate - cfssl gencert --initca "${CFDIR}/intermediate.json" 2>/dev/null | cfssljson --bare "${tmpdir}/inter" - cfssl sign -ca "${tmpdir}/ca.pem" -ca-key "${tmpdir}/ca-key.pem" -config "${CFDIR}/profiles.json" -profile intermediate_ca "${tmpdir}/inter.csr" 2>/dev/null | cfssljson --bare "${tmpdir}/inter" - #gen server cert for crowdsec with the intermediate - cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config "${CFDIR}/profiles.json" -profile=server "${CFDIR}/server.json" 2>/dev/null | cfssljson --bare "${tmpdir}/server" - #gen client cert for the bouncer - cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config "${CFDIR}/profiles.json" -profile=client "${CFDIR}/bouncer.json" 2>/dev/null | cfssljson --bare "${tmpdir}/bouncer" - #gen client cert for the bouncer with an invalid OU - cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config "${CFDIR}/profiles.json" -profile=client "${CFDIR}/bouncer_invalid.json" 2>/dev/null | cfssljson --bare "${tmpdir}/bouncer_bad_ou" - #gen client cert for the bouncer directly signed by the CA, it should be refused by crowdsec as uses the intermediate - cfssl gencert -ca "${tmpdir}/ca.pem" -ca-key "${tmpdir}/ca-key.pem" -config "${CFDIR}/profiles.json" -profile=client "${CFDIR}/bouncer.json" 2>/dev/null | cfssljson --bare "${tmpdir}/bouncer_invalid" - - cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config "${CFDIR}/profiles.json" -profile=client "${CFDIR}/bouncer.json" 2>/dev/null | cfssljson --bare "${tmpdir}/bouncer_revoked" - serial="$(openssl x509 -noout -serial -in "${tmpdir}/bouncer_revoked.pem" | cut -d '=' -f2)" - echo "ibase=16; ${serial}" | bc >"${tmpdir}/serials.txt" - cfssl gencrl "${tmpdir}/serials.txt" "${tmpdir}/ca.pem" "${tmpdir}/ca-key.pem" | base64 -d | openssl crl -inform DER -out "${tmpdir}/crl.pem" - - cat "${tmpdir}/ca.pem" "${tmpdir}/inter.pem" > "${tmpdir}/bundle.pem" + # Root CA + cfssl gencert -loglevel 2 \ + --initca "$CFDIR/ca_root.json" \ + | cfssljson --bare "$tmpdir/root" + + # Intermediate CAs (valid or revoked) + for cert in "inter" "inter_rev"; do + cfssl gencert -loglevel 2 \ + --initca "$CFDIR/ca_intermediate.json" \ + | cfssljson --bare "$tmpdir/$cert" + + cfssl sign -loglevel 2 \ + -ca "$tmpdir/root.pem" -ca-key "$tmpdir/root-key.pem" \ + -config "$CFDIR/profiles.json" -profile intermediate_ca "$tmpdir/$cert.csr" \ + | cfssljson --bare "$tmpdir/$cert" + done + + # Server cert for crowdsec with the intermediate + cfssl gencert -loglevel 2 \ + -ca "$tmpdir/inter.pem" -ca-key "$tmpdir/inter-key.pem" \ + -config "$CFDIR/profiles.json" -profile=server "$CFDIR/server.json" \ + | cfssljson --bare "$tmpdir/server" + + # Client certs (valid or revoked) + for cert in "leaf" "leaf_rev1" "leaf_rev2"; do + cfssl gencert -loglevel 3 \ + -ca "$tmpdir/inter.pem" -ca-key "$tmpdir/inter-key.pem" \ + -config "$CFDIR/profiles.json" -profile=client \ + "$CFDIR/bouncer.json" \ + | cfssljson --bare "$tmpdir/$cert" + done + + # Client cert (by revoked inter) + cfssl gencert -loglevel 3 \ + -ca "$tmpdir/inter_rev.pem" -ca-key "$tmpdir/inter_rev-key.pem" \ + -config "$CFDIR/profiles.json" -profile=client \ + "$CFDIR/bouncer.json" \ + | cfssljson --bare "$tmpdir/leaf_rev3" + + # Bad client cert (invalid OU) + cfssl gencert -loglevel 3 \ + -ca "$tmpdir/inter.pem" -ca-key "$tmpdir/inter-key.pem" \ + -config "$CFDIR/profiles.json" -profile=client \ + "$CFDIR/bouncer_invalid.json" \ + | cfssljson --bare "$tmpdir/leaf_bad_ou" + + # Bad client cert (directly signed by the CA, it should be refused by crowdsec as it uses the intermediate) + cfssl gencert -loglevel 3 \ + -ca "$tmpdir/root.pem" -ca-key "$tmpdir/root-key.pem" \ + -config "$CFDIR/profiles.json" -profile=client \ + "$CFDIR/bouncer.json" \ + | cfssljson --bare "$tmpdir/leaf_invalid" + + truncate -s 0 "$tmpdir/crl.pem" + + # Revoke certs + { + echo '-----BEGIN X509 CRL-----' + cfssl gencrl \ + <(cert_serial_number "$tmpdir/leaf_rev1.pem") \ + "$tmpdir/inter.pem" \ + "$tmpdir/inter-key.pem" + echo '-----END X509 CRL-----' + + echo '-----BEGIN X509 CRL-----' + cfssl gencrl \ + <(cert_serial_number "$tmpdir/leaf_rev2.pem") \ + "$tmpdir/inter.pem" \ + "$tmpdir/inter-key.pem" + echo '-----END X509 CRL-----' + + echo '-----BEGIN X509 CRL-----' + cfssl gencrl \ + <(cert_serial_number "$tmpdir/inter_rev.pem") \ + "$tmpdir/root.pem" \ + "$tmpdir/root-key.pem" + echo '-----END X509 CRL-----' + + echo '-----BEGIN X509 CRL-----' + cfssl gencrl \ + <(cert_serial_number "$tmpdir/leaf.pem") \ + "$tmpdir/root.pem" \ + "$tmpdir/root-key.pem" + echo '-----END X509 CRL-----' + } >> "$tmpdir/crl.pem" + + cat "$tmpdir/root.pem" "$tmpdir/inter.pem" > "$tmpdir/bundle.pem" config_set ' .api.server.tls.cert_file=strenv(tmpdir) + "/server.pem" | @@ -65,9 +145,14 @@ teardown() { assert_output "[]" } -@test "simulate one bouncer request with a valid cert" { - rune -0 curl -s --cert "${tmpdir}/bouncer.pem" --key "${tmpdir}/bouncer-key.pem" --cacert "${tmpdir}/bundle.pem" https://localhost:8080/v1/decisions\?ip=42.42.42.42 +@test "simulate a bouncer request with a valid cert" { + rune -0 curl --fail-with-body -sS \ + --cert "$tmpdir/leaf.pem" \ + --key "$tmpdir/leaf-key.pem" \ + --cacert "$tmpdir/bundle.pem" \ + https://localhost:8080/v1/decisions\?ip=42.42.42.42 assert_output "null" + refute_stderr rune -0 cscli bouncers list -o json rune -0 jq '. | length' <(output) assert_output '1' @@ -77,21 +162,86 @@ teardown() { rune cscli bouncers delete localhost@127.0.0.1 } -@test "simulate one bouncer request with an invalid cert" { - rune curl -s --cert "${tmpdir}/bouncer_invalid.pem" --key "${tmpdir}/bouncer_invalid-key.pem" --cacert "${tmpdir}/ca-key.pem" https://localhost:8080/v1/decisions\?ip=42.42.42.42 - rune -0 cscli bouncers list -o json - assert_output "[]" +@test "a bouncer authenticated with TLS can send metrics" { + payload=$(yq -o j <<-EOT + remediation_components: [] + log_processors: [] + EOT + ) + + # with mutual authentication there is no api key, so it's detected as RC if user agent != crowdsec + + rune -22 curl --fail-with-body -sS \ + --cert "$tmpdir/leaf.pem" \ + --key "$tmpdir/leaf-key.pem" \ + --cacert "$tmpdir/bundle.pem" \ + https://localhost:8080/v1/usage-metrics -X POST --data "$payload" + assert_stderr --partial 'error: 400' + assert_json '{message: "Missing remediation component data"}' + + rune -22 curl --fail-with-body -sS \ + --cert "$tmpdir/leaf.pem" \ + --key "$tmpdir/leaf-key.pem" \ + --cacert "$tmpdir/bundle.pem" \ + --user-agent "crowdsec/someversion" \ + https://localhost:8080/v1/usage-metrics -X POST --data "$payload" + assert_stderr --partial 'error: 401' + assert_json '{code:401, message: "cookie token is empty"}' + + rune cscli bouncers delete localhost@127.0.0.1 } -@test "simulate one bouncer request with an invalid OU" { - rune curl -s --cert "${tmpdir}/bouncer_bad_ou.pem" --key "${tmpdir}/bouncer_bad_ou-key.pem" --cacert "${tmpdir}/bundle.pem" https://localhost:8080/v1/decisions\?ip=42.42.42.42 +@test "simulate a bouncer request with an invalid cert" { + rune -77 curl --fail-with-body -sS \ + --cert "$tmpdir/leaf_invalid.pem" \ + --key "$tmpdir/leaf_invalid-key.pem" \ + --cacert "$tmpdir/root-key.pem" \ + https://localhost:8080/v1/decisions\?ip=42.42.42.42 + assert_stderr --partial 'error setting certificate file' rune -0 cscli bouncers list -o json assert_output "[]" } -@test "simulate one bouncer request with a revoked certificate" { - rune -0 curl -i -s --cert "${tmpdir}/bouncer_revoked.pem" --key "${tmpdir}/bouncer_revoked-key.pem" --cacert "${tmpdir}/bundle.pem" https://localhost:8080/v1/decisions\?ip=42.42.42.42 - assert_output --partial "access forbidden" +@test "simulate a bouncer request with an invalid OU" { + rune -22 curl --fail-with-body -sS \ + --cert "$tmpdir/leaf_bad_ou.pem" \ + --key "$tmpdir/leaf_bad_ou-key.pem" \ + --cacert "$tmpdir/bundle.pem" \ + https://localhost:8080/v1/decisions\?ip=42.42.42.42 + assert_json '{message: "access forbidden"}' + assert_stderr --partial 'error: 403' rune -0 cscli bouncers list -o json assert_output "[]" } + +@test "simulate a bouncer request with a revoked certificate" { + # we have two certificates revoked by different CRL blocks + # we connect twice to test the cache too + for cert in "leaf_rev1" "leaf_rev2" "leaf_rev1" "leaf_rev2"; do + truncate_log + rune -22 curl --fail-with-body -sS \ + --cert "$tmpdir/$cert.pem" \ + --key "$tmpdir/$cert-key.pem" \ + --cacert "$tmpdir/bundle.pem" \ + https://localhost:8080/v1/decisions\?ip=42.42.42.42 + assert_log --partial "certificate revoked by CRL" + assert_json '{message: "access forbidden"}' + assert_stderr --partial "error: 403" + rune -0 cscli bouncers list -o json + assert_output "[]" + done +} + +# vvv this test must be last, or it can break the ones that follow + +@test "allowed_ou can't contain an empty string" { + ./instance-crowdsec stop + config_set ' + .common.log_media="stdout" | + .api.server.tls.bouncers_allowed_ou=["bouncer-ou", ""] + ' + rune -1 wait-for "$CROWDSEC" + assert_stderr --partial "allowed_ou configuration contains invalid empty string" +} + +# ^^^ this test must be last, or it can break the ones that follow diff --git a/test/bats/12_notifications.bats b/test/bats/12_notifications.bats new file mode 100644 index 00000000000..86032bf8212 --- /dev/null +++ b/test/bats/12_notifications.bats @@ -0,0 +1,39 @@ +#!/usr/bin/env bats +# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si: + +set -u + +setup_file() { + load "../lib/setup_file.sh" +} + +teardown_file() { + load "../lib/teardown_file.sh" +} + +setup() { + load "../lib/setup.sh" + load "../lib/bats-file/load.bash" + ./instance-data load + ./instance-crowdsec start +} + +teardown() { + cd "$TEST_DIR" || exit 1 + ./instance-crowdsec stop +} + +#---------- + +@test "cscli notifications list" { + rune -0 cscli notifications list + assert_output --partial "Name" + assert_output --partial "Type" + assert_output --partial "Profile name" +} + +@test "cscli notifications must be run from lapi" { + config_disable_lapi + rune -1 cscli notifications list + assert_stderr --partial "local API is disabled -- this command must be run on the local API machine" +} diff --git a/test/bats/13_capi_whitelists.bats b/test/bats/13_capi_whitelists.bats new file mode 100644 index 00000000000..ed7ef2ac560 --- /dev/null +++ b/test/bats/13_capi_whitelists.bats @@ -0,0 +1,90 @@ +#!/usr/bin/env bats +# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si: + +set -u + +setup_file() { + load "../lib/setup_file.sh" + CONFIG_DIR=$(dirname "$CONFIG_YAML") + CAPI_WHITELISTS_YAML="$CONFIG_DIR/capi-whitelists.yaml" + export CAPI_WHITELISTS_YAML +} + +teardown_file() { + load "../lib/teardown_file.sh" +} + +setup() { + load "../lib/setup.sh" + load "../lib/bats-file/load.bash" + ./instance-data load + config_set '.common.log_media="stdout"' + config_set '.api.server.capi_whitelists_path=strenv(CAPI_WHITELISTS_YAML)' +} + +teardown() { + ./instance-crowdsec stop +} + +#---------- + +@test "capi_whitelists: file missing" { + rune -0 wait-for \ + --err "while opening capi whitelist file: open $CAPI_WHITELISTS_YAML: no such file or directory" \ + "$CROWDSEC" +} + +@test "capi_whitelists: error on open" { + echo > "$CAPI_WHITELISTS_YAML" + chmod 000 "$CAPI_WHITELISTS_YAML" + if is_package_testing; then + rune -0 wait-for \ + --err "while parsing capi whitelist file .*: empty file" \ + "$CROWDSEC" + else + rune -0 wait-for \ + --err "while opening capi whitelist file: open $CAPI_WHITELISTS_YAML: permission denied" \ + "$CROWDSEC" + fi +} + +@test "capi_whitelists: empty file" { + echo > "$CAPI_WHITELISTS_YAML" + rune -0 wait-for \ + --err "while parsing capi whitelist file '$CAPI_WHITELISTS_YAML': empty file" \ + "$CROWDSEC" +} + +@test "capi_whitelists: empty lists" { + echo '{"ips": [], "cidrs": []}' > "$CAPI_WHITELISTS_YAML" + rune -0 wait-for \ + --err "Starting processing data" \ + "$CROWDSEC" +} + +@test "capi_whitelists: bad ip" { + echo '{"ips": ["blahblah"], "cidrs": []}' > "$CAPI_WHITELISTS_YAML" + rune -0 wait-for \ + --err "while parsing capi whitelist file '$CAPI_WHITELISTS_YAML': invalid IP address: blahblah" \ + "$CROWDSEC" +} + +@test "capi_whitelists: bad cidr" { + echo '{"ips": [], "cidrs": ["blahblah"]}' > "$CAPI_WHITELISTS_YAML" + rune -0 wait-for \ + --err "while parsing capi whitelist file '$CAPI_WHITELISTS_YAML': invalid CIDR address: blahblah" \ + "$CROWDSEC" +} + +@test "capi_whitelists: file with ip and cidr values" { + cat <<-EOT > "$CAPI_WHITELISTS_YAML" + ips: + - 1.2.3.4 + - 2.3.4.5 + cidrs: + - 1.2.3.0/24 + EOT + + config_set '.common.log_level="trace"' + rune -0 ./instance-crowdsec start +} diff --git a/test/bats/20_collections.bats b/test/bats/20_collections.bats deleted file mode 100644 index aa1fa6b21d0..00000000000 --- a/test/bats/20_collections.bats +++ /dev/null @@ -1,145 +0,0 @@ -#!/usr/bin/env bats -# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si: - -set -u - -setup_file() { - load "../lib/setup_file.sh" -} - -teardown_file() { - load "../lib/teardown_file.sh" -} - -setup() { - load "../lib/setup.sh" - ./instance-data load - ./instance-crowdsec start -} - -teardown() { - ./instance-crowdsec stop -} - -#---------- - -@test "we can list collections" { - rune -0 cscli collections list -} - -@test "there are 2 collections (linux and sshd)" { - rune -0 cscli collections list -o json - rune -0 jq '.collections | length' <(output) - assert_output 2 -} - -@test "can install a collection (as a regular user) and remove it" { - # collection is not installed - rune -0 cscli collections list -o json - rune -0 jq -r '.collections[].name' <(output) - refute_line "crowdsecurity/mysql" - - # we install it - rune -0 cscli collections install crowdsecurity/mysql -o human - assert_stderr --partial "Enabled crowdsecurity/mysql" - - # it has been installed - rune -0 cscli collections list -o json - rune -0 jq -r '.collections[].name' <(output) - assert_line "crowdsecurity/mysql" - - # we install it - rune -0 cscli collections remove crowdsecurity/mysql -o human - assert_stderr --partial "Removed symlink [crowdsecurity/mysql]" - - # it has been removed - rune -0 cscli collections list -o json - rune -0 jq -r '.collections[].name' <(output) - refute_line "crowdsecurity/mysql" -} - -@test "must use --force to remove a collection that belongs to another, which becomes tainted" { - # we expect no error since we may have multiple collections, some removed and some not - rune -0 cscli collections remove crowdsecurity/sshd - assert_stderr --partial "crowdsecurity/sshd belongs to other collections" - assert_stderr --partial "[crowdsecurity/linux]" - - rune -0 cscli collections remove crowdsecurity/sshd --force - assert_stderr --partial "Removed symlink [crowdsecurity/sshd]" - rune -0 cscli collections inspect crowdsecurity/linux -o json - rune -0 jq -r '.tainted' <(output) - assert_output "true" -} - -@test "can remove a collection" { - rune -0 cscli collections remove crowdsecurity/linux - assert_stderr --partial "Removed" - assert_stderr --regexp ".*for the new configuration to be effective." - rune -0 cscli collections inspect crowdsecurity/linux -o human - assert_line 'installed: false' -} - -@test "collections delete is an alias for collections remove" { - rune -0 cscli collections delete crowdsecurity/linux - assert_stderr --partial "Removed" - assert_stderr --regexp ".*for the new configuration to be effective." -} - -@test "removing a collection that does not exist is noop" { - rune -0 cscli collections remove crowdsecurity/apache2 - refute_stderr --partial "Removed" - assert_stderr --regexp ".*for the new configuration to be effective." -} - -@test "can remove a removed collection" { - rune -0 cscli collections install crowdsecurity/mysql - rune -0 cscli collections remove crowdsecurity/mysql - assert_stderr --partial "Removed" - rune -0 cscli collections remove crowdsecurity/mysql - refute_stderr --partial "Removed" -} - -@test "can remove all collections" { - # we may have this too, from package installs - rune cscli parsers delete crowdsecurity/whitelists - rune -0 cscli collections remove --all - assert_stderr --partial "Removed symlink [crowdsecurity/sshd]" - assert_stderr --partial "Removed symlink [crowdsecurity/linux]" - rune -0 cscli hub list -o json - assert_json '{collections:[],parsers:[],postoverflows:[],scenarios:[]}' - rune -0 cscli collections remove --all - assert_stderr --partial 'Disabled 0 items' -} - -@test "a taint bubbles up to the top collection" { - coll=crowdsecurity/nginx - subcoll=crowdsecurity/base-http-scenarios - scenario=crowdsecurity/http-crawl-non_statics - - # install a collection with dependencies - rune -0 cscli collections install "$coll" - - # the collection, subcollection and scenario are installed and not tainted - # we have to default to false because tainted is (as of 1.4.6) returned - # only when true - rune -0 cscli collections inspect "$coll" -o json - rune -0 jq -e '(.installed,.tainted|false)==(true,false)' <(output) - rune -0 cscli collections inspect "$subcoll" -o json - rune -0 jq -e '(.installed,.tainted|false)==(true,false)' <(output) - rune -0 cscli scenarios inspect "$scenario" -o json - rune -0 jq -e '(.installed,.tainted|false)==(true,false)' <(output) - - # we taint the scenario - HUB_DIR=$(config_get '.config_paths.hub_dir') - yq e '.description="I am tainted"' -i "$HUB_DIR/scenarios/$scenario.yaml" - - # the collection, subcollection and scenario are now tainted - rune -0 cscli scenarios inspect "$scenario" -o json - rune -0 jq -e '(.installed,.tainted)==(true,true)' <(output) - rune -0 cscli collections inspect "$subcoll" -o json - rune -0 jq -e '(.installed,.tainted)==(true,true)' <(output) - rune -0 cscli collections inspect "$coll" -o json - rune -0 jq -e '(.installed,.tainted)==(true,true)' <(output) -} - -# TODO test download-only diff --git a/test/bats/20_hub.bats b/test/bats/20_hub.bats new file mode 100644 index 00000000000..b8fa1e9efca --- /dev/null +++ b/test/bats/20_hub.bats @@ -0,0 +1,166 @@ +#!/usr/bin/env bats +# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si: + +set -u + +setup_file() { + load "../lib/setup_file.sh" + ./instance-data load + INDEX_PATH=$(config_get '.config_paths.index_path') + export INDEX_PATH + CONFIG_DIR=$(config_get '.config_paths.config_dir') + export CONFIG_DIR +} + +teardown_file() { + load "../lib/teardown_file.sh" +} + +setup() { + load "../lib/setup.sh" + load "../lib/bats-file/load.bash" + ./instance-data load + hub_strip_index +} + +teardown() { + : +} + +#---------- + +@test "cscli hub list" { + hub_purge_all + + # no items + rune -0 cscli hub list + assert_output "No items to display" + rune -0 cscli hub list -o json + assert_json '{"appsec-configs":[],"appsec-rules":[],parsers:[],scenarios:[],collections:[],contexts:[],postoverflows:[]}' + rune -0 cscli hub list -o raw + assert_output 'name,status,version,description,type' + + # some items: with output=human, show only non-empty tables + rune -0 cscli parsers install crowdsecurity/whitelists + rune -0 cscli scenarios install crowdsecurity/telnet-bf + rune -0 cscli hub list + assert_output --regexp ".*PARSERS.*crowdsecurity/whitelists.*SCENARIOS.*crowdsecurity/telnet-bf.*" + refute_output --partial 'POSTOVERFLOWS' + refute_output --partial 'COLLECTIONS' + + rune -0 cscli hub list -o json + rune -0 jq -e '(.parsers | length == 1) and (.scenarios | length == 1)' <(output) + rune -0 cscli hub list -o raw + assert_output --partial 'crowdsecurity/whitelists' + assert_output --partial 'crowdsecurity/telnet-bf' + refute_output --partial 'crowdsecurity/iptables' + + # all items + mkdir -p "$CONFIG_DIR/contexts" + # there are no contexts yet, so we create a local one + touch "$CONFIG_DIR/contexts/mycontext.yaml" + rune -0 cscli hub list -a + assert_output --regexp ".*PARSERS.*crowdsecurity/whitelists.*POSTOVERFLOWS.*SCENARIOS.*crowdsecurity/telnet-bf.*CONTEXTS.*mycontext.yaml.*COLLECTIONS.*crowdsecurity/iptables.*" + rune -0 cscli hub list -a -o json + rune -0 jq -e '(.parsers | length > 1) and (.scenarios | length > 1)' <(output) + rune -0 cscli hub list -a -o raw + assert_output --partial 'crowdsecurity/whitelists' + assert_output --partial 'crowdsecurity/telnet-bf' + assert_output --partial 'crowdsecurity/iptables' +} + +@test "cscli hub list (invalid index)" { + new_hub=$(jq <"$INDEX_PATH" '."appsec-rules"."crowdsecurity/vpatch-laravel-debug-mode".version="999"') + echo "$new_hub" >"$INDEX_PATH" + rune -0 cscli hub list --error + assert_stderr --partial "invalid hub item appsec-rules:crowdsecurity/vpatch-laravel-debug-mode: latest version missing from index" + + rune -1 cscli appsec-rules install crowdsecurity/vpatch-laravel-debug-mode --force + assert_stderr --partial "error while installing 'crowdsecurity/vpatch-laravel-debug-mode': latest hash missing from index. The index file is invalid, please run 'cscli hub update' and try again" +} + +@test "missing reference in hub index" { + new_hub=$(jq <"$INDEX_PATH" 'del(.parsers."crowdsecurity/smb-logs") | del (.scenarios."crowdsecurity/mysql-bf")') + echo "$new_hub" >"$INDEX_PATH" + rune -0 cscli hub list --error + assert_stderr --partial "can't find crowdsecurity/smb-logs in parsers, required by crowdsecurity/smb" + assert_stderr --partial "can't find crowdsecurity/mysql-bf in scenarios, required by crowdsecurity/mysql" +} + +@test "loading hub reports tainted items (subitem is tainted)" { + rune -0 cscli collections install crowdsecurity/sshd + rune -0 cscli hub list + refute_stderr --partial "tainted" + rune -0 truncate -s0 "$CONFIG_DIR/parsers/s01-parse/sshd-logs.yaml" + rune -0 cscli hub list + assert_stderr --partial "crowdsecurity/sshd is tainted by parsers:crowdsecurity/sshd-logs" +} + +@test "loading hub reports tainted items (subitem is not installed)" { + rune -0 cscli collections install crowdsecurity/sshd + rune -0 cscli hub list + refute_stderr --partial "tainted" + rune -0 rm "$CONFIG_DIR/parsers/s01-parse/sshd-logs.yaml" + rune -0 cscli hub list + assert_stderr --partial "crowdsecurity/sshd is tainted by missing parsers:crowdsecurity/sshd-logs" +} + +@test "cscli hub update" { + rm -f "$INDEX_PATH" + rune -0 cscli hub update + assert_stderr --partial "Wrote index to $INDEX_PATH" + rune -0 cscli hub update + assert_stderr --partial "hub index is up to date" +} + +@test "cscli hub upgrade" { + rune -0 cscli hub upgrade + assert_stderr --partial "Upgrading parsers" + assert_stderr --partial "Upgraded 0 parsers" + assert_stderr --partial "Upgrading postoverflows" + assert_stderr --partial "Upgraded 0 postoverflows" + assert_stderr --partial "Upgrading scenarios" + assert_stderr --partial "Upgraded 0 scenarios" + assert_stderr --partial "Upgrading contexts" + assert_stderr --partial "Upgraded 0 contexts" + assert_stderr --partial "Upgrading collections" + assert_stderr --partial "Upgraded 0 collections" + assert_stderr --partial "Upgrading appsec-configs" + assert_stderr --partial "Upgraded 0 appsec-configs" + assert_stderr --partial "Upgrading appsec-rules" + assert_stderr --partial "Upgraded 0 appsec-rules" + assert_stderr --partial "Upgrading collections" + assert_stderr --partial "Upgraded 0 collections" + + rune -0 cscli parsers install crowdsecurity/syslog-logs + rune -0 cscli hub upgrade + assert_stderr --partial "crowdsecurity/syslog-logs: up-to-date" + + rune -0 cscli hub upgrade --force + assert_stderr --partial "crowdsecurity/syslog-logs: up-to-date" + assert_stderr --partial "crowdsecurity/syslog-logs: updated" + assert_stderr --partial "Upgraded 1 parsers" + # this is used by the cron script to know if the hub was updated + assert_output --partial "updated crowdsecurity/syslog-logs" +} + +@test "cscli hub upgrade (with local items)" { + mkdir -p "$CONFIG_DIR/collections" + touch "$CONFIG_DIR/collections/foo.yaml" + rune -0 cscli hub upgrade + assert_stderr --partial "not upgrading foo.yaml: local item" +} + +@test "cscli hub types" { + rune -0 cscli hub types -o raw + assert_line "parsers" + assert_line "postoverflows" + assert_line "scenarios" + assert_line "contexts" + assert_line "collections" + rune -0 cscli hub types -o human + rune -0 yq -o json <(output) + assert_json '["parsers","postoverflows","scenarios","contexts","appsec-configs","appsec-rules","collections"]' + rune -0 cscli hub types -o json + assert_json '["parsers","postoverflows","scenarios","contexts","appsec-configs","appsec-rules","collections"]' +} diff --git a/test/bats/20_hub_collections.bats b/test/bats/20_hub_collections.bats new file mode 100644 index 00000000000..6822339ae40 --- /dev/null +++ b/test/bats/20_hub_collections.bats @@ -0,0 +1,381 @@ +#!/usr/bin/env bats +# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si: + +set -u + +setup_file() { + load "../lib/setup_file.sh" + ./instance-data load + HUB_DIR=$(config_get '.config_paths.hub_dir') + export HUB_DIR + INDEX_PATH=$(config_get '.config_paths.index_path') + export INDEX_PATH + CONFIG_DIR=$(config_get '.config_paths.config_dir') + export CONFIG_DIR +} + +teardown_file() { + load "../lib/teardown_file.sh" +} + +setup() { + load "../lib/setup.sh" + load "../lib/bats-file/load.bash" + ./instance-data load + hub_strip_index +} + +teardown() { + ./instance-crowdsec stop +} + +#---------- + +@test "cscli collections list" { + hub_purge_all + + # no items + rune -0 cscli collections list + assert_output --partial "COLLECTIONS" + rune -0 cscli collections list -o json + assert_json '{collections:[]}' + rune -0 cscli collections list -o raw + assert_output 'name,status,version,description' + + # some items + rune -0 cscli collections install crowdsecurity/sshd crowdsecurity/smb + + rune -0 cscli collections list + assert_output --partial crowdsecurity/sshd + assert_output --partial crowdsecurity/smb + rune -0 grep -c enabled <(output) + assert_output "2" + + rune -0 cscli collections list -o json + assert_output --partial crowdsecurity/sshd + assert_output --partial crowdsecurity/smb + rune -0 jq '.collections | length' <(output) + assert_output "2" + + rune -0 cscli collections list -o raw + assert_output --partial crowdsecurity/sshd + assert_output --partial crowdsecurity/smb + rune -0 grep -vc 'name,status,version,description' <(output) + assert_output "2" +} + +@test "cscli collections list -a" { + expected=$(jq <"$INDEX_PATH" -r '.collections | length') + + rune -0 cscli collections list -a + rune -0 grep -c disabled <(output) + assert_output "$expected" + + rune -0 cscli collections list -o json -a + rune -0 jq '.collections | length' <(output) + assert_output "$expected" + + rune -0 cscli collections list -o raw -a + rune -0 grep -vc 'name,status,version,description' <(output) + assert_output "$expected" + + # the list should be the same in all formats, and sorted (not case sensitive) + + list_raw=$(cscli collections list -o raw -a | tail -n +2 | cut -d, -f1) + list_human=$(cscli collections list -o human -a | tail -n +6 | head -n -1 | cut -d' ' -f2) + list_json=$(cscli collections list -o json -a | jq -r '.collections[].name') + + rune -0 sort -f <<<"$list_raw" + assert_output "$list_raw" + + assert_equal "$list_raw" "$list_json" + assert_equal "$list_raw" "$list_human" +} + +@test "cscli collections list [collection]..." { + # non-existent + rune -1 cscli collections install foo/bar + assert_stderr --partial "can't find 'foo/bar' in collections" + + # not installed + rune -0 cscli collections list crowdsecurity/smb + assert_output --regexp 'crowdsecurity/smb.*disabled' + + # install two items + rune -0 cscli collections install crowdsecurity/sshd crowdsecurity/smb + + # list an installed item + rune -0 cscli collections list crowdsecurity/sshd + assert_output --regexp "crowdsecurity/sshd" + refute_output --partial "crowdsecurity/smb" + + # list multiple installed and non installed items + rune -0 cscli collections list crowdsecurity/sshd crowdsecurity/smb crowdsecurity/nginx + assert_output --partial "crowdsecurity/sshd" + assert_output --partial "crowdsecurity/smb" + assert_output --partial "crowdsecurity/nginx" + + rune -0 cscli collections list crowdsecurity/sshd -o json + rune -0 jq '.collections | length' <(output) + assert_output "1" + rune -0 cscli collections list crowdsecurity/sshd crowdsecurity/smb crowdsecurity/nginx -o json + rune -0 jq '.collections | length' <(output) + assert_output "3" + + rune -0 cscli collections list crowdsecurity/sshd -o raw + rune -0 grep -vc 'name,status,version,description' <(output) + assert_output "1" + rune -0 cscli collections list crowdsecurity/sshd crowdsecurity/smb -o raw + rune -0 grep -vc 'name,status,version,description' <(output) + assert_output "2" +} + +@test "cscli collections install" { + rune -1 cscli collections install + assert_stderr --partial 'requires at least 1 arg(s), only received 0' + + # not in hub + rune -1 cscli collections install crowdsecurity/blahblah + assert_stderr --partial "can't find 'crowdsecurity/blahblah' in collections" + + # simple install + rune -0 cscli collections install crowdsecurity/sshd + rune -0 cscli collections inspect crowdsecurity/sshd --no-metrics + assert_output --partial 'crowdsecurity/sshd' + assert_output --partial 'installed: true' + + # autocorrect + rune -1 cscli collections install crowdsecurity/ssshd + assert_stderr --partial "can't find 'crowdsecurity/ssshd' in collections, did you mean 'crowdsecurity/sshd'?" + + # install multiple + rune -0 cscli collections install crowdsecurity/sshd crowdsecurity/smb + rune -0 cscli collections inspect crowdsecurity/sshd --no-metrics + assert_output --partial 'crowdsecurity/sshd' + assert_output --partial 'installed: true' + rune -0 cscli collections inspect crowdsecurity/smb --no-metrics + assert_output --partial 'crowdsecurity/smb' + assert_output --partial 'installed: true' +} + +@test "cscli collections install (file location and download-only)" { + rune -0 cscli collections install crowdsecurity/linux --download-only + rune -0 cscli collections inspect crowdsecurity/linux --no-metrics + assert_output --partial 'crowdsecurity/linux' + assert_output --partial 'installed: false' + assert_file_exists "$HUB_DIR/collections/crowdsecurity/linux.yaml" + assert_file_not_exists "$CONFIG_DIR/collections/linux.yaml" + + rune -0 cscli collections install crowdsecurity/linux + rune -0 cscli collections inspect crowdsecurity/linux --no-metrics + assert_output --partial 'installed: true' + assert_file_exists "$CONFIG_DIR/collections/linux.yaml" +} + +@test "cscli collections install --force (tainted)" { + rune -0 cscli collections install crowdsecurity/sshd + echo "dirty" >"$CONFIG_DIR/collections/sshd.yaml" + + rune -1 cscli collections install crowdsecurity/sshd + assert_stderr --partial "error while installing 'crowdsecurity/sshd': while enabling crowdsecurity/sshd: crowdsecurity/sshd is tainted, won't overwrite unless --force" + + rune -0 cscli collections install crowdsecurity/sshd --force + assert_stderr --partial "Enabled crowdsecurity/sshd" +} + +@test "cscli collections install --ignore (skip on errors)" { + rune -1 cscli collections install foo/bar crowdsecurity/sshd + assert_stderr --partial "can't find 'foo/bar' in collections" + refute_stderr --partial "Enabled collections: crowdsecurity/sshd" + + rune -0 cscli collections install foo/bar crowdsecurity/sshd --ignore + assert_stderr --partial "can't find 'foo/bar' in collections" + assert_stderr --partial "Enabled collections: crowdsecurity/sshd" +} + +@test "cscli collections inspect" { + rune -1 cscli collections inspect + assert_stderr --partial 'requires at least 1 arg(s), only received 0' + # required for metrics + ./instance-crowdsec start + + rune -1 cscli collections inspect blahblah/blahblah + assert_stderr --partial "can't find 'blahblah/blahblah' in collections" + + # one item + rune -0 cscli collections inspect crowdsecurity/sshd --no-metrics + assert_line 'type: collections' + assert_line 'name: crowdsecurity/sshd' + assert_line 'author: crowdsecurity' + assert_line 'path: collections/crowdsecurity/sshd.yaml' + assert_line 'installed: false' + refute_line --partial 'Current metrics:' + + # one item, with metrics + rune -0 cscli collections inspect crowdsecurity/sshd + assert_line --partial 'Current metrics:' + + # one item, json + rune -0 cscli collections inspect crowdsecurity/sshd -o json + rune -0 jq -c '[.type, .name, .author, .path, .installed]' <(output) + assert_json '["collections","crowdsecurity/sshd","crowdsecurity","collections/crowdsecurity/sshd.yaml",false]' + + # one item, raw + rune -0 cscli collections inspect crowdsecurity/sshd -o raw + assert_line 'type: collections' + assert_line 'name: crowdsecurity/sshd' + assert_line 'author: crowdsecurity' + assert_line 'path: collections/crowdsecurity/sshd.yaml' + assert_line 'installed: false' + refute_line --partial 'Current metrics:' + + # multiple items + rune -0 cscli collections inspect crowdsecurity/sshd crowdsecurity/smb --no-metrics + assert_output --partial 'crowdsecurity/sshd' + assert_output --partial 'crowdsecurity/smb' + rune -1 grep -c 'Current metrics:' <(output) + assert_output "0" + + # multiple items, with metrics + rune -0 cscli collections inspect crowdsecurity/sshd crowdsecurity/smb + rune -0 grep -c 'Current metrics:' <(output) + assert_output "2" + + # multiple items, json + rune -0 cscli collections inspect crowdsecurity/sshd crowdsecurity/smb -o json + rune -0 jq -sc '[.[] | [.type, .name, .author, .path, .installed]]' <(output) + assert_json '[["collections","crowdsecurity/sshd","crowdsecurity","collections/crowdsecurity/sshd.yaml",false],["collections","crowdsecurity/smb","crowdsecurity","collections/crowdsecurity/smb.yaml",false]]' + + # multiple items, raw + rune -0 cscli collections inspect crowdsecurity/sshd crowdsecurity/smb -o raw + assert_output --partial 'crowdsecurity/sshd' + assert_output --partial 'crowdsecurity/smb' + rune -1 grep -c 'Current metrics:' <(output) + assert_output "0" +} + +@test "cscli collections remove" { + rune -1 cscli collections remove + assert_stderr --partial "specify at least one collection to remove or '--all'" + rune -1 cscli collections remove blahblah/blahblah + assert_stderr --partial "can't find 'blahblah/blahblah' in collections" + + rune -0 cscli collections install crowdsecurity/sshd --download-only + rune -0 cscli collections remove crowdsecurity/sshd + assert_stderr --partial 'removing crowdsecurity/sshd: not installed -- no need to remove' + + rune -0 cscli collections install crowdsecurity/sshd + rune -0 cscli collections remove crowdsecurity/sshd + assert_stderr --partial 'Removed crowdsecurity/sshd' + + rune -0 cscli collections remove crowdsecurity/sshd --purge + assert_stderr --partial 'Removed source file [crowdsecurity/sshd]' + + rune -0 cscli collections remove crowdsecurity/sshd + assert_stderr --partial 'removing crowdsecurity/sshd: not installed -- no need to remove' + + rune -0 cscli collections remove crowdsecurity/sshd --purge --debug + assert_stderr --partial 'removing crowdsecurity/sshd: not downloaded -- no need to remove' + refute_stderr --partial 'Removed source file [crowdsecurity/sshd]' + + # install, then remove, check files + rune -0 cscli collections install crowdsecurity/sshd + assert_file_exists "$CONFIG_DIR/collections/sshd.yaml" + rune -0 cscli collections remove crowdsecurity/sshd + assert_file_not_exists "$CONFIG_DIR/collections/sshd.yaml" + + # delete is an alias for remove + rune -0 cscli collections install crowdsecurity/sshd + assert_file_exists "$CONFIG_DIR/collections/sshd.yaml" + rune -0 cscli collections delete crowdsecurity/sshd + assert_file_not_exists "$CONFIG_DIR/collections/sshd.yaml" + + # purge + assert_file_exists "$HUB_DIR/collections/crowdsecurity/sshd.yaml" + rune -0 cscli collections remove crowdsecurity/sshd --purge + assert_file_not_exists "$HUB_DIR/collections/crowdsecurity/sshd.yaml" + + rune -0 cscli collections install crowdsecurity/sshd crowdsecurity/smb + + # --all + rune -0 cscli collections list -o raw + rune -0 grep -vc 'name,status,version,description' <(output) + assert_output "2" + + rune -0 cscli collections remove --all + + rune -0 cscli collections list -o raw + rune -1 grep -vc 'name,status,version,description' <(output) + assert_output "0" +} + +@test "cscli collections remove --force" { + # remove a collections that belongs to a collection + rune -0 cscli collections install crowdsecurity/linux + rune -0 cscli collections remove crowdsecurity/sshd + assert_stderr --partial "crowdsecurity/sshd belongs to collections: [crowdsecurity/linux]" + assert_stderr --partial "Run 'sudo cscli collections remove crowdsecurity/sshd --force' if you want to force remove this collection" +} + +@test "cscli collections upgrade" { + rune -1 cscli collections upgrade + assert_stderr --partial "specify at least one collection to upgrade or '--all'" + rune -1 cscli collections upgrade blahblah/blahblah + assert_stderr --partial "can't find 'blahblah/blahblah' in collections" + rune -0 cscli collections remove crowdsecurity/exim --purge + rune -1 cscli collections upgrade crowdsecurity/exim + assert_stderr --partial "can't upgrade crowdsecurity/exim: not installed" + rune -0 cscli collections install crowdsecurity/exim --download-only + rune -1 cscli collections upgrade crowdsecurity/exim + assert_stderr --partial "can't upgrade crowdsecurity/exim: downloaded but not installed" + + # hash of the string "v0.0" + sha256_0_0="dfebecf42784a31aa3d009dbcec0c657154a034b45f49cf22a895373f6dbf63d" + + # add version 0.0 to all collections + new_hub=$(jq --arg DIGEST "$sha256_0_0" <"$INDEX_PATH" '.collections |= with_entries(.value.versions["0.0"] = {"digest": $DIGEST, "deprecated": false})') + echo "$new_hub" >"$INDEX_PATH" + + rune -0 cscli collections install crowdsecurity/sshd + + echo "v0.0" > "$CONFIG_DIR/collections/sshd.yaml" + rune -0 cscli collections inspect crowdsecurity/sshd -o json + rune -0 jq -e '.local_version=="0.0"' <(output) + + # upgrade + rune -0 cscli collections upgrade crowdsecurity/sshd + rune -0 cscli collections inspect crowdsecurity/sshd -o json + rune -0 jq -e '.local_version==.version' <(output) + + # taint + echo "dirty" >"$CONFIG_DIR/collections/sshd.yaml" + # XXX: should return error + rune -0 cscli collections upgrade crowdsecurity/sshd + assert_stderr --partial "crowdsecurity/sshd is tainted, --force to overwrite" + rune -0 cscli collections inspect crowdsecurity/sshd -o json + rune -0 jq -e '.local_version=="?"' <(output) + + # force upgrade with taint + rune -0 cscli collections upgrade crowdsecurity/sshd --force + rune -0 cscli collections inspect crowdsecurity/sshd -o json + rune -0 jq -e '.local_version==.version' <(output) + + # multiple items + rune -0 cscli collections install crowdsecurity/smb + echo "v0.0" >"$CONFIG_DIR/collections/sshd.yaml" + echo "v0.0" >"$CONFIG_DIR/collections/smb.yaml" + rune -0 cscli collections list -o json + rune -0 jq -e '[.collections[].local_version]==["0.0","0.0"]' <(output) + rune -0 cscli collections upgrade crowdsecurity/sshd crowdsecurity/smb + rune -0 cscli collections list -o json + rune -0 jq -e 'any(.collections[].local_version; .=="0.0") | not' <(output) + + # upgrade all + echo "v0.0" >"$CONFIG_DIR/collections/sshd.yaml" + echo "v0.0" >"$CONFIG_DIR/collections/smb.yaml" + rune -0 cscli collections list -o json + rune -0 jq -e '[.collections[].local_version]==["0.0","0.0"]' <(output) + rune -0 cscli collections upgrade --all + rune -0 cscli collections list -o json + rune -0 jq -e 'any(.collections[].local_version; .=="0.0") | not' <(output) +} diff --git a/test/bats/20_hub_collections_dep.bats b/test/bats/20_hub_collections_dep.bats new file mode 100644 index 00000000000..673b812dc0d --- /dev/null +++ b/test/bats/20_hub_collections_dep.bats @@ -0,0 +1,126 @@ +#!/usr/bin/env bats +# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si: + +set -u + +setup_file() { + load "../lib/setup_file.sh" + ./instance-data load + INDEX_PATH=$(config_get '.config_paths.index_path') + export INDEX_PATH + CONFIG_DIR=$(config_get '.config_paths.config_dir') + export CONFIG_DIR +} + +teardown_file() { + load "../lib/teardown_file.sh" +} + +setup() { + load "../lib/setup.sh" + load "../lib/bats-file/load.bash" + ./instance-data load + hub_strip_index +} + +teardown() { + ./instance-crowdsec stop +} + +#---------- + +@test "cscli collections (dependencies)" { + # inject a dependency: smb requires sshd + hub_dep=$(jq <"$INDEX_PATH" '. * {collections:{"crowdsecurity/smb":{collections:["crowdsecurity/sshd"]}}}') + echo "$hub_dep" >"$INDEX_PATH" + + # verify that installing smb brings sshd + rune -0 cscli collections install crowdsecurity/smb + rune -0 cscli collections list -o json + rune -0 jq -e '[.collections[].name]==["crowdsecurity/smb","crowdsecurity/sshd"]' <(output) + + # verify that removing smb removes sshd too + rune -0 cscli collections remove crowdsecurity/smb + rune -0 cscli collections list -o json + rune -0 jq -e '.collections | length == 0' <(output) + + # we can't remove sshd without --force + rune -0 cscli collections install crowdsecurity/smb + # XXX: should this be an error? + rune -0 cscli collections remove crowdsecurity/sshd + assert_stderr --partial "crowdsecurity/sshd belongs to collections: [crowdsecurity/smb]" + assert_stderr --partial "Run 'sudo cscli collections remove crowdsecurity/sshd --force' if you want to force remove this collection" + rune -0 cscli collections list -o json + rune -0 jq -c '[.collections[].name]' <(output) + assert_json '["crowdsecurity/smb","crowdsecurity/sshd"]' + + # use the --force + rune -0 cscli collections remove crowdsecurity/sshd --force + rune -0 cscli collections list -o json + rune -0 jq -c '[.collections[].name]' <(output) + assert_json '["crowdsecurity/smb"]' + + # and now smb is tainted! + rune -0 cscli collections inspect crowdsecurity/smb -o json + rune -0 jq -e '.tainted==true' <(output) + rune -0 cscli collections remove crowdsecurity/smb --force + + # empty + rune -0 cscli collections list -o json + rune -0 jq -e '.collections | length == 0' <(output) + + # reinstall + rune -0 cscli collections install crowdsecurity/smb --force + + # taint on sshd means smb is tainted as well + rune -0 cscli collections inspect crowdsecurity/smb -o json + rune -0 jq -e '.tainted==false' <(output) + echo "dirty" >"$CONFIG_DIR/collections/sshd.yaml" + rune -0 cscli collections inspect crowdsecurity/smb -o json + rune -0 jq -e '.tainted==true' <(output) + + # now we can't remove smb without --force + rune -1 cscli collections remove crowdsecurity/smb + assert_stderr --partial "crowdsecurity/smb is tainted, use '--force' to remove" +} + +@test "cscli collections (dependencies II: the revenge)" { + rune -0 cscli collections install crowdsecurity/wireguard baudneo/gotify + rune -0 cscli collections remove crowdsecurity/wireguard + assert_stderr --partial "crowdsecurity/syslog-logs was not removed because it also belongs to baudneo/gotify" + rune -0 cscli collections inspect crowdsecurity/wireguard -o json + rune -0 jq -e '.installed==false' <(output) +} + +@test "cscli collections (dependencies III: origins)" { + # it is perfectly fine to remove an item belonging to a collection that we are removing anyway + + # inject a dependency: sshd requires the syslog-logs parsers, but linux does too + hub_dep=$(jq <"$INDEX_PATH" '. * {collections:{"crowdsecurity/sshd":{parsers:["crowdsecurity/syslog-logs"]}}}') + echo "$hub_dep" >"$INDEX_PATH" + + # verify that installing sshd brings syslog-logs + rune -0 cscli collections install crowdsecurity/sshd + rune -0 cscli parsers inspect crowdsecurity/syslog-logs -o json + rune -0 jq -e '.installed==true' <(output) + + rune -0 cscli collections install crowdsecurity/linux + + # removing linux should remove syslog-logs even though sshd depends on it + rune -0 cscli collections remove crowdsecurity/linux + refute_stderr --partial "crowdsecurity/syslog-logs was not removed" + # we must also consider indirect dependencies + refute_stderr --partial "crowdsecurity/ssh-bf was not removed" + rune -0 cscli parsers list -o json + rune -0 jq -e '.parsers | length == 0' <(output) +} + +@test "cscli collections (dependencies IV: looper)" { + hub_dep=$(jq <"$INDEX_PATH" '. * {collections:{"crowdsecurity/sshd":{collections:["crowdsecurity/linux"]}}}') + echo "$hub_dep" >"$INDEX_PATH" + + rune -1 cscli hub list + assert_stderr --partial "circular dependency detected" + rune -1 wait-for "$CROWDSEC" + assert_stderr --partial "circular dependency detected" +} diff --git a/test/bats/20_hub_items.bats b/test/bats/20_hub_items.bats new file mode 100644 index 00000000000..4b390c90ed4 --- /dev/null +++ b/test/bats/20_hub_items.bats @@ -0,0 +1,282 @@ +#!/usr/bin/env bats +# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si: + +set -u + +setup_file() { + load "../lib/setup_file.sh" + ./instance-data load + HUB_DIR=$(config_get '.config_paths.hub_dir') + export HUB_DIR + INDEX_PATH=$(config_get '.config_paths.index_path') + export INDEX_PATH + CONFIG_DIR=$(config_get '.config_paths.config_dir') + export CONFIG_DIR +} + +teardown_file() { + load "../lib/teardown_file.sh" +} + +setup() { + load "../lib/setup.sh" + load "../lib/bats-file/load.bash" + ./instance-data load + hub_strip_index +} + +teardown() { + ./instance-crowdsec stop +} + +#---------- +# +# Tests that don't need to be repeated for each hub type +# + +@test "hub versions are correctly sorted during sync" { + # hash of an empty file + sha256_empty="e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + + # add two versions with the same hash, that don't sort the same way + # in a lexical vs semver sort. CrowdSec should report the latest version + + new_hub=$( \ + jq --arg DIGEST "$sha256_empty" <"$INDEX_PATH" \ + '. * {collections:{"crowdsecurity/sshd":{"versions":{"1.2":{"digest":$DIGEST, "deprecated": false}, "1.10": {"digest":$DIGEST, "deprecated": false}}}}}' \ + ) + echo "$new_hub" >"$INDEX_PATH" + + rune -0 cscli collections install crowdsecurity/sshd + + truncate -s 0 "$CONFIG_DIR/collections/sshd.yaml" + + rune -0 cscli collections inspect crowdsecurity/sshd -o json + # XXX: is this supposed to be tainted or up to date? + rune -0 jq -c '[.local_version,.up_to_date,.tainted]' <(output) + assert_json '["1.10",false,false]' +} + +@test "do not unmarshal state attributes" { + new_hub=$( \ + jq <"$INDEX_PATH" \ + '. * {parsers:{"crowdsecurity/syslog-logs":{"tainted":true, "installed":true, "local":true}}}' + ) + echo "$new_hub" >"$INDEX_PATH" + + rune -0 cscli parsers inspect crowdsecurity/syslog-logs --no-metrics + assert_output --partial 'tainted: false' + assert_output --partial 'installed: false' + assert_output --partial 'local: false' +} + +@test "hub index with invalid (non semver) version numbers" { + rune -0 cscli collections remove crowdsecurity/sshd --purge + + new_hub=$( \ + jq <"$INDEX_PATH" \ + '. * {collections:{"crowdsecurity/sshd":{"versions":{"1.2.3.4":{"digest":"foo", "deprecated": false}}}}}' \ + ) + echo "$new_hub" >"$INDEX_PATH" + + rune -0 cscli collections install crowdsecurity/sshd + rune -1 cscli collections inspect crowdsecurity/sshd --no-metrics -o json + # XXX: we are on the verbose side here... + rune -0 jq -r ".msg" <(stderr) + assert_output --regexp "failed to read Hub index: failed to sync hub items: failed to scan .*: while syncing collections sshd.yaml: 1.2.3.4: Invalid Semantic Version. Run 'sudo cscli hub update' to download the index again" +} + +@test "removing or purging an item already removed by hand" { + rune -0 cscli parsers install crowdsecurity/syslog-logs + rune -0 cscli parsers inspect crowdsecurity/syslog-logs -o json + rune -0 jq -r '.local_path' <(output) + rune -0 rm "$(output)" + + rune -0 cscli parsers remove crowdsecurity/syslog-logs --debug + assert_stderr --partial "removing crowdsecurity/syslog-logs: not installed -- no need to remove" + + rune -0 cscli parsers inspect crowdsecurity/syslog-logs -o json + rune -0 jq -r '.path' <(output) + rune -0 rm "$HUB_DIR/$(output)" + + rune -0 cscli parsers remove crowdsecurity/syslog-logs --purge --debug + assert_stderr --partial "removing crowdsecurity/syslog-logs: not downloaded -- no need to remove" + + rune -0 cscli parsers remove crowdsecurity/linux --all --error --purge --force + rune -0 cscli collections remove crowdsecurity/linux --all --error --purge --force + refute_output + refute_stderr +} + +@test "a local item is not tainted" { + # not from cscli... inspect + rune -0 mkdir -p "$CONFIG_DIR/collections" + rune -0 touch "$CONFIG_DIR/collections/foobar.yaml" + rune -0 cscli collections inspect foobar.yaml -o json + rune -0 jq -e '[.tainted,.local==false,true]' <(output) + + rune -0 cscli collections install crowdsecurity/sshd + rune -0 truncate -s0 "$CONFIG_DIR/collections/sshd.yaml" + rune -0 cscli collections inspect crowdsecurity/sshd -o json + rune -0 jq -e '[.tainted,.local==true,false]' <(output) + + # and not from hub update + rune -0 cscli hub update + assert_stderr --partial "collection crowdsecurity/sshd is tainted" + refute_stderr --partial "collection foobar.yaml is tainted" +} + +@test "a local item's name defaults to its filename" { + rune -0 mkdir -p "$CONFIG_DIR/collections" + rune -0 touch "$CONFIG_DIR/collections/foobar.yaml" + rune -0 cscli collections list -o json + rune -0 jq -r '.[][].name' <(output) + assert_output "foobar.yaml" + rune -0 cscli collections list foobar.yaml + rune -0 cscli collections inspect foobar.yaml -o json + rune -0 jq -e '[.installed,.local==true,true]' <(output) +} + +@test "a local item can provide its own name" { + rune -0 mkdir -p "$CONFIG_DIR/collections" + echo "name: hi-its-me" > "$CONFIG_DIR/collections/foobar.yaml" + rune -0 cscli collections list -o json + rune -0 jq -r '.[][].name' <(output) + assert_output "hi-its-me" + rune -0 cscli collections list hi-its-me + rune -0 cscli collections inspect hi-its-me -o json + rune -0 jq -e '[.installed,.local]==[true,true]' <(output) +} + +@test "a local item cannot be downloaded by cscli" { + rune -0 mkdir -p "$CONFIG_DIR/collections" + rune -0 touch "$CONFIG_DIR/collections/foobar.yaml" + rune -1 cscli collections install foobar.yaml + assert_stderr --partial "foobar.yaml is local, can't download" + rune -1 cscli collections install foobar.yaml --force + assert_stderr --partial "foobar.yaml is local, can't download" +} + +@test "a local item cannot be removed by cscli" { + rune -0 mkdir -p "$CONFIG_DIR/collections" + rune -0 touch "$CONFIG_DIR/collections/foobar.yaml" + rune -0 cscli collections remove foobar.yaml + assert_stderr --partial "foobar.yaml is a local item, please delete manually" + rune -0 cscli collections remove foobar.yaml --purge + assert_stderr --partial "foobar.yaml is a local item, please delete manually" + rune -0 cscli collections remove foobar.yaml --force + assert_stderr --partial "foobar.yaml is a local item, please delete manually" + rune -0 cscli collections remove --all + assert_stderr --partial "foobar.yaml is a local item, please delete manually" + rune -0 cscli collections remove --all --purge + assert_stderr --partial "foobar.yaml is a local item, please delete manually" +} + +@test "a dangling link is reported with a warning" { + rune -0 mkdir -p "$CONFIG_DIR/collections" + rune -0 ln -s /this/does/not/exist.yaml "$CONFIG_DIR/collections/foobar.yaml" + rune -0 cscli hub list + assert_stderr --partial "Ignoring file $CONFIG_DIR/collections/foobar.yaml: lstat /this/does/not/exist.yaml: no such file or directory" + rune -0 cscli hub list -o json + rune -0 jq '.collections' <(output) + assert_json '[]' +} + +@test "tainted hub file, not enabled, install --force should repair" { + rune -0 cscli scenarios install crowdsecurity/ssh-bf + rune -0 cscli scenarios inspect crowdsecurity/ssh-bf -o json + local_path="$(jq -r '.local_path' <(output))" + echo >> "$local_path" + rm "$local_path" + rune -0 cscli scenarios install crowdsecurity/ssh-bf --force + rune -0 cscli scenarios inspect crowdsecurity/ssh-bf -o json + rune -0 jq -c '.tainted' <(output) + assert_output 'false' +} + +@test "don't traverse hidden directories (starting with a dot)" { + rune -0 mkdir -p "$CONFIG_DIR/scenarios/.foo" + rune -0 touch "$CONFIG_DIR/scenarios/.foo/bar.yaml" + rune -0 cscli hub list --trace + assert_stderr --partial "skipping hidden directory $CONFIG_DIR/scenarios/.foo" +} + +@test "allow symlink to target inside a hidden directory" { + # k8s config maps use hidden directories and links when mounted + rune -0 mkdir -p "$CONFIG_DIR/scenarios/.foo" + + # ignored + rune -0 touch "$CONFIG_DIR/scenarios/.foo/hidden.yaml" + rune -0 cscli scenarios list -o json + rune -0 jq '.scenarios | length' <(output) + assert_output 0 + + # real file + rune -0 touch "$CONFIG_DIR/scenarios/myfoo.yaml" + rune -0 cscli scenarios list -o json + rune -0 jq '.scenarios | length' <(output) + assert_output 1 + + rune -0 rm "$CONFIG_DIR/scenarios/myfoo.yaml" + rune -0 cscli scenarios list -o json + rune -0 jq '.scenarios | length' <(output) + assert_output 0 + + # link to ignored is not ignored, and the name comes from the link + rune -0 ln -s "$CONFIG_DIR/scenarios/.foo/hidden.yaml" "$CONFIG_DIR/scenarios/myfoo.yaml" + rune -0 cscli scenarios list -o json + rune -0 jq -c '[.scenarios[].name] | sort' <(output) + assert_json '["myfoo.yaml"]' +} + +@test "item files can be links to links" { + rune -0 mkdir -p "$CONFIG_DIR"/scenarios/{.foo,.bar} + + rune -0 ln -s "$CONFIG_DIR/scenarios/.foo/hidden.yaml" "$CONFIG_DIR/scenarios/.bar/hidden.yaml" + + # link to a danling link + rune -0 ln -s "$CONFIG_DIR/scenarios/.bar/hidden.yaml" "$CONFIG_DIR/scenarios/myfoo.yaml" + rune -0 cscli scenarios list + assert_stderr --partial "Ignoring file $CONFIG_DIR/scenarios/myfoo.yaml: lstat $CONFIG_DIR/scenarios/.foo/hidden.yaml: no such file or directory" + rune -0 cscli scenarios list -o json + rune -0 jq '.scenarios | length' <(output) + assert_output 0 + + # detect link loops + rune -0 ln -s "$CONFIG_DIR/scenarios/.bar/hidden.yaml" "$CONFIG_DIR/scenarios/.foo/hidden.yaml" + rune -0 cscli scenarios list + assert_stderr --partial "Ignoring file $CONFIG_DIR/scenarios/myfoo.yaml: too many levels of symbolic links" + + rune -0 rm "$CONFIG_DIR/scenarios/.foo/hidden.yaml" + rune -0 touch "$CONFIG_DIR/scenarios/.foo/hidden.yaml" + rune -0 cscli scenarios list -o json + rune -0 jq '.scenarios | length' <(output) + assert_output 1 +} + +@test "item files can be in a subdirectory" { + rune -0 mkdir -p "$CONFIG_DIR/scenarios/sub/sub2/sub3" + rune -0 touch "$CONFIG_DIR/scenarios/sub/imlocal.yaml" + # subdir name is now part of the item name + rune -0 cscli scenarios inspect sub/imlocal.yaml -o json + rune -0 jq -e '[.tainted,.local==false,true]' <(output) + rune -0 rm "$CONFIG_DIR/scenarios/sub/imlocal.yaml" + + rune -0 ln -s "$HUB_DIR/scenarios/crowdsecurity/smb-bf.yaml" "$CONFIG_DIR/scenarios/sub/smb-bf.yaml" + rune -0 cscli scenarios inspect crowdsecurity/smb-bf -o json + rune -0 jq -e '[.tainted,.local==false,false]' <(output) + rune -0 rm "$CONFIG_DIR/scenarios/sub/smb-bf.yaml" + + rune -0 ln -s "$HUB_DIR/scenarios/crowdsecurity/smb-bf.yaml" "$CONFIG_DIR/scenarios/sub/sub2/sub3/smb-bf.yaml" + rune -0 cscli scenarios inspect crowdsecurity/smb-bf -o json + rune -0 jq -e '[.tainted,.local==false,false]' <(output) +} + +@test "same file name for local items in different subdirectories" { + rune -0 mkdir -p "$CONFIG_DIR"/scenarios/{foo,bar} + rune -0 touch "$CONFIG_DIR/scenarios/foo/local.yaml" + rune -0 touch "$CONFIG_DIR/scenarios/bar/local.yaml" + rune -0 cscli scenarios list -o json + rune -0 jq -c '[.scenarios[].name] | sort' <(output) + assert_json '["bar/local.yaml","foo/local.yaml"]' +} diff --git a/test/bats/20_hub_parsers.bats b/test/bats/20_hub_parsers.bats new file mode 100644 index 00000000000..791b1a2177f --- /dev/null +++ b/test/bats/20_hub_parsers.bats @@ -0,0 +1,383 @@ +#!/usr/bin/env bats +# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si: + +set -u + +setup_file() { + load "../lib/setup_file.sh" + ./instance-data load + HUB_DIR=$(config_get '.config_paths.hub_dir') + export HUB_DIR + INDEX_PATH=$(config_get '.config_paths.index_path') + export INDEX_PATH + CONFIG_DIR=$(config_get '.config_paths.config_dir') + export CONFIG_DIR +} + +teardown_file() { + load "../lib/teardown_file.sh" +} + +setup() { + load "../lib/setup.sh" + load "../lib/bats-file/load.bash" + ./instance-data load + hub_strip_index +} + +teardown() { + ./instance-crowdsec stop +} + +#---------- + +@test "cscli parsers list" { + hub_purge_all + + # no items + rune -0 cscli parsers list + assert_output --partial "PARSERS" + rune -0 cscli parsers list -o json + assert_json '{parsers:[]}' + rune -0 cscli parsers list -o raw + assert_output 'name,status,version,description' + + # some items + rune -0 cscli parsers install crowdsecurity/whitelists crowdsecurity/windows-auth + + rune -0 cscli parsers list + assert_output --partial crowdsecurity/whitelists + assert_output --partial crowdsecurity/windows-auth + rune -0 grep -c enabled <(output) + assert_output "2" + + rune -0 cscli parsers list -o json + assert_output --partial crowdsecurity/whitelists + assert_output --partial crowdsecurity/windows-auth + rune -0 jq '.parsers | length' <(output) + assert_output "2" + + rune -0 cscli parsers list -o raw + assert_output --partial crowdsecurity/whitelists + assert_output --partial crowdsecurity/windows-auth + rune -0 grep -vc 'name,status,version,description' <(output) + assert_output "2" +} + +@test "cscli parsers list -a" { + expected=$(jq <"$INDEX_PATH" -r '.parsers | length') + + rune -0 cscli parsers list -a + rune -0 grep -c disabled <(output) + assert_output "$expected" + + rune -0 cscli parsers list -o json -a + rune -0 jq '.parsers | length' <(output) + assert_output "$expected" + + rune -0 cscli parsers list -o raw -a + rune -0 grep -vc 'name,status,version,description' <(output) + assert_output "$expected" + + # the list should be the same in all formats, and sorted (not case sensitive) + + list_raw=$(cscli parsers list -o raw -a | tail -n +2 | cut -d, -f1) + list_human=$(cscli parsers list -o human -a | tail -n +6 | head -n -1 | cut -d' ' -f2) + list_json=$(cscli parsers list -o json -a | jq -r '.parsers[].name') + + rune -0 sort -f <<<"$list_raw" + assert_output "$list_raw" + + assert_equal "$list_raw" "$list_json" + assert_equal "$list_raw" "$list_human" +} + +@test "cscli parsers list [parser]..." { + # non-existent + rune -1 cscli parsers install foo/bar + assert_stderr --partial "can't find 'foo/bar' in parsers" + + # not installed + rune -0 cscli parsers list crowdsecurity/whitelists + assert_output --regexp 'crowdsecurity/whitelists.*disabled' + + # install two items + rune -0 cscli parsers install crowdsecurity/whitelists crowdsecurity/windows-auth + + # list an installed item + rune -0 cscli parsers list crowdsecurity/whitelists + assert_output --regexp "crowdsecurity/whitelists.*enabled" + refute_output --partial "crowdsecurity/windows-auth" + + # list multiple installed and non installed items + rune -0 cscli parsers list crowdsecurity/whitelists crowdsecurity/windows-auth crowdsecurity/traefik-logs + assert_output --partial "crowdsecurity/whitelists" + assert_output --partial "crowdsecurity/windows-auth" + assert_output --partial "crowdsecurity/traefik-logs" + + rune -0 cscli parsers list crowdsecurity/whitelists -o json + rune -0 jq '.parsers | length' <(output) + assert_output "1" + rune -0 cscli parsers list crowdsecurity/whitelists crowdsecurity/windows-auth crowdsecurity/traefik-logs -o json + rune -0 jq '.parsers | length' <(output) + assert_output "3" + + rune -0 cscli parsers list crowdsecurity/whitelists -o raw + rune -0 grep -vc 'name,status,version,description' <(output) + assert_output "1" + rune -0 cscli parsers list crowdsecurity/whitelists crowdsecurity/windows-auth crowdsecurity/traefik-logs -o raw + rune -0 grep -vc 'name,status,version,description' <(output) + assert_output "3" +} + +@test "cscli parsers install" { + rune -1 cscli parsers install + assert_stderr --partial 'requires at least 1 arg(s), only received 0' + + # not in hub + rune -1 cscli parsers install crowdsecurity/blahblah + assert_stderr --partial "can't find 'crowdsecurity/blahblah' in parsers" + + # simple install + rune -0 cscli parsers install crowdsecurity/whitelists + rune -0 cscli parsers inspect crowdsecurity/whitelists --no-metrics + assert_output --partial 'crowdsecurity/whitelists' + assert_output --partial 'installed: true' + + # autocorrect + rune -1 cscli parsers install crowdsecurity/sshd-logz + assert_stderr --partial "can't find 'crowdsecurity/sshd-logz' in parsers, did you mean 'crowdsecurity/sshd-logs'?" + + # install multiple + rune -0 cscli parsers install crowdsecurity/pgsql-logs crowdsecurity/postfix-logs + rune -0 cscli parsers inspect crowdsecurity/pgsql-logs --no-metrics + assert_output --partial 'crowdsecurity/pgsql-logs' + assert_output --partial 'installed: true' + rune -0 cscli parsers inspect crowdsecurity/postfix-logs --no-metrics + assert_output --partial 'crowdsecurity/postfix-logs' + assert_output --partial 'installed: true' +} + +@test "cscli parsers install (file location and download-only)" { + rune -0 cscli parsers install crowdsecurity/whitelists --download-only + rune -0 cscli parsers inspect crowdsecurity/whitelists --no-metrics + assert_output --partial 'crowdsecurity/whitelists' + assert_output --partial 'installed: false' + assert_file_exists "$HUB_DIR/parsers/s02-enrich/crowdsecurity/whitelists.yaml" + assert_file_not_exists "$CONFIG_DIR/parsers/s02-enrich/whitelists.yaml" + + rune -0 cscli parsers install crowdsecurity/whitelists + rune -0 cscli parsers inspect crowdsecurity/whitelists --no-metrics + assert_output --partial 'installed: true' + assert_file_exists "$CONFIG_DIR/parsers/s02-enrich/whitelists.yaml" +} + +@test "cscli parsers install --force (tainted)" { + rune -0 cscli parsers install crowdsecurity/whitelists + echo "dirty" >"$CONFIG_DIR/parsers/s02-enrich/whitelists.yaml" + + rune -1 cscli parsers install crowdsecurity/whitelists + assert_stderr --partial "error while installing 'crowdsecurity/whitelists': while enabling crowdsecurity/whitelists: crowdsecurity/whitelists is tainted, won't overwrite unless --force" + + rune -0 cscli parsers install crowdsecurity/whitelists --force + assert_stderr --partial "Enabled crowdsecurity/whitelists" +} + +@test "cscli parsers install --ignore (skip on errors)" { + rune -1 cscli parsers install foo/bar crowdsecurity/whitelists + assert_stderr --partial "can't find 'foo/bar' in parsers" + refute_stderr --partial "Enabled parsers: crowdsecurity/whitelists" + + rune -0 cscli parsers install foo/bar crowdsecurity/whitelists --ignore + assert_stderr --partial "can't find 'foo/bar' in parsers" + assert_stderr --partial "Enabled parsers: crowdsecurity/whitelists" +} + +@test "cscli parsers inspect" { + rune -1 cscli parsers inspect + assert_stderr --partial 'requires at least 1 arg(s), only received 0' + # required for metrics + ./instance-crowdsec start + + rune -1 cscli parsers inspect blahblah/blahblah + assert_stderr --partial "can't find 'blahblah/blahblah' in parsers" + + # one item + rune -0 cscli parsers inspect crowdsecurity/sshd-logs --no-metrics + assert_line 'type: parsers' + assert_line 'stage: s01-parse' + assert_line 'name: crowdsecurity/sshd-logs' + assert_line 'author: crowdsecurity' + assert_line 'path: parsers/s01-parse/crowdsecurity/sshd-logs.yaml' + assert_line 'installed: false' + refute_line --partial 'Current metrics:' + + # one item, with metrics + rune -0 cscli parsers inspect crowdsecurity/sshd-logs + assert_line --partial 'Current metrics:' + + # one item, json + rune -0 cscli parsers inspect crowdsecurity/sshd-logs -o json + rune -0 jq -c '[.type, .stage, .name, .author, .path, .installed]' <(output) + assert_json '["parsers","s01-parse","crowdsecurity/sshd-logs","crowdsecurity","parsers/s01-parse/crowdsecurity/sshd-logs.yaml",false]' + + # one item, raw + rune -0 cscli parsers inspect crowdsecurity/sshd-logs -o raw + assert_line 'type: parsers' + assert_line 'name: crowdsecurity/sshd-logs' + assert_line 'stage: s01-parse' + assert_line 'author: crowdsecurity' + assert_line 'path: parsers/s01-parse/crowdsecurity/sshd-logs.yaml' + assert_line 'installed: false' + refute_line --partial 'Current metrics:' + + # multiple items + rune -0 cscli parsers inspect crowdsecurity/sshd-logs crowdsecurity/whitelists --no-metrics + assert_output --partial 'crowdsecurity/sshd-logs' + assert_output --partial 'crowdsecurity/whitelists' + rune -1 grep -c 'Current metrics:' <(output) + assert_output "0" + + # multiple items, with metrics + rune -0 cscli parsers inspect crowdsecurity/sshd-logs crowdsecurity/whitelists + rune -0 grep -c 'Current metrics:' <(output) + assert_output "2" + + # multiple items, json + rune -0 cscli parsers inspect crowdsecurity/sshd-logs crowdsecurity/whitelists -o json + rune -0 jq -sc '[.[] | [.type, .stage, .name, .author, .path, .installed]]' <(output) + assert_json '[["parsers","s01-parse","crowdsecurity/sshd-logs","crowdsecurity","parsers/s01-parse/crowdsecurity/sshd-logs.yaml",false],["parsers","s02-enrich","crowdsecurity/whitelists","crowdsecurity","parsers/s02-enrich/crowdsecurity/whitelists.yaml",false]]' + + # multiple items, raw + rune -0 cscli parsers inspect crowdsecurity/sshd-logs crowdsecurity/whitelists -o raw + assert_output --partial 'crowdsecurity/sshd-logs' + assert_output --partial 'crowdsecurity/whitelists' + rune -1 grep -c 'Current metrics:' <(output) + assert_output "0" +} + +@test "cscli parsers remove" { + rune -1 cscli parsers remove + assert_stderr --partial "specify at least one parser to remove or '--all'" + rune -1 cscli parsers remove blahblah/blahblah + assert_stderr --partial "can't find 'blahblah/blahblah' in parsers" + + rune -0 cscli parsers install crowdsecurity/whitelists --download-only + rune -0 cscli parsers remove crowdsecurity/whitelists + assert_stderr --partial "removing crowdsecurity/whitelists: not installed -- no need to remove" + + rune -0 cscli parsers install crowdsecurity/whitelists + rune -0 cscli parsers remove crowdsecurity/whitelists + assert_stderr --partial "Removed crowdsecurity/whitelists" + + rune -0 cscli parsers remove crowdsecurity/whitelists --purge + assert_stderr --partial 'Removed source file [crowdsecurity/whitelists]' + + rune -0 cscli parsers remove crowdsecurity/whitelists + assert_stderr --partial "removing crowdsecurity/whitelists: not installed -- no need to remove" + + rune -0 cscli parsers remove crowdsecurity/whitelists --purge --debug + assert_stderr --partial 'removing crowdsecurity/whitelists: not downloaded -- no need to remove' + refute_stderr --partial 'Removed source file [crowdsecurity/whitelists]' + + # install, then remove, check files + rune -0 cscli parsers install crowdsecurity/whitelists + assert_file_exists "$CONFIG_DIR/parsers/s02-enrich/whitelists.yaml" + rune -0 cscli parsers remove crowdsecurity/whitelists + assert_file_not_exists "$CONFIG_DIR/parsers/s02-enrich/whitelists.yaml" + + # delete is an alias for remove + rune -0 cscli parsers install crowdsecurity/whitelists + assert_file_exists "$CONFIG_DIR/parsers/s02-enrich/whitelists.yaml" + rune -0 cscli parsers delete crowdsecurity/whitelists + assert_file_not_exists "$CONFIG_DIR/parsers/s02-enrich/whitelists.yaml" + + # purge + assert_file_exists "$HUB_DIR/parsers/s02-enrich/crowdsecurity/whitelists.yaml" + rune -0 cscli parsers remove crowdsecurity/whitelists --purge + assert_file_not_exists "$HUB_DIR/parsers/s02-enrich/crowdsecurity/whitelists.yaml" + + rune -0 cscli parsers install crowdsecurity/whitelists crowdsecurity/windows-auth + + # --all + rune -0 cscli parsers list -o raw + rune -0 grep -vc 'name,status,version,description' <(output) + assert_output "2" + + rune -0 cscli parsers remove --all + + rune -0 cscli parsers list -o raw + rune -1 grep -vc 'name,status,version,description' <(output) + assert_output "0" +} + +@test "cscli parsers remove --force" { + # remove a parser that belongs to a collection + rune -0 cscli collections install crowdsecurity/sshd + rune -0 cscli parsers remove crowdsecurity/sshd-logs + assert_stderr --partial "crowdsecurity/sshd-logs belongs to collections: [crowdsecurity/sshd]" + assert_stderr --partial "Run 'sudo cscli parsers remove crowdsecurity/sshd-logs --force' if you want to force remove this parser" +} + +@test "cscli parsers upgrade" { + rune -1 cscli parsers upgrade + assert_stderr --partial "specify at least one parser to upgrade or '--all'" + rune -1 cscli parsers upgrade blahblah/blahblah + assert_stderr --partial "can't find 'blahblah/blahblah' in parsers" + rune -0 cscli parsers remove crowdsecurity/pam-logs --purge + rune -1 cscli parsers upgrade crowdsecurity/pam-logs + assert_stderr --partial "can't upgrade crowdsecurity/pam-logs: not installed" + rune -0 cscli parsers install crowdsecurity/pam-logs --download-only + rune -1 cscli parsers upgrade crowdsecurity/pam-logs + assert_stderr --partial "can't upgrade crowdsecurity/pam-logs: downloaded but not installed" + + # hash of the string "v0.0" + sha256_0_0="dfebecf42784a31aa3d009dbcec0c657154a034b45f49cf22a895373f6dbf63d" + + # add version 0.0 to all parsers + new_hub=$(jq --arg DIGEST "$sha256_0_0" <"$INDEX_PATH" '.parsers |= with_entries(.value.versions["0.0"] = {"digest": $DIGEST, "deprecated": false})') + echo "$new_hub" >"$INDEX_PATH" + + rune -0 cscli parsers install crowdsecurity/whitelists + + echo "v0.0" > "$CONFIG_DIR/parsers/s02-enrich/whitelists.yaml" + rune -0 cscli parsers inspect crowdsecurity/whitelists -o json + rune -0 jq -e '.local_version=="0.0"' <(output) + + # upgrade + rune -0 cscli parsers upgrade crowdsecurity/whitelists + rune -0 cscli parsers inspect crowdsecurity/whitelists -o json + rune -0 jq -e '.local_version==.version' <(output) + + # taint + echo "dirty" >"$CONFIG_DIR/parsers/s02-enrich/whitelists.yaml" + # XXX: should return error + rune -0 cscli parsers upgrade crowdsecurity/whitelists + assert_stderr --partial "crowdsecurity/whitelists is tainted, --force to overwrite" + rune -0 cscli parsers inspect crowdsecurity/whitelists -o json + rune -0 jq -e '.local_version=="?"' <(output) + + # force upgrade with taint + rune -0 cscli parsers upgrade crowdsecurity/whitelists --force + rune -0 cscli parsers inspect crowdsecurity/whitelists -o json + rune -0 jq -e '.local_version==.version' <(output) + + # multiple items + rune -0 cscli parsers install crowdsecurity/windows-auth + echo "v0.0" >"$CONFIG_DIR/parsers/s02-enrich/whitelists.yaml" + echo "v0.0" >"$CONFIG_DIR/parsers/s01-parse/windows-auth.yaml" + rune -0 cscli parsers list -o json + rune -0 jq -e '[.parsers[].local_version]==["0.0","0.0"]' <(output) + rune -0 cscli parsers upgrade crowdsecurity/whitelists crowdsecurity/windows-auth + rune -0 cscli parsers list -o json + rune -0 jq -e 'any(.parsers[].local_version; .=="0.0") | not' <(output) + + # upgrade all + echo "v0.0" >"$CONFIG_DIR/parsers/s02-enrich/whitelists.yaml" + echo "v0.0" >"$CONFIG_DIR/parsers/s01-parse/windows-auth.yaml" + rune -0 cscli parsers list -o json + rune -0 jq -e '[.parsers[].local_version]==["0.0","0.0"]' <(output) + rune -0 cscli parsers upgrade --all + rune -0 cscli parsers list -o json + rune -0 jq -e 'any(.parsers[].local_version; .=="0.0") | not' <(output) +} diff --git a/test/bats/20_hub_postoverflows.bats b/test/bats/20_hub_postoverflows.bats new file mode 100644 index 00000000000..37337b08caa --- /dev/null +++ b/test/bats/20_hub_postoverflows.bats @@ -0,0 +1,383 @@ +#!/usr/bin/env bats +# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si: + +set -u + +setup_file() { + load "../lib/setup_file.sh" + ./instance-data load + HUB_DIR=$(config_get '.config_paths.hub_dir') + export HUB_DIR + INDEX_PATH=$(config_get '.config_paths.index_path') + export INDEX_PATH + CONFIG_DIR=$(config_get '.config_paths.config_dir') + export CONFIG_DIR +} + +teardown_file() { + load "../lib/teardown_file.sh" +} + +setup() { + load "../lib/setup.sh" + load "../lib/bats-file/load.bash" + ./instance-data load + hub_strip_index +} + +teardown() { + ./instance-crowdsec stop +} + +#---------- + +@test "cscli postoverflows list" { + hub_purge_all + + # no items + rune -0 cscli postoverflows list + assert_output --partial "POSTOVERFLOWS" + rune -0 cscli postoverflows list -o json + assert_json '{postoverflows:[]}' + rune -0 cscli postoverflows list -o raw + assert_output 'name,status,version,description' + + # some items + rune -0 cscli postoverflows install crowdsecurity/rdns crowdsecurity/cdn-whitelist + + rune -0 cscli postoverflows list + assert_output --partial crowdsecurity/rdns + assert_output --partial crowdsecurity/cdn-whitelist + rune -0 grep -c enabled <(output) + assert_output "2" + + rune -0 cscli postoverflows list -o json + assert_output --partial crowdsecurity/rdns + assert_output --partial crowdsecurity/cdn-whitelist + rune -0 jq '.postoverflows | length' <(output) + assert_output "2" + + rune -0 cscli postoverflows list -o raw + assert_output --partial crowdsecurity/rdns + assert_output --partial crowdsecurity/cdn-whitelist + rune -0 grep -vc 'name,status,version,description' <(output) + assert_output "2" +} + +@test "cscli postoverflows list -a" { + expected=$(jq <"$INDEX_PATH" -r '.postoverflows | length') + + rune -0 cscli postoverflows list -a + rune -0 grep -c disabled <(output) + assert_output "$expected" + + rune -0 cscli postoverflows list -o json -a + rune -0 jq '.postoverflows | length' <(output) + assert_output "$expected" + + rune -0 cscli postoverflows list -o raw -a + rune -0 grep -vc 'name,status,version,description' <(output) + assert_output "$expected" + + # the list should be the same in all formats, and sorted (not case sensitive) + + list_raw=$(cscli postoverflows list -o raw -a | tail -n +2 | cut -d, -f1) + list_human=$(cscli postoverflows list -o human -a | tail -n +6 | head -n -1 | cut -d' ' -f2) + list_json=$(cscli postoverflows list -o json -a | jq -r '.postoverflows[].name') + + rune -0 sort -f <<<"$list_raw" + assert_output "$list_raw" + + assert_equal "$list_raw" "$list_json" + assert_equal "$list_raw" "$list_human" +} + +@test "cscli postoverflows list [postoverflow]..." { + # non-existent + rune -1 cscli postoverflows install foo/bar + assert_stderr --partial "can't find 'foo/bar' in postoverflows" + + # not installed + rune -0 cscli postoverflows list crowdsecurity/rdns + assert_output --regexp 'crowdsecurity/rdns.*disabled' + + # install two items + rune -0 cscli postoverflows install crowdsecurity/rdns crowdsecurity/cdn-whitelist + + # list an installed item + rune -0 cscli postoverflows list crowdsecurity/rdns + assert_output --regexp "crowdsecurity/rdns.*enabled" + refute_output --partial "crowdsecurity/cdn-whitelist" + + # list multiple installed and non installed items + rune -0 cscli postoverflows list crowdsecurity/rdns crowdsecurity/cdn-whitelist crowdsecurity/ipv6_to_range + assert_output --partial "crowdsecurity/rdns" + assert_output --partial "crowdsecurity/cdn-whitelist" + assert_output --partial "crowdsecurity/ipv6_to_range" + + rune -0 cscli postoverflows list crowdsecurity/rdns -o json + rune -0 jq '.postoverflows | length' <(output) + assert_output "1" + rune -0 cscli postoverflows list crowdsecurity/rdns crowdsecurity/cdn-whitelist crowdsecurity/ipv6_to_range -o json + rune -0 jq '.postoverflows | length' <(output) + assert_output "3" + + rune -0 cscli postoverflows list crowdsecurity/rdns -o raw + rune -0 grep -vc 'name,status,version,description' <(output) + assert_output "1" + rune -0 cscli postoverflows list crowdsecurity/rdns crowdsecurity/cdn-whitelist crowdsecurity/ipv6_to_range -o raw + rune -0 grep -vc 'name,status,version,description' <(output) + assert_output "3" +} + +@test "cscli postoverflows install" { + rune -1 cscli postoverflows install + assert_stderr --partial 'requires at least 1 arg(s), only received 0' + + # not in hub + rune -1 cscli postoverflows install crowdsecurity/blahblah + assert_stderr --partial "can't find 'crowdsecurity/blahblah' in postoverflows" + + # simple install + rune -0 cscli postoverflows install crowdsecurity/rdns + rune -0 cscli postoverflows inspect crowdsecurity/rdns --no-metrics + assert_output --partial 'crowdsecurity/rdns' + assert_output --partial 'installed: true' + + # autocorrect + rune -1 cscli postoverflows install crowdsecurity/rdnf + assert_stderr --partial "can't find 'crowdsecurity/rdnf' in postoverflows, did you mean 'crowdsecurity/rdns'?" + + # install multiple + rune -0 cscli postoverflows install crowdsecurity/rdns crowdsecurity/cdn-whitelist + rune -0 cscli postoverflows inspect crowdsecurity/rdns --no-metrics + assert_output --partial 'crowdsecurity/rdns' + assert_output --partial 'installed: true' + rune -0 cscli postoverflows inspect crowdsecurity/cdn-whitelist --no-metrics + assert_output --partial 'crowdsecurity/cdn-whitelist' + assert_output --partial 'installed: true' +} + +@test "cscli postoverflows install (file location and download-only)" { + rune -0 cscli postoverflows install crowdsecurity/rdns --download-only + rune -0 cscli postoverflows inspect crowdsecurity/rdns --no-metrics + assert_output --partial 'crowdsecurity/rdns' + assert_output --partial 'installed: false' + assert_file_exists "$HUB_DIR/postoverflows/s00-enrich/crowdsecurity/rdns.yaml" + assert_file_not_exists "$CONFIG_DIR/postoverflows/s00-enrich/rdns.yaml" + + rune -0 cscli postoverflows install crowdsecurity/rdns + rune -0 cscli postoverflows inspect crowdsecurity/rdns --no-metrics + assert_output --partial 'installed: true' + assert_file_exists "$CONFIG_DIR/postoverflows/s00-enrich/rdns.yaml" +} + +@test "cscli postoverflows install --force (tainted)" { + rune -0 cscli postoverflows install crowdsecurity/rdns + echo "dirty" >"$CONFIG_DIR/postoverflows/s00-enrich/rdns.yaml" + + rune -1 cscli postoverflows install crowdsecurity/rdns + assert_stderr --partial "error while installing 'crowdsecurity/rdns': while enabling crowdsecurity/rdns: crowdsecurity/rdns is tainted, won't overwrite unless --force" + + rune -0 cscli postoverflows install crowdsecurity/rdns --force + assert_stderr --partial "Enabled crowdsecurity/rdns" +} + +@test "cscli postoverflow install --ignore (skip on errors)" { + rune -1 cscli postoverflows install foo/bar crowdsecurity/rdns + assert_stderr --partial "can't find 'foo/bar' in postoverflows" + refute_stderr --partial "Enabled postoverflows: crowdsecurity/rdns" + + rune -0 cscli postoverflows install foo/bar crowdsecurity/rdns --ignore + assert_stderr --partial "can't find 'foo/bar' in postoverflows" + assert_stderr --partial "Enabled postoverflows: crowdsecurity/rdns" +} + +@test "cscli postoverflows inspect" { + rune -1 cscli postoverflows inspect + assert_stderr --partial 'requires at least 1 arg(s), only received 0' + # required for metrics + ./instance-crowdsec start + + rune -1 cscli postoverflows inspect blahblah/blahblah + assert_stderr --partial "can't find 'blahblah/blahblah' in postoverflows" + + # one item + rune -0 cscli postoverflows inspect crowdsecurity/rdns --no-metrics + assert_line 'type: postoverflows' + assert_line 'stage: s00-enrich' + assert_line 'name: crowdsecurity/rdns' + assert_line 'author: crowdsecurity' + assert_line 'path: postoverflows/s00-enrich/crowdsecurity/rdns.yaml' + assert_line 'installed: false' + refute_line --partial 'Current metrics:' + + # one item, with metrics + rune -0 cscli postoverflows inspect crowdsecurity/rdns + assert_line --partial 'Current metrics:' + + # one item, json + rune -0 cscli postoverflows inspect crowdsecurity/rdns -o json + rune -0 jq -c '[.type, .stage, .name, .author, .path, .installed]' <(output) + assert_json '["postoverflows","s00-enrich","crowdsecurity/rdns","crowdsecurity","postoverflows/s00-enrich/crowdsecurity/rdns.yaml",false]' + + # one item, raw + rune -0 cscli postoverflows inspect crowdsecurity/rdns -o raw + assert_line 'type: postoverflows' + assert_line 'name: crowdsecurity/rdns' + assert_line 'stage: s00-enrich' + assert_line 'author: crowdsecurity' + assert_line 'path: postoverflows/s00-enrich/crowdsecurity/rdns.yaml' + assert_line 'installed: false' + refute_line --partial 'Current metrics:' + + # multiple items + rune -0 cscli postoverflows inspect crowdsecurity/rdns crowdsecurity/cdn-whitelist --no-metrics + assert_output --partial 'crowdsecurity/rdns' + assert_output --partial 'crowdsecurity/cdn-whitelist' + rune -1 grep -c 'Current metrics:' <(output) + assert_output "0" + + # multiple items, with metrics + rune -0 cscli postoverflows inspect crowdsecurity/rdns crowdsecurity/cdn-whitelist + rune -0 grep -c 'Current metrics:' <(output) + assert_output "2" + + # multiple items, json + rune -0 cscli postoverflows inspect crowdsecurity/rdns crowdsecurity/cdn-whitelist -o json + rune -0 jq -sc '[.[] | [.type, .stage, .name, .author, .path, .installed]]' <(output) + assert_json '[["postoverflows","s00-enrich","crowdsecurity/rdns","crowdsecurity","postoverflows/s00-enrich/crowdsecurity/rdns.yaml",false],["postoverflows","s01-whitelist","crowdsecurity/cdn-whitelist","crowdsecurity","postoverflows/s01-whitelist/crowdsecurity/cdn-whitelist.yaml",false]]' + + # multiple items, raw + rune -0 cscli postoverflows inspect crowdsecurity/rdns crowdsecurity/cdn-whitelist -o raw + assert_output --partial 'crowdsecurity/rdns' + assert_output --partial 'crowdsecurity/cdn-whitelist' + run -1 grep -c 'Current metrics:' <(output) + assert_output "0" +} + +@test "cscli postoverflows remove" { + rune -1 cscli postoverflows remove + assert_stderr --partial "specify at least one postoverflow to remove or '--all'" + rune -1 cscli postoverflows remove blahblah/blahblah + assert_stderr --partial "can't find 'blahblah/blahblah' in postoverflows" + + rune -0 cscli postoverflows install crowdsecurity/rdns --download-only + rune -0 cscli postoverflows remove crowdsecurity/rdns + assert_stderr --partial "removing crowdsecurity/rdns: not installed -- no need to remove" + + rune -0 cscli postoverflows install crowdsecurity/rdns + rune -0 cscli postoverflows remove crowdsecurity/rdns + assert_stderr --partial 'Removed crowdsecurity/rdns' + + rune -0 cscli postoverflows remove crowdsecurity/rdns --purge + assert_stderr --partial 'Removed source file [crowdsecurity/rdns]' + + rune -0 cscli postoverflows remove crowdsecurity/rdns + assert_stderr --partial 'removing crowdsecurity/rdns: not installed -- no need to remove' + + rune -0 cscli postoverflows remove crowdsecurity/rdns --purge --debug + assert_stderr --partial 'removing crowdsecurity/rdns: not downloaded -- no need to remove' + refute_stderr --partial 'Removed source file [crowdsecurity/rdns]' + + # install, then remove, check files + rune -0 cscli postoverflows install crowdsecurity/rdns + assert_file_exists "$CONFIG_DIR/postoverflows/s00-enrich/rdns.yaml" + rune -0 cscli postoverflows remove crowdsecurity/rdns + assert_file_not_exists "$CONFIG_DIR/postoverflows/s00-enrich/rdns.yaml" + + # delete is an alias for remove + rune -0 cscli postoverflows install crowdsecurity/rdns + assert_file_exists "$CONFIG_DIR/postoverflows/s00-enrich/rdns.yaml" + rune -0 cscli postoverflows delete crowdsecurity/rdns + assert_file_not_exists "$CONFIG_DIR/postoverflows/s00-enrich/rdns.yaml" + + # purge + assert_file_exists "$HUB_DIR/postoverflows/s00-enrich/crowdsecurity/rdns.yaml" + rune -0 cscli postoverflows remove crowdsecurity/rdns --purge + assert_file_not_exists "$HUB_DIR/postoverflows/s00-enrich/crowdsecurity/rdns.yaml" + + rune -0 cscli postoverflows install crowdsecurity/rdns crowdsecurity/cdn-whitelist + + # --all + rune -0 cscli postoverflows list -o raw + rune -0 grep -vc 'name,status,version,description' <(output) + assert_output "2" + + rune -0 cscli postoverflows remove --all + + rune -0 cscli postoverflows list -o raw + rune -1 grep -vc 'name,status,version,description' <(output) + assert_output "0" +} + +@test "cscli postoverflows remove --force" { + # remove a postoverflow that belongs to a collection + rune -0 cscli collections install crowdsecurity/auditd + rune -0 cscli postoverflows remove crowdsecurity/auditd-whitelisted-process + assert_stderr --partial "crowdsecurity/auditd-whitelisted-process belongs to collections: [crowdsecurity/auditd]" + assert_stderr --partial "Run 'sudo cscli postoverflows remove crowdsecurity/auditd-whitelisted-process --force' if you want to force remove this postoverflow" +} + +@test "cscli postoverflows upgrade" { + rune -1 cscli postoverflows upgrade + assert_stderr --partial "specify at least one postoverflow to upgrade or '--all'" + rune -1 cscli postoverflows upgrade blahblah/blahblah + assert_stderr --partial "can't find 'blahblah/blahblah' in postoverflows" + rune -0 cscli postoverflows remove crowdsecurity/discord-crawler-whitelist --purge + rune -1 cscli postoverflows upgrade crowdsecurity/discord-crawler-whitelist + assert_stderr --partial "can't upgrade crowdsecurity/discord-crawler-whitelist: not installed" + rune -0 cscli postoverflows install crowdsecurity/discord-crawler-whitelist --download-only + rune -1 cscli postoverflows upgrade crowdsecurity/discord-crawler-whitelist + assert_stderr --partial "can't upgrade crowdsecurity/discord-crawler-whitelist: downloaded but not installed" + + # hash of the string "v0.0" + sha256_0_0="dfebecf42784a31aa3d009dbcec0c657154a034b45f49cf22a895373f6dbf63d" + + # add version 0.0 to all postoverflows + new_hub=$(jq --arg DIGEST "$sha256_0_0" <"$INDEX_PATH" '.postoverflows |= with_entries(.value.versions["0.0"] = {"digest": $DIGEST, "deprecated": false})') + echo "$new_hub" >"$INDEX_PATH" + + rune -0 cscli postoverflows install crowdsecurity/rdns + + echo "v0.0" > "$CONFIG_DIR/postoverflows/s00-enrich/rdns.yaml" + rune -0 cscli postoverflows inspect crowdsecurity/rdns -o json + rune -0 jq -e '.local_version=="0.0"' <(output) + + # upgrade + rune -0 cscli postoverflows upgrade crowdsecurity/rdns + rune -0 cscli postoverflows inspect crowdsecurity/rdns -o json + rune -0 jq -e '.local_version==.version' <(output) + + # taint + echo "dirty" >"$CONFIG_DIR/postoverflows/s00-enrich/rdns.yaml" + # XXX: should return error + rune -0 cscli postoverflows upgrade crowdsecurity/rdns + assert_stderr --partial "crowdsecurity/rdns is tainted, --force to overwrite" + rune -0 cscli postoverflows inspect crowdsecurity/rdns -o json + rune -0 jq -e '.local_version=="?"' <(output) + + # force upgrade with taint + rune -0 cscli postoverflows upgrade crowdsecurity/rdns --force + rune -0 cscli postoverflows inspect crowdsecurity/rdns -o json + rune -0 jq -e '.local_version==.version' <(output) + + # multiple items + rune -0 cscli postoverflows install crowdsecurity/cdn-whitelist + echo "v0.0" >"$CONFIG_DIR/postoverflows/s00-enrich/rdns.yaml" + echo "v0.0" >"$CONFIG_DIR/postoverflows/s01-whitelist/cdn-whitelist.yaml" + rune -0 cscli postoverflows list -o json + rune -0 jq -e '[.postoverflows[].local_version]==["0.0","0.0"]' <(output) + rune -0 cscli postoverflows upgrade crowdsecurity/rdns crowdsecurity/cdn-whitelist + rune -0 cscli postoverflows list -o json + rune -0 jq -e 'any(.postoverflows[].local_version; .=="0.0") | not' <(output) + + # upgrade all + echo "v0.0" >"$CONFIG_DIR/postoverflows/s00-enrich/rdns.yaml" + echo "v0.0" >"$CONFIG_DIR/postoverflows/s01-whitelist/cdn-whitelist.yaml" + rune -0 cscli postoverflows list -o json + rune -0 jq -e '[.postoverflows[].local_version]==["0.0","0.0"]' <(output) + rune -0 cscli postoverflows upgrade --all + rune -0 cscli postoverflows list -o json + rune -0 jq -e 'any(.postoverflows[].local_version; .=="0.0") | not' <(output) +} diff --git a/test/bats/20_hub_scenarios.bats b/test/bats/20_hub_scenarios.bats new file mode 100644 index 00000000000..3ab3d944c93 --- /dev/null +++ b/test/bats/20_hub_scenarios.bats @@ -0,0 +1,382 @@ +#!/usr/bin/env bats +# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si: + +set -u + +setup_file() { + load "../lib/setup_file.sh" + ./instance-data load + HUB_DIR=$(config_get '.config_paths.hub_dir') + export HUB_DIR + INDEX_PATH=$(config_get '.config_paths.index_path') + export INDEX_PATH + CONFIG_DIR=$(config_get '.config_paths.config_dir') + export CONFIG_DIR +} + +teardown_file() { + load "../lib/teardown_file.sh" +} + +setup() { + load "../lib/setup.sh" + load "../lib/bats-file/load.bash" + ./instance-data load + hub_strip_index +} + +teardown() { + ./instance-crowdsec stop +} + +#---------- + +@test "cscli scenarios list" { + hub_purge_all + + # no items + rune -0 cscli scenarios list + assert_output --partial "SCENARIOS" + rune -0 cscli scenarios list -o json + assert_json '{scenarios:[]}' + rune -0 cscli scenarios list -o raw + assert_output 'name,status,version,description' + + # some items + rune -0 cscli scenarios install crowdsecurity/ssh-bf crowdsecurity/telnet-bf + + rune -0 cscli scenarios list + assert_output --partial crowdsecurity/ssh-bf + assert_output --partial crowdsecurity/telnet-bf + rune -0 grep -c enabled <(output) + assert_output "2" + + rune -0 cscli scenarios list -o json + assert_output --partial crowdsecurity/ssh-bf + assert_output --partial crowdsecurity/telnet-bf + rune -0 jq '.scenarios | length' <(output) + assert_output "2" + + rune -0 cscli scenarios list -o raw + assert_output --partial crowdsecurity/ssh-bf + assert_output --partial crowdsecurity/telnet-bf + rune -0 grep -vc 'name,status,version,description' <(output) + assert_output "2" +} + +@test "cscli scenarios list -a" { + expected=$(jq <"$INDEX_PATH" -r '.scenarios | length') + + rune -0 cscli scenarios list -a + rune -0 grep -c disabled <(output) + assert_output "$expected" + + rune -0 cscli scenarios list -o json -a + rune -0 jq '.scenarios | length' <(output) + assert_output "$expected" + + rune -0 cscli scenarios list -o raw -a + rune -0 grep -vc 'name,status,version,description' <(output) + assert_output "$expected" + + # the list should be the same in all formats, and sorted (not case sensitive) + + list_raw=$(cscli scenarios list -o raw -a | tail -n +2 | cut -d, -f1) + list_human=$(cscli scenarios list -o human -a | tail -n +6 | head -n -1 | cut -d' ' -f2) + list_json=$(cscli scenarios list -o json -a | jq -r '.scenarios[].name') + + rune -0 sort -f <<<"$list_raw" + assert_output "$list_raw" + + assert_equal "$list_raw" "$list_json" + assert_equal "$list_raw" "$list_human" +} + +@test "cscli scenarios list [scenario]..." { + # non-existent + rune -1 cscli scenario install foo/bar + assert_stderr --partial "can't find 'foo/bar' in scenarios" + + # not installed + rune -0 cscli scenarios list crowdsecurity/ssh-bf + assert_output --regexp 'crowdsecurity/ssh-bf.*disabled' + + # install two items + rune -0 cscli scenarios install crowdsecurity/ssh-bf crowdsecurity/telnet-bf + + # list an installed item + rune -0 cscli scenarios list crowdsecurity/ssh-bf + assert_output --regexp "crowdsecurity/ssh-bf.*enabled" + refute_output --partial "crowdsecurity/telnet-bf" + + # list multiple installed and non installed items + rune -0 cscli scenarios list crowdsecurity/ssh-bf crowdsecurity/telnet-bf crowdsecurity/aws-bf crowdsecurity/aws-bf + assert_output --partial "crowdsecurity/ssh-bf" + assert_output --partial "crowdsecurity/telnet-bf" + assert_output --partial "crowdsecurity/aws-bf" + + rune -0 cscli scenarios list crowdsecurity/ssh-bf -o json + rune -0 jq '.scenarios | length' <(output) + assert_output "1" + rune -0 cscli scenarios list crowdsecurity/ssh-bf crowdsecurity/telnet-bf crowdsecurity/aws-bf -o json + rune -0 jq '.scenarios | length' <(output) + assert_output "3" + + rune -0 cscli scenarios list crowdsecurity/ssh-bf -o raw + rune -0 grep -vc 'name,status,version,description' <(output) + assert_output "1" + rune -0 cscli scenarios list crowdsecurity/ssh-bf crowdsecurity/telnet-bf crowdsecurity/aws-bf -o raw + rune -0 grep -vc 'name,status,version,description' <(output) + assert_output "3" +} + +@test "cscli scenarios install" { + rune -1 cscli scenarios install + assert_stderr --partial 'requires at least 1 arg(s), only received 0' + + # not in hub + rune -1 cscli scenarios install crowdsecurity/blahblah + assert_stderr --partial "can't find 'crowdsecurity/blahblah' in scenarios" + + # simple install + rune -0 cscli scenarios install crowdsecurity/ssh-bf + rune -0 cscli scenarios inspect crowdsecurity/ssh-bf --no-metrics + assert_output --partial 'crowdsecurity/ssh-bf' + assert_output --partial 'installed: true' + + # autocorrect + rune -1 cscli scenarios install crowdsecurity/ssh-tf + assert_stderr --partial "can't find 'crowdsecurity/ssh-tf' in scenarios, did you mean 'crowdsecurity/ssh-bf'?" + + # install multiple + rune -0 cscli scenarios install crowdsecurity/ssh-bf crowdsecurity/telnet-bf + rune -0 cscli scenarios inspect crowdsecurity/ssh-bf --no-metrics + assert_output --partial 'crowdsecurity/ssh-bf' + assert_output --partial 'installed: true' + rune -0 cscli scenarios inspect crowdsecurity/telnet-bf --no-metrics + assert_output --partial 'crowdsecurity/telnet-bf' + assert_output --partial 'installed: true' +} + +@test "cscli scenarios install (file location and download-only)" { + # simple install + rune -0 cscli scenarios install crowdsecurity/ssh-bf --download-only + rune -0 cscli scenarios inspect crowdsecurity/ssh-bf --no-metrics + assert_output --partial 'crowdsecurity/ssh-bf' + assert_output --partial 'installed: false' + assert_file_exists "$HUB_DIR/scenarios/crowdsecurity/ssh-bf.yaml" + assert_file_not_exists "$CONFIG_DIR/scenarios/ssh-bf.yaml" + + rune -0 cscli scenarios install crowdsecurity/ssh-bf + rune -0 cscli scenarios inspect crowdsecurity/ssh-bf --no-metrics + assert_output --partial 'installed: true' + assert_file_exists "$CONFIG_DIR/scenarios/ssh-bf.yaml" +} + +@test "cscli scenarios install --force (tainted)" { + rune -0 cscli scenarios install crowdsecurity/ssh-bf + echo "dirty" >"$CONFIG_DIR/scenarios/ssh-bf.yaml" + + rune -1 cscli scenarios install crowdsecurity/ssh-bf + assert_stderr --partial "error while installing 'crowdsecurity/ssh-bf': while enabling crowdsecurity/ssh-bf: crowdsecurity/ssh-bf is tainted, won't overwrite unless --force" + + rune -0 cscli scenarios install crowdsecurity/ssh-bf --force + assert_stderr --partial "Enabled crowdsecurity/ssh-bf" +} + +@test "cscli scenarios install --ignore (skip on errors)" { + rune -1 cscli scenarios install foo/bar crowdsecurity/ssh-bf + assert_stderr --partial "can't find 'foo/bar' in scenarios" + refute_stderr --partial "Enabled scenarios: crowdsecurity/ssh-bf" + + rune -0 cscli scenarios install foo/bar crowdsecurity/ssh-bf --ignore + assert_stderr --partial "can't find 'foo/bar' in scenarios" + assert_stderr --partial "Enabled scenarios: crowdsecurity/ssh-bf" +} + +@test "cscli scenarios inspect" { + rune -1 cscli scenarios inspect + assert_stderr --partial 'requires at least 1 arg(s), only received 0' + # required for metrics + ./instance-crowdsec start + + rune -1 cscli scenarios inspect blahblah/blahblah + assert_stderr --partial "can't find 'blahblah/blahblah' in scenarios" + + # one item + rune -0 cscli scenarios inspect crowdsecurity/ssh-bf --no-metrics + assert_line 'type: scenarios' + assert_line 'name: crowdsecurity/ssh-bf' + assert_line 'author: crowdsecurity' + assert_line 'path: scenarios/crowdsecurity/ssh-bf.yaml' + assert_line 'installed: false' + refute_line --partial 'Current metrics:' + + # one item, with metrics + rune -0 cscli scenarios inspect crowdsecurity/ssh-bf + assert_line --partial 'Current metrics:' + + # one item, json + rune -0 cscli scenarios inspect crowdsecurity/ssh-bf -o json + rune -0 jq -c '[.type, .name, .author, .path, .installed]' <(output) + assert_json '["scenarios","crowdsecurity/ssh-bf","crowdsecurity","scenarios/crowdsecurity/ssh-bf.yaml",false]' + + # one item, raw + rune -0 cscli scenarios inspect crowdsecurity/ssh-bf -o raw + assert_line 'type: scenarios' + assert_line 'name: crowdsecurity/ssh-bf' + assert_line 'author: crowdsecurity' + assert_line 'path: scenarios/crowdsecurity/ssh-bf.yaml' + assert_line 'installed: false' + refute_line --partial 'Current metrics:' + + # multiple items + rune -0 cscli scenarios inspect crowdsecurity/ssh-bf crowdsecurity/telnet-bf --no-metrics + assert_output --partial 'crowdsecurity/ssh-bf' + assert_output --partial 'crowdsecurity/telnet-bf' + rune -1 grep -c 'Current metrics:' <(output) + assert_output "0" + + # multiple items, with metrics + rune -0 cscli scenarios inspect crowdsecurity/ssh-bf crowdsecurity/telnet-bf + rune -0 grep -c 'Current metrics:' <(output) + assert_output "2" + + # multiple items, json + rune -0 cscli scenarios inspect crowdsecurity/ssh-bf crowdsecurity/telnet-bf -o json + rune -0 jq -sc '[.[] | [.type, .name, .author, .path, .installed]]' <(output) + assert_json '[["scenarios","crowdsecurity/ssh-bf","crowdsecurity","scenarios/crowdsecurity/ssh-bf.yaml",false],["scenarios","crowdsecurity/telnet-bf","crowdsecurity","scenarios/crowdsecurity/telnet-bf.yaml",false]]' + + # multiple items, raw + rune -0 cscli scenarios inspect crowdsecurity/ssh-bf crowdsecurity/telnet-bf -o raw + assert_output --partial 'crowdsecurity/ssh-bf' + assert_output --partial 'crowdsecurity/telnet-bf' + run -1 grep -c 'Current metrics:' <(output) + assert_output "0" +} + +@test "cscli scenarios remove" { + rune -1 cscli scenarios remove + assert_stderr --partial "specify at least one scenario to remove or '--all'" + rune -1 cscli scenarios remove blahblah/blahblah + assert_stderr --partial "can't find 'blahblah/blahblah' in scenarios" + + rune -0 cscli scenarios install crowdsecurity/ssh-bf --download-only + rune -0 cscli scenarios remove crowdsecurity/ssh-bf + assert_stderr --partial "removing crowdsecurity/ssh-bf: not installed -- no need to remove" + + rune -0 cscli scenarios install crowdsecurity/ssh-bf + rune -0 cscli scenarios remove crowdsecurity/ssh-bf + assert_stderr --partial "Removed crowdsecurity/ssh-bf" + + rune -0 cscli scenarios remove crowdsecurity/ssh-bf --purge + assert_stderr --partial 'Removed source file [crowdsecurity/ssh-bf]' + + rune -0 cscli scenarios remove crowdsecurity/ssh-bf + assert_stderr --partial "removing crowdsecurity/ssh-bf: not installed -- no need to remove" + + rune -0 cscli scenarios remove crowdsecurity/ssh-bf --purge --debug + assert_stderr --partial 'removing crowdsecurity/ssh-bf: not downloaded -- no need to remove' + refute_stderr --partial 'Removed source file [crowdsecurity/ssh-bf]' + + # install, then remove, check files + rune -0 cscli scenarios install crowdsecurity/ssh-bf + assert_file_exists "$CONFIG_DIR/scenarios/ssh-bf.yaml" + rune -0 cscli scenarios remove crowdsecurity/ssh-bf + assert_file_not_exists "$CONFIG_DIR/scenarios/ssh-bf.yaml" + + # delete is an alias for remove + rune -0 cscli scenarios install crowdsecurity/ssh-bf + assert_file_exists "$CONFIG_DIR/scenarios/ssh-bf.yaml" + rune -0 cscli scenarios delete crowdsecurity/ssh-bf + assert_file_not_exists "$CONFIG_DIR/scenarios/ssh-bf.yaml" + + # purge + assert_file_exists "$HUB_DIR/scenarios/crowdsecurity/ssh-bf.yaml" + rune -0 cscli scenarios remove crowdsecurity/ssh-bf --purge + assert_file_not_exists "$HUB_DIR/scenarios/crowdsecurity/ssh-bf.yaml" + + rune -0 cscli scenarios install crowdsecurity/ssh-bf crowdsecurity/telnet-bf + + # --all + rune -0 cscli scenarios list -o raw + rune -0 grep -vc 'name,status,version,description' <(output) + assert_output "2" + + rune -0 cscli scenarios remove --all + + rune -0 cscli scenarios list -o raw + rune -1 grep -vc 'name,status,version,description' <(output) + assert_output "0" +} + +@test "cscli scenarios remove --force" { + # remove a scenario that belongs to a collection + rune -0 cscli collections install crowdsecurity/sshd + rune -0 cscli scenarios remove crowdsecurity/ssh-bf + assert_stderr --partial "crowdsecurity/ssh-bf belongs to collections: [crowdsecurity/sshd]" + assert_stderr --partial "Run 'sudo cscli scenarios remove crowdsecurity/ssh-bf --force' if you want to force remove this scenario" +} + +@test "cscli scenarios upgrade" { + rune -1 cscli scenarios upgrade + assert_stderr --partial "specify at least one scenario to upgrade or '--all'" + rune -1 cscli scenarios upgrade blahblah/blahblah + assert_stderr --partial "can't find 'blahblah/blahblah' in scenarios" + rune -0 cscli scenarios remove crowdsecurity/vsftpd-bf --purge + rune -1 cscli scenarios upgrade crowdsecurity/vsftpd-bf + assert_stderr --partial "can't upgrade crowdsecurity/vsftpd-bf: not installed" + rune -0 cscli scenarios install crowdsecurity/vsftpd-bf --download-only + rune -1 cscli scenarios upgrade crowdsecurity/vsftpd-bf + assert_stderr --partial "can't upgrade crowdsecurity/vsftpd-bf: downloaded but not installed" + + # hash of the string "v0.0" + sha256_0_0="dfebecf42784a31aa3d009dbcec0c657154a034b45f49cf22a895373f6dbf63d" + + # add version 0.0 to all scenarios + new_hub=$(jq --arg DIGEST "$sha256_0_0" <"$INDEX_PATH" '.scenarios |= with_entries(.value.versions["0.0"] = {"digest": $DIGEST, "deprecated": false})') + echo "$new_hub" >"$INDEX_PATH" + + rune -0 cscli scenarios install crowdsecurity/ssh-bf + + echo "v0.0" > "$CONFIG_DIR/scenarios/ssh-bf.yaml" + rune -0 cscli scenarios inspect crowdsecurity/ssh-bf -o json + rune -0 jq -e '.local_version=="0.0"' <(output) + + # upgrade + rune -0 cscli scenarios upgrade crowdsecurity/ssh-bf + rune -0 cscli scenarios inspect crowdsecurity/ssh-bf -o json + rune -0 jq -e '.local_version==.version' <(output) + + # taint + echo "dirty" >"$CONFIG_DIR/scenarios/ssh-bf.yaml" + # XXX: should return error + rune -0 cscli scenarios upgrade crowdsecurity/ssh-bf + assert_stderr --partial "crowdsecurity/ssh-bf is tainted, --force to overwrite" + rune -0 cscli scenarios inspect crowdsecurity/ssh-bf -o json + rune -0 jq -e '.local_version=="?"' <(output) + + # force upgrade with taint + rune -0 cscli scenarios upgrade crowdsecurity/ssh-bf --force + rune -0 cscli scenarios inspect crowdsecurity/ssh-bf -o json + rune -0 jq -e '.local_version==.version' <(output) + + # multiple items + rune -0 cscli scenarios install crowdsecurity/telnet-bf + echo "v0.0" >"$CONFIG_DIR/scenarios/ssh-bf.yaml" + echo "v0.0" >"$CONFIG_DIR/scenarios/telnet-bf.yaml" + rune -0 cscli scenarios list -o json + rune -0 jq -e '[.scenarios[].local_version]==["0.0","0.0"]' <(output) + rune -0 cscli scenarios upgrade crowdsecurity/ssh-bf crowdsecurity/telnet-bf + rune -0 cscli scenarios list -o json + rune -0 jq -e 'any(.scenarios[].local_version; .=="0.0") | not' <(output) + + # upgrade all + echo "v0.0" >"$CONFIG_DIR/scenarios/ssh-bf.yaml" + echo "v0.0" >"$CONFIG_DIR/scenarios/telnet-bf.yaml" + rune -0 cscli scenarios list -o json + rune -0 jq -e '[.scenarios[].local_version]==["0.0","0.0"]' <(output) + rune -0 cscli scenarios upgrade --all + rune -0 cscli scenarios list -o json + rune -0 jq -e 'any(.scenarios[].local_version; .=="0.0") | not' <(output) +} diff --git a/test/bats/30_machines.bats b/test/bats/30_machines.bats index d5ddf840f4c..d4cce67d0b0 100644 --- a/test/bats/30_machines.bats +++ b/test/bats/30_machines.bats @@ -23,20 +23,29 @@ teardown() { #---------- -@test "can list machines as regular user" { - rune -0 cscli machines list -} - @test "we have exactly one machine" { rune -0 cscli machines list -o json rune -0 jq -c '[. | length, .[0].machineId[0:32], .[0].isValidated]' <(output) assert_output '[1,"githubciXXXXXXXXXXXXXXXXXXXXXXXX",true]' } +@test "don't overwrite local credentials by default" { + rune -1 cscli machines add local -a -o json + rune -0 jq -r '.msg' <(stderr) + assert_output --partial 'already exists: please remove it, use "--force" or specify a different file with "-f"' + rune -0 cscli machines add local -a --force + assert_stderr --partial "Machine 'local' successfully added to the local API." +} + +@test "passwords have a size limit" { + rune -1 cscli machines add local --password "$(printf '%73s' '' | tr ' ' x)" + assert_stderr --partial "password too long (max 72 characters)" +} + @test "add a new machine and delete it" { rune -0 cscli machines add -a -f /dev/null CiTestMachine -o human assert_stderr --partial "Machine 'CiTestMachine' successfully added to the local API" - assert_stderr --partial "API credentials dumped to '/dev/null'" + assert_stderr --partial "API credentials written to '/dev/null'" # we now have two machines rune -0 cscli machines list -o json @@ -53,10 +62,42 @@ teardown() { assert_output 1 } +@test "delete non-existent machine" { + # this is not a fatal error, won't halt a script with -e + rune -0 cscli machines delete something + assert_stderr --partial "unable to delete machine: 'something' does not exist" + rune -0 cscli machines delete something --ignore-missing + refute_stderr +} + +@test "machines [delete|inspect] has autocompletion" { + rune -0 cscli machines add -a -f /dev/null foo1 + rune -0 cscli machines add -a -f /dev/null foo2 + rune -0 cscli machines add -a -f /dev/null bar + rune -0 cscli machines add -a -f /dev/null baz + rune -0 cscli __complete machines delete 'foo' + assert_line --index 0 'foo1' + assert_line --index 1 'foo2' + refute_line 'bar' + refute_line 'baz' + rune -0 cscli __complete machines inspect 'foo' + assert_line --index 0 'foo1' + assert_line --index 1 'foo2' + refute_line 'bar' + refute_line 'baz' +} + +@test "heartbeat is initially null" { + rune -0 cscli machines add foo --auto --file /dev/null + rune -0 cscli machines list -o json + rune -0 yq '.[] | select(.machineId == "foo") | .last_heartbeat' <(output) + assert_output null +} + @test "register, validate and then remove a machine" { rune -0 cscli lapi register --machine CiTestMachineRegister -f /dev/null -o human assert_stderr --partial "Successfully registered to Local API (LAPI)" - assert_stderr --partial "Local API credentials dumped to '/dev/null'" + assert_stderr --partial "Local API credentials written to '/dev/null'" # the machine is not validated yet rune -0 cscli machines list -o json @@ -81,3 +122,20 @@ teardown() { rune -0 jq '. | length' <(output) assert_output 1 } + +@test "cscli machines prune" { + rune -0 cscli metrics + + # if the fixture has been created some time ago, + # the machines may be old enough to trigger a user prompt. + # make sure the prune duration is high enough. + rune -0 cscli machines prune --duration 1000000h + assert_output 'No machines to prune.' + + rune -0 cscli machines list -o json + rune -0 jq -r '.[-1].machineId' <(output) + rune -0 cscli machines delete "$output" + + rune -0 cscli machines prune + assert_output 'No machines to prune.' +} diff --git a/test/bats/30_machines_tls.bats b/test/bats/30_machines_tls.bats index 121cdecdf1b..ef02d1b57c3 100644 --- a/test/bats/30_machines_tls.bats +++ b/test/bats/30_machines_tls.bats @@ -3,39 +3,119 @@ set -u +# root: root CA +# inter: intermediate CA +# inter_rev: intermediate CA revoked by root (CRL3) +# leaf: valid client cert +# leaf_rev1: client cert revoked by inter (CRL1) +# leaf_rev2: client cert revoked by inter (CRL2) +# leaf_rev3: client cert (indirectly) revoked by root +# +# CRL1: inter revokes leaf_rev1 +# CRL2: inter revokes leaf_rev2 +# CRL3: root revokes inter_rev +# CRL4: root revokes leaf, but is ignored + setup_file() { load "../lib/setup_file.sh" ./instance-data load - CONFIG_DIR=$(dirname "${CONFIG_YAML}") + CONFIG_DIR=$(dirname "$CONFIG_YAML") export CONFIG_DIR - tmpdir="${BATS_FILE_TMPDIR}" + tmpdir="$BATS_FILE_TMPDIR" export tmpdir - CFDIR="${BATS_TEST_DIRNAME}/testdata/cfssl" + CFDIR="$BATS_TEST_DIRNAME/testdata/cfssl" export CFDIR - #gen the CA - cfssl gencert --initca "${CFDIR}/ca.json" 2>/dev/null | cfssljson --bare "${tmpdir}/ca" - #gen an intermediate - cfssl gencert --initca "${CFDIR}/intermediate.json" 2>/dev/null | cfssljson --bare "${tmpdir}/inter" - cfssl sign -ca "${tmpdir}/ca.pem" -ca-key "${tmpdir}/ca-key.pem" -config "${CFDIR}/profiles.json" -profile intermediate_ca "${tmpdir}/inter.csr" 2>/dev/null | cfssljson --bare "${tmpdir}/inter" - #gen server cert for crowdsec with the intermediate - cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config "${CFDIR}/profiles.json" -profile=server "${CFDIR}/server.json" 2>/dev/null | cfssljson --bare "${tmpdir}/server" - #gen client cert for the agent - cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config "${CFDIR}/profiles.json" -profile=client "${CFDIR}/agent.json" 2>/dev/null | cfssljson --bare "${tmpdir}/agent" - #gen client cert for the agent with an invalid OU - cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config "${CFDIR}/profiles.json" -profile=client "${CFDIR}/agent_invalid.json" 2>/dev/null | cfssljson --bare "${tmpdir}/agent_bad_ou" - #gen client cert for the agent directly signed by the CA, it should be refused by crowdsec as uses the intermediate - cfssl gencert -ca "${tmpdir}/ca.pem" -ca-key "${tmpdir}/ca-key.pem" -config "${CFDIR}/profiles.json" -profile=client "${CFDIR}/agent.json" 2>/dev/null | cfssljson --bare "${tmpdir}/agent_invalid" - - cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config "${CFDIR}/profiles.json" -profile=client "${CFDIR}/agent.json" 2>/dev/null | cfssljson --bare "${tmpdir}/agent_revoked" - serial="$(openssl x509 -noout -serial -in "${tmpdir}/agent_revoked.pem" | cut -d '=' -f2)" - echo "ibase=16; ${serial}" | bc >"${tmpdir}/serials.txt" - cfssl gencrl "${tmpdir}/serials.txt" "${tmpdir}/ca.pem" "${tmpdir}/ca-key.pem" | base64 -d | openssl crl -inform DER -out "${tmpdir}/crl.pem" - - cat "${tmpdir}/ca.pem" "${tmpdir}/inter.pem" > "${tmpdir}/bundle.pem" + # Root CA + cfssl gencert -loglevel 2 \ + --initca "$CFDIR/ca_root.json" \ + | cfssljson --bare "$tmpdir/root" + + # Intermediate CAs (valid or revoked) + for cert in "inter" "inter_rev"; do + cfssl gencert -loglevel 2 \ + --initca "$CFDIR/ca_intermediate.json" \ + | cfssljson --bare "$tmpdir/$cert" + + cfssl sign -loglevel 2 \ + -ca "$tmpdir/root.pem" -ca-key "$tmpdir/root-key.pem" \ + -config "$CFDIR/profiles.json" -profile intermediate_ca "$tmpdir/$cert.csr" \ + | cfssljson --bare "$tmpdir/$cert" + done + + # Server cert for crowdsec with the intermediate + cfssl gencert -loglevel 2 \ + -ca "$tmpdir/inter.pem" -ca-key "$tmpdir/inter-key.pem" \ + -config "$CFDIR/profiles.json" -profile=server "$CFDIR/server.json" \ + | cfssljson --bare "$tmpdir/server" + + # Client certs (valid or revoked) + for cert in "leaf" "leaf_rev1" "leaf_rev2"; do + cfssl gencert -loglevel 3 \ + -ca "$tmpdir/inter.pem" -ca-key "$tmpdir/inter-key.pem" \ + -config "$CFDIR/profiles.json" -profile=client \ + "$CFDIR/agent.json" \ + | cfssljson --bare "$tmpdir/$cert" + done + + # Client cert (by revoked inter) + cfssl gencert -loglevel 3 \ + -ca "$tmpdir/inter_rev.pem" -ca-key "$tmpdir/inter_rev-key.pem" \ + -config "$CFDIR/profiles.json" -profile=client \ + "$CFDIR/agent.json" \ + | cfssljson --bare "$tmpdir/leaf_rev3" + + # Bad client cert (invalid OU) + cfssl gencert -loglevel 3 \ + -ca "$tmpdir/inter.pem" -ca-key "$tmpdir/inter-key.pem" \ + -config "$CFDIR/profiles.json" -profile=client \ + "$CFDIR/agent_invalid.json" \ + | cfssljson --bare "$tmpdir/leaf_bad_ou" + + # Bad client cert (directly signed by the CA, it should be refused by crowdsec as it uses the intermediate) + cfssl gencert -loglevel 3 \ + -ca "$tmpdir/root.pem" -ca-key "$tmpdir/root-key.pem" \ + -config "$CFDIR/profiles.json" -profile=client \ + "$CFDIR/agent.json" \ + | cfssljson --bare "$tmpdir/leaf_invalid" + + truncate -s 0 "$tmpdir/crl.pem" + + # Revoke certs + { + echo '-----BEGIN X509 CRL-----' + cfssl gencrl \ + <(cert_serial_number "$tmpdir/leaf_rev1.pem") \ + "$tmpdir/inter.pem" \ + "$tmpdir/inter-key.pem" + echo '-----END X509 CRL-----' + + echo '-----BEGIN X509 CRL-----' + cfssl gencrl \ + <(cert_serial_number "$tmpdir/leaf_rev2.pem") \ + "$tmpdir/inter.pem" \ + "$tmpdir/inter-key.pem" + echo '-----END X509 CRL-----' + + echo '-----BEGIN X509 CRL-----' + cfssl gencrl \ + <(cert_serial_number "$tmpdir/inter_rev.pem") \ + "$tmpdir/root.pem" \ + "$tmpdir/root-key.pem" + echo '-----END X509 CRL-----' + + echo '-----BEGIN X509 CRL-----' + cfssl gencrl \ + <(cert_serial_number "$tmpdir/leaf.pem") \ + "$tmpdir/root.pem" \ + "$tmpdir/root-key.pem" + echo '-----END X509 CRL-----' + } >> "$tmpdir/crl.pem" + + cat "$tmpdir/root.pem" "$tmpdir/inter.pem" > "$tmpdir/bundle.pem" config_set ' .api.server.tls.cert_file=strenv(tmpdir) + "/server.pem" | @@ -48,7 +128,7 @@ setup_file() { # remove all machines for machine in $(cscli machines list -o json | jq -r '.[].machineId'); do - cscli machines delete "${machine}" >/dev/null 2>&1 + cscli machines delete "$machine" >/dev/null 2>&1 done config_disable_agent @@ -78,72 +158,141 @@ teardown() { @test "missing key_file" { config_set '.api.server.tls.key_file=""' - rune -1 timeout 2s "${CROWDSEC}" - assert_stderr --partial "missing TLS key file" + rune -0 wait-for \ + --err "missing TLS key file" \ + "$CROWDSEC" } @test "missing cert_file" { config_set '.api.server.tls.cert_file=""' - rune -1 timeout 2s "${CROWDSEC}" - assert_stderr --partial "missing TLS cert file" + rune -0 wait-for \ + --err "missing TLS cert file" \ + "$CROWDSEC" } @test "invalid OU for agent" { - config_set "${CONFIG_DIR}/local_api_credentials.yaml" ' + config_set "$CONFIG_DIR/local_api_credentials.yaml" ' .ca_cert_path=strenv(tmpdir) + "/bundle.pem" | - .key_path=strenv(tmpdir) + "/agent_bad_ou-key.pem" | - .cert_path=strenv(tmpdir) + "/agent_bad_ou.pem" | + .key_path=strenv(tmpdir) + "/leaf_bad_ou-key.pem" | + .cert_path=strenv(tmpdir) + "/leaf_bad_ou.pem" | .url="https://127.0.0.1:8080" ' - config_set "${CONFIG_DIR}/local_api_credentials.yaml" 'del(.login,.password)' + config_set "$CONFIG_DIR/local_api_credentials.yaml" 'del(.login,.password)' ./instance-crowdsec start rune -0 cscli machines list -o json assert_output '[]' } @test "we have exactly one machine registered with TLS" { - config_set "${CONFIG_DIR}/local_api_credentials.yaml" ' + config_set "$CONFIG_DIR/local_api_credentials.yaml" ' .ca_cert_path=strenv(tmpdir) + "/bundle.pem" | - .key_path=strenv(tmpdir) + "/agent-key.pem" | - .cert_path=strenv(tmpdir) + "/agent.pem" | + .key_path=strenv(tmpdir) + "/leaf-key.pem" | + .cert_path=strenv(tmpdir) + "/leaf.pem" | .url="https://127.0.0.1:8080" ' - config_set "${CONFIG_DIR}/local_api_credentials.yaml" 'del(.login,.password)' + config_set "$CONFIG_DIR/local_api_credentials.yaml" 'del(.login,.password)' ./instance-crowdsec start rune -0 cscli lapi status + # second connection, test the tls cache + rune -0 cscli lapi status rune -0 cscli machines list -o json rune -0 jq -c '[. | length, .[0].machineId[0:32], .[0].isValidated, .[0].ipAddress, .[0].auth_type]' <(output) assert_output '[1,"localhost@127.0.0.1",true,"127.0.0.1","tls"]' - cscli machines delete localhost@127.0.0.1 + rune -0 cscli machines delete localhost@127.0.0.1 +} + +@test "a machine can still connect with a unix socket, no TLS" { + sock=$(config_get '.api.server.listen_socket') + export sock + + # an agent is a machine too + config_disable_agent + ./instance-crowdsec start + + rune -0 cscli machines add with-socket --auto --force + rune -0 cscli lapi status + + rune -0 cscli machines list -o json + rune -0 jq -c '[. | length, .[0].machineId[0:32], .[0].isValidated, .[0].ipAddress, .[0].auth_type]' <(output) + assert_output '[1,"with-socket",true,"127.0.0.1","password"]' + + # TLS cannot be used with a unix socket + + config_set "$CONFIG_DIR/local_api_credentials.yaml" ' + .ca_cert_path=strenv(tmpdir) + "/bundle.pem" + ' + + rune -1 cscli lapi status + assert_stderr --partial "loading api client: cannot use TLS with a unix socket" + + config_set "$CONFIG_DIR/local_api_credentials.yaml" ' + del(.ca_cert_path) | + .key_path=strenv(tmpdir) + "/leaf-key.pem" + ' + + rune -1 cscli lapi status + assert_stderr --partial "loading api client: cannot use TLS with a unix socket" + + config_set "$CONFIG_DIR/local_api_credentials.yaml" ' + del(.key_path) | + .cert_path=strenv(tmpdir) + "/leaf.pem" + ' + + rune -1 cscli lapi status + assert_stderr --partial "loading api client: cannot use TLS with a unix socket" + + rune -0 cscli machines delete with-socket } @test "invalid cert for agent" { - config_set "${CONFIG_DIR}/local_api_credentials.yaml" ' + config_set "$CONFIG_DIR/local_api_credentials.yaml" ' .ca_cert_path=strenv(tmpdir) + "/bundle.pem" | - .key_path=strenv(tmpdir) + "/agent_invalid-key.pem" | - .cert_path=strenv(tmpdir) + "/agent_invalid.pem" | + .key_path=strenv(tmpdir) + "/leaf_invalid-key.pem" | + .cert_path=strenv(tmpdir) + "/leaf_invalid.pem" | .url="https://127.0.0.1:8080" ' - config_set "${CONFIG_DIR}/local_api_credentials.yaml" 'del(.login,.password)' + config_set "$CONFIG_DIR/local_api_credentials.yaml" 'del(.login,.password)' ./instance-crowdsec start + rune -1 cscli lapi status rune -0 cscli machines list -o json assert_output '[]' } @test "revoked cert for agent" { - config_set "${CONFIG_DIR}/local_api_credentials.yaml" ' - .ca_cert_path=strenv(tmpdir) + "/bundle.pem" | - .key_path=strenv(tmpdir) + "/agent_revoked-key.pem" | - .cert_path=strenv(tmpdir) + "/agent_revoked.pem" | - .url="https://127.0.0.1:8080" - ' + # we have two certificates revoked by different CRL blocks + # we connect twice to test the cache too + for cert in "leaf_rev1" "leaf_rev2" "leaf_rev1" "leaf_rev2"; do + truncate_log + cert="$cert" config_set "$CONFIG_DIR/local_api_credentials.yaml" ' + .ca_cert_path=strenv(tmpdir) + "/bundle.pem" | + .key_path=strenv(tmpdir) + "/" + strenv(cert) + "-key.pem" | + .cert_path=strenv(tmpdir) + "/" + strenv(cert) + ".pem" | + .url="https://127.0.0.1:8080" + ' - config_set "${CONFIG_DIR}/local_api_credentials.yaml" 'del(.login,.password)' - ./instance-crowdsec start - rune -0 cscli machines list -o json - assert_output '[]' + config_set "$CONFIG_DIR/local_api_credentials.yaml" 'del(.login,.password)' + ./instance-crowdsec start + rune -1 cscli lapi status + assert_log --partial "certificate revoked by CRL" + rune -0 cscli machines list -o json + assert_output '[]' + ./instance-crowdsec stop + done } + +# vvv this test must be last, or it can break the ones that follow + +@test "allowed_ou can't contain an empty string" { + config_set ' + .common.log_media="stdout" | + .api.server.tls.agents_allowed_ou=["agent-ou", ""] + ' + rune -1 wait-for "$CROWDSEC" + assert_stderr --partial "allowed_ou configuration contains invalid empty string" +} + +# ^^^ this test must be last, or it can break the ones that follow diff --git a/test/bats/40_cold-logs.bats b/test/bats/40_cold-logs.bats index ad4d5233cc0..070a9eac5f1 100644 --- a/test/bats/40_cold-logs.bats +++ b/test/bats/40_cold-logs.bats @@ -11,9 +11,13 @@ fake_log() { setup_file() { load "../lib/setup_file.sh" - # we reset config and data, and only run the daemon once for all the tests in this file ./instance-data load + + cscli collections install crowdsecurity/sshd --error >/dev/null + cscli parsers install crowdsecurity/syslog-logs --error >/dev/null + cscli parsers install crowdsecurity/dateparse-enrich --error >/dev/null + ./instance-crowdsec start } @@ -28,14 +32,14 @@ setup() { #---------- @test "-type and -dsn are required together" { - rune -1 "${CROWDSEC}" -no-api -type syslog + rune -1 "$CROWDSEC" -no-api -type syslog assert_stderr --partial "-type requires a -dsn argument" - rune -1 "${CROWDSEC}" -no-api -dsn file:///dev/fd/0 + rune -1 "$CROWDSEC" -no-api -dsn file:///dev/fd/0 assert_stderr --partial "-dsn requires a -type argument" } @test "the one-shot mode works" { - rune -0 "${CROWDSEC}" -dsn file://<(fake_log) -type syslog -no-api + rune -0 "$CROWDSEC" -dsn file://<(fake_log) -type syslog -no-api refute_output assert_stderr --partial "single file mode : log_media=stdout daemonize=false" assert_stderr --regexp "Adding file .* to filelist" @@ -66,7 +70,7 @@ setup() { @test "1.1.1.172 has not been banned (range/NOT-contained: -r 1.1.2.0/24)" { rune -0 cscli decisions list -r 1.1.2.0/24 -o json - assert_output 'null' + assert_json '[]' } @test "1.1.1.172 has been banned (exact: -i 1.1.1.172)" { @@ -77,5 +81,5 @@ setup() { @test "1.1.1.173 has not been banned (exact: -i 1.1.1.173)" { rune -0 cscli decisions list -i 1.1.1.173 -o json - assert_output 'null' + assert_json '[]' } diff --git a/test/bats/40_live-ban.bats b/test/bats/40_live-ban.bats index c410cbce5a0..fb5fd1fd435 100644 --- a/test/bats/40_live-ban.bats +++ b/test/bats/40_live-ban.bats @@ -13,6 +13,10 @@ setup_file() { load "../lib/setup_file.sh" # we reset config and data, but run the daemon only in the tests that need it ./instance-data load + + cscli collections install crowdsecurity/sshd --error >/dev/null + cscli parsers install crowdsecurity/syslog-logs --error >/dev/null + cscli parsers install crowdsecurity/dateparse-enrich --error >/dev/null } teardown_file() { @@ -30,16 +34,29 @@ teardown() { #---------- @test "1.1.1.172 has been banned" { - tmpfile=$(TMPDIR="${BATS_TEST_TMPDIR}" mktemp) - touch "${tmpfile}" + tmpfile=$(TMPDIR="$BATS_TEST_TMPDIR" mktemp) + touch "$tmpfile" ACQUIS_YAML=$(config_get '.crowdsec_service.acquisition_path') - echo -e "---\nfilename: ${tmpfile}\nlabels:\n type: syslog\n" >>"${ACQUIS_YAML}" + echo -e "---\nfilename: ${tmpfile}\nlabels:\n type: syslog\n" >>"$ACQUIS_YAML" ./instance-crowdsec start - fake_log >>"${tmpfile}" - sleep 2 - rm -f -- "${tmpfile}" - rune -0 cscli decisions list -o json - rune -0 jq -r '.[].decisions[0].value' <(output) - assert_output '1.1.1.172' + + sleep 0.2 + + fake_log >>"$tmpfile" + + sleep 0.2 + + rm -f -- "$tmpfile" + + found=0 + # this may take some time in CI + for _ in $(seq 1 10); do + if cscli decisions list -o json | jq -r '.[].decisions[0].value' | grep -q '1.1.1.172'; then + found=1 + break + fi + sleep 0.2 + done + assert_equal 1 "$found" } diff --git a/test/bats/50_simulation.bats b/test/bats/50_simulation.bats index 578dcf81a31..bffa50cbccc 100644 --- a/test/bats/50_simulation.bats +++ b/test/bats/50_simulation.bats @@ -12,6 +12,11 @@ fake_log() { setup_file() { load "../lib/setup_file.sh" ./instance-data load + + cscli collections install crowdsecurity/sshd --error >/dev/null + cscli parsers install crowdsecurity/syslog-logs --error >/dev/null + cscli parsers install crowdsecurity/dateparse-enrich --error >/dev/null + ./instance-crowdsec start } @@ -28,7 +33,7 @@ setup() { @test "we have one decision" { rune -0 cscli simulation disable --global - fake_log | "${CROWDSEC}" -dsn file:///dev/fd/0 -type syslog -no-api + fake_log | "$CROWDSEC" -dsn file:///dev/fd/0 -type syslog -no-api rune -0 cscli decisions list -o json rune -0 jq '. | length' <(output) assert_output 1 @@ -36,7 +41,7 @@ setup() { @test "1.1.1.174 has been banned (exact)" { rune -0 cscli simulation disable --global - fake_log | "${CROWDSEC}" -dsn file:///dev/fd/0 -type syslog -no-api + fake_log | "$CROWDSEC" -dsn file:///dev/fd/0 -type syslog -no-api rune -0 cscli decisions list -o json rune -0 jq -r '.[].decisions[0].value' <(output) assert_output '1.1.1.174' @@ -44,7 +49,7 @@ setup() { @test "decision has simulated == false (exact)" { rune -0 cscli simulation disable --global - fake_log | "${CROWDSEC}" -dsn file:///dev/fd/0 -type syslog -no-api + fake_log | "$CROWDSEC" -dsn file:///dev/fd/0 -type syslog -no-api rune -0 cscli decisions list -o json rune -0 jq '.[].decisions[0].simulated' <(output) assert_output 'false' @@ -52,15 +57,28 @@ setup() { @test "simulated scenario, listing non-simulated: expect no decision" { rune -0 cscli simulation enable crowdsecurity/ssh-bf - fake_log | "${CROWDSEC}" -dsn file:///dev/fd/0 -type syslog -no-api + fake_log | "$CROWDSEC" -dsn file:///dev/fd/0 -type syslog -no-api + rune -0 cscli decisions list --no-simu -o json + assert_json '[]' +} + +@test "simulated local scenario: expect no decision" { + CONFIG_DIR=$(dirname "$CONFIG_YAML") + HUB_DIR=$(config_get '.config_paths.hub_dir') + rune -0 mkdir -p "$CONFIG_DIR"/scenarios + # replace an installed scenario with a local version + rune -0 cp -r "$HUB_DIR"/scenarios/crowdsecurity/ssh-bf.yaml "$CONFIG_DIR"/scenarios/ssh-bf2.yaml + rune -0 cscli scenarios remove crowdsecurity/ssh-bf --force --purge + rune -0 cscli simulation enable crowdsecurity/ssh-bf + fake_log | "$CROWDSEC" -dsn file:///dev/fd/0 -type syslog -no-api rune -0 cscli decisions list --no-simu -o json - assert_output 'null' + assert_json '[]' } @test "global simulation, listing non-simulated: expect no decision" { rune -0 cscli simulation disable crowdsecurity/ssh-bf rune -0 cscli simulation enable --global - fake_log | "${CROWDSEC}" -dsn file:///dev/fd/0 -type syslog -no-api + fake_log | "$CROWDSEC" -dsn file:///dev/fd/0 -type syslog -no-api rune -0 cscli decisions list --no-simu -o json - assert_output 'null' + assert_json '[]' } diff --git a/test/bats/70_http_plugin.bats b/test/bats/70_plugin_http.bats similarity index 84% rename from test/bats/70_http_plugin.bats rename to test/bats/70_plugin_http.bats index a8b860aab83..462fc7c9406 100644 --- a/test/bats/70_http_plugin.bats +++ b/test/bats/70_plugin_http.bats @@ -15,7 +15,7 @@ setup_file() { export MOCK_URL PLUGIN_DIR=$(config_get '.config_paths.plugin_dir') # could have a trailing slash - PLUGIN_DIR=$(realpath "${PLUGIN_DIR}") + PLUGIN_DIR=$(realpath "$PLUGIN_DIR") export PLUGIN_DIR # https://mikefarah.gitbook.io/yq/operators/env-variable-operators @@ -35,10 +35,10 @@ setup_file() { .plugin_config.group="" ' - rm -f -- "${MOCK_OUT}" + rm -f -- "$MOCK_OUT" ./instance-crowdsec start - ./instance-mock-http start "${MOCK_PORT}" + ./instance-mock-http start "$MOCK_PORT" } teardown_file() { @@ -63,24 +63,24 @@ setup() { } @test "expected 1 log line from http server" { - rune -0 wc -l <"${MOCK_OUT}" + rune -0 wc -l <"$MOCK_OUT" # wc can pad with spaces on some platforms rune -0 tr -d ' ' < <(output) assert_output 1 } @test "expected to receive 2 alerts in the request body from plugin" { - rune -0 jq -r '.request_body' <"${MOCK_OUT}" + rune -0 jq -r '.request_body' <"$MOCK_OUT" rune -0 jq -r 'length' <(output) assert_output 2 } @test "expected to receive IP 1.2.3.4 as value of first decision" { - rune -0 jq -r '.request_body[0].decisions[0].value' <"${MOCK_OUT}" + rune -0 jq -r '.request_body[0].decisions[0].value' <"$MOCK_OUT" assert_output 1.2.3.4 } @test "expected to receive IP 1.2.3.5 as value of second decision" { - rune -0 jq -r '.request_body[1].decisions[0].value' <"${MOCK_OUT}" + rune -0 jq -r '.request_body[1].decisions[0].value' <"$MOCK_OUT" assert_output 1.2.3.5 } diff --git a/test/bats/71_dummy_plugin.bats b/test/bats/71_plugin_dummy.bats similarity index 83% rename from test/bats/71_dummy_plugin.bats rename to test/bats/71_plugin_dummy.bats index 78352c51459..c242d7ec4bc 100644 --- a/test/bats/71_dummy_plugin.bats +++ b/test/bats/71_plugin_dummy.bats @@ -5,19 +5,19 @@ set -u setup_file() { load "../lib/setup_file.sh" - [[ -n "${PACKAGE_TESTING}" ]] && return + is_package_testing && return ./instance-data load - tempfile=$(TMPDIR="${BATS_FILE_TMPDIR}" mktemp) + tempfile=$(TMPDIR="$BATS_FILE_TMPDIR" mktemp) export tempfile - tempfile2=$(TMPDIR="${BATS_FILE_TMPDIR}" mktemp) + tempfile2=$(TMPDIR="$BATS_FILE_TMPDIR" mktemp) export tempfile2 DUMMY_YAML="$(config_get '.config_paths.notification_dir')/dummy.yaml" - config_set "${DUMMY_YAML}" ' + config_set "$DUMMY_YAML" ' .group_wait="5s" | .group_threshold=2 | .output_file=strenv(tempfile) | @@ -51,7 +51,7 @@ teardown_file() { } setup() { - [[ -n "${PACKAGE_TESTING}" ]] && skip + is_package_testing && skip load "../lib/setup.sh" } @@ -67,12 +67,12 @@ setup() { } @test "expected 1 notification" { - rune -0 cat "${tempfile}" + rune -0 cat "$tempfile" assert_output --partial 1.2.3.4 assert_output --partial 1.2.3.5 } @test "second notification works too" { - rune -0 cat "${tempfile2}" + rune -0 cat "$tempfile2" assert_output --partial secondfile } diff --git a/test/bats/72_plugin_badconfig.bats b/test/bats/72_plugin_badconfig.bats index 9640e333073..7be16c6cf8e 100644 --- a/test/bats/72_plugin_badconfig.bats +++ b/test/bats/72_plugin_badconfig.bats @@ -8,7 +8,7 @@ setup_file() { PLUGIN_DIR=$(config_get '.config_paths.plugin_dir') # could have a trailing slash - PLUGIN_DIR=$(realpath "${PLUGIN_DIR}") + PLUGIN_DIR=$(realpath "$PLUGIN_DIR") export PLUGIN_DIR PROFILES_PATH=$(config_get '.api.server.profiles_path') @@ -26,45 +26,50 @@ setup() { teardown() { ./instance-crowdsec stop - rm -f "${PLUGIN_DIR}"/badname - chmod go-w "${PLUGIN_DIR}"/notification-http + rm -f "$PLUGIN_DIR"/badname + chmod go-w "$PLUGIN_DIR"/notification-http || true } #---------- @test "misconfigured plugin, only user is empty" { config_set '.plugin_config.user="" | .plugin_config.group="nogroup"' - config_set "${PROFILES_PATH}" '.notifications=["http_default"]' - rune -1 timeout 2s "${CROWDSEC}" - assert_stderr --partial "api server init: unable to run plugin broker: while loading plugin: while getting process attributes: both plugin user and group must be set" + config_set "$PROFILES_PATH" '.notifications=["http_default"]' + rune -0 wait-for \ + --err "api server init: unable to run plugin broker: while loading plugin: while getting process attributes: both plugin user and group must be set" \ + "$CROWDSEC" } @test "misconfigured plugin, only group is empty" { config_set '(.plugin_config.user="nobody") | (.plugin_config.group="")' - config_set "${PROFILES_PATH}" '.notifications=["http_default"]' - rune -1 timeout 2s "${CROWDSEC}" - assert_stderr --partial "api server init: unable to run plugin broker: while loading plugin: while getting process attributes: both plugin user and group must be set" + config_set "$PROFILES_PATH" '.notifications=["http_default"]' + rune -0 wait-for \ + --err "api server init: unable to run plugin broker: while loading plugin: while getting process attributes: both plugin user and group must be set" \ + "$CROWDSEC" } @test "misconfigured plugin, user does not exist" { config_set '(.plugin_config.user="userdoesnotexist") | (.plugin_config.group="groupdoesnotexist")' - config_set "${PROFILES_PATH}" '.notifications=["http_default"]' - rune -1 timeout 2s "${CROWDSEC}" - assert_stderr --partial "api server init: unable to run plugin broker: while loading plugin: while getting process attributes: user: unknown user userdoesnotexist" + config_set "$PROFILES_PATH" '.notifications=["http_default"]' + rune -0 wait-for \ + --err "api server init: unable to run plugin broker: while loading plugin: while getting process attributes: user: unknown user userdoesnotexist" \ + "$CROWDSEC" } @test "misconfigured plugin, group does not exist" { config_set '(.plugin_config.user=strenv(USER)) | (.plugin_config.group="groupdoesnotexist")' - config_set "${PROFILES_PATH}" '.notifications=["http_default"]' - rune -1 timeout 2s "${CROWDSEC}" - assert_stderr --partial "api server init: unable to run plugin broker: while loading plugin: while getting process attributes: group: unknown group groupdoesnotexist" + config_set "$PROFILES_PATH" '.notifications=["http_default"]' + rune -0 wait-for \ + --err "api server init: unable to run plugin broker: while loading plugin: while getting process attributes: group: unknown group groupdoesnotexist" \ + "$CROWDSEC" } @test "bad plugin name" { - config_set "${PROFILES_PATH}" '.notifications=["http_default"]' - cp "${PLUGIN_DIR}"/notification-http "${PLUGIN_DIR}"/badname - rune -1 timeout 2s "${CROWDSEC}" - assert_stderr --partial "api server init: unable to run plugin broker: while loading plugin: plugin name ${PLUGIN_DIR}/badname is invalid. Name should be like {type-name}" + config_set "$PROFILES_PATH" '.notifications=["http_default"]' + cp "$PLUGIN_DIR"/notification-http "$PLUGIN_DIR"/badname + rune -0 wait-for \ + --err "api server init: unable to run plugin broker: while loading plugin: plugin name ${PLUGIN_DIR}/badname is invalid. Name should be like {type-name}" \ + "$CROWDSEC" } @test "duplicate notification config" { @@ -72,53 +77,58 @@ teardown() { # email_default has two configurations rune -0 yq -i '.name="email_default"' "$CONFIG_DIR/notifications/http.yaml" # enable a notification, otherwise plugins are ignored - config_set "${PROFILES_PATH}" '.notifications=["slack_default"]' - # we want to check the logs + config_set "$PROFILES_PATH" '.notifications=["slack_default"]' + # the slack plugin may fail or not, but we just need the logs config_set '.common.log_media="stdout"' - # the command will fail because slack_deault is not working - run -1 --separate-stderr timeout 2s "${CROWDSEC}" - # but we have what we wanted - assert_stderr --partial "notification 'email_default' is defined multiple times" + rune wait-for \ + --err "notification 'email_default' is defined multiple times" \ + "$CROWDSEC" } @test "bad plugin permission (group writable)" { - config_set "${PROFILES_PATH}" '.notifications=["http_default"]' - chmod g+w "${PLUGIN_DIR}"/notification-http - rune -1 timeout 2s "${CROWDSEC}" - assert_stderr --partial "api server init: unable to run plugin broker: while loading plugin: plugin at ${PLUGIN_DIR}/notification-http is group writable, group writable plugins are invalid" + config_set "$PROFILES_PATH" '.notifications=["http_default"]' + chmod g+w "$PLUGIN_DIR"/notification-http + rune -0 wait-for \ + --err "api server init: unable to run plugin broker: while loading plugin: plugin at ${PLUGIN_DIR}/notification-http is group writable, group writable plugins are invalid" \ + "$CROWDSEC" } @test "bad plugin permission (world writable)" { - config_set "${PROFILES_PATH}" '.notifications=["http_default"]' - chmod o+w "${PLUGIN_DIR}"/notification-http - rune -1 timeout 2s "${CROWDSEC}" - assert_stderr --partial "api server init: unable to run plugin broker: while loading plugin: plugin at ${PLUGIN_DIR}/notification-http is world writable, world writable plugins are invalid" + config_set "$PROFILES_PATH" '.notifications=["http_default"]' + chmod o+w "$PLUGIN_DIR"/notification-http + rune -0 wait-for \ + --err "api server init: unable to run plugin broker: while loading plugin: plugin at ${PLUGIN_DIR}/notification-http is world writable, world writable plugins are invalid" \ + "$CROWDSEC" } @test "config.yaml: missing .plugin_config section" { config_set 'del(.plugin_config)' - config_set "${PROFILES_PATH}" '.notifications=["http_default"]' - rune -1 timeout 2s "${CROWDSEC}" - assert_stderr --partial "api server init: plugins are enabled, but the plugin_config section is missing in the configuration" + config_set "$PROFILES_PATH" '.notifications=["http_default"]' + rune -0 wait-for \ + --err "api server init: plugins are enabled, but the plugin_config section is missing in the configuration" \ + "$CROWDSEC" } @test "config.yaml: missing config_paths.notification_dir" { config_set 'del(.config_paths.notification_dir)' - config_set "${PROFILES_PATH}" '.notifications=["http_default"]' - rune -1 timeout 2s "${CROWDSEC}" - assert_stderr --partial "api server init: plugins are enabled, but config_paths.notification_dir is not defined" + config_set "$PROFILES_PATH" '.notifications=["http_default"]' + rune -0 wait-for \ + --err "api server init: plugins are enabled, but config_paths.notification_dir is not defined" \ + "$CROWDSEC" } @test "config.yaml: missing config_paths.plugin_dir" { config_set 'del(.config_paths.plugin_dir)' - config_set "${PROFILES_PATH}" '.notifications=["http_default"]' - rune -1 timeout 2s "${CROWDSEC}" - assert_stderr --partial "api server init: plugins are enabled, but config_paths.plugin_dir is not defined" + config_set "$PROFILES_PATH" '.notifications=["http_default"]' + rune -0 wait-for \ + --err "api server init: plugins are enabled, but config_paths.plugin_dir is not defined" \ + "$CROWDSEC" } @test "unable to run plugin broker: while reading plugin config" { config_set '.config_paths.notification_dir="/this/path/does/not/exist"' - config_set "${PROFILES_PATH}" '.notifications=["http_default"]' - rune -1 timeout 2s "${CROWDSEC}" - assert_stderr --partial "api server init: unable to run plugin broker: while loading plugin config: open /this/path/does/not/exist: no such file or directory" + config_set "$PROFILES_PATH" '.notifications=["http_default"]' + rune -0 wait-for \ + --err "api server init: unable to run plugin broker: while loading plugin config: open /this/path/does/not/exist: no such file or directory" \ + "$CROWDSEC" } diff --git a/test/bats/73_plugin_formatting.bats b/test/bats/73_plugin_formatting.bats new file mode 100644 index 00000000000..9ed64837403 --- /dev/null +++ b/test/bats/73_plugin_formatting.bats @@ -0,0 +1,65 @@ +#!/usr/bin/env bats +# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si: + +set -u + +setup_file() { + load "../lib/setup_file.sh" + is_package_testing && return + + ./instance-data load + + tempfile=$(TMPDIR="$BATS_FILE_TMPDIR" mktemp) + export tempfile + + DUMMY_YAML="$(config_get '.config_paths.notification_dir')/dummy.yaml" + + # we test the template that is suggested in the email notification + # the $alert is not a shell variable + # shellcheck disable=SC2016 + config_set "$DUMMY_YAML" ' + .group_wait="5s" | + .group_threshold=2 | + .output_file=strenv(tempfile) | + .format=" {{range . -}} {{$alert := . -}} {{range .Decisions -}}

{{.Value}} will get {{.Type}} for next {{.Duration}} for triggering {{.Scenario}} on machine {{$alert.MachineID}}.

CrowdSec CTI

{{end -}} {{end -}} " + ' + + config_set "$(config_get '.api.server.profiles_path')" ' + .notifications=["dummy_default"] | + .filters=["Alert.GetScope() == \"Ip\""] + ' + + config_set ' + .plugin_config.user="" | + .plugin_config.group="" + ' + + ./instance-crowdsec start +} + +teardown_file() { + load "../lib/teardown_file.sh" +} + +setup() { + is_package_testing && skip + load "../lib/setup.sh" +} + +#---------- + +@test "add two bans" { + rune -0 cscli decisions add --ip 1.2.3.4 --duration 30s + assert_stderr --partial 'Decision successfully added' + + rune -0 cscli decisions add --ip 1.2.3.5 --duration 30s + assert_stderr --partial 'Decision successfully added' + sleep 2 +} + +@test "expected 1 notification" { + rune -0 cat "$tempfile" + assert_output - <<-EOT +

1.2.3.4 will get ban for next 30s for triggering manual 'ban' from 'githubciXXXXXXXXXXXXXXXXXXXXXXXX' on machine githubciXXXXXXXXXXXXXXXXXXXXXXXX.

CrowdSec CTI

1.2.3.5 will get ban for next 30s for triggering manual 'ban' from 'githubciXXXXXXXXXXXXXXXXXXXXXXXX' on machine githubciXXXXXXXXXXXXXXXXXXXXXXXX.

CrowdSec CTI

+ EOT +} diff --git a/test/bats/80_alerts.bats b/test/bats/80_alerts.bats index 0f70223b6bf..6d84c1a1fce 100644 --- a/test/bats/80_alerts.bats +++ b/test/bats/80_alerts.bats @@ -73,9 +73,9 @@ teardown() { rune -0 cscli alerts list -o raw <(output) rune -0 grep 10.20.30.40 <(output) rune -0 cut -d, -f1 <(output) - ALERT_ID="${output}" + ALERT_ID="$output" - rune -0 cscli alerts inspect "${ALERT_ID}" -o human + rune -0 cscli alerts inspect "$ALERT_ID" -o human rune -0 plaintext < <(output) assert_line --regexp '^#+$' assert_line --regexp "^ - ID *: ${ALERT_ID}$" @@ -93,10 +93,10 @@ teardown() { assert_line --regexp "^.* ID .* scope:value .* action .* expiration .* created_at .*$" assert_line --regexp "^.* Ip:10.20.30.40 .* ban .*$" - rune -0 cscli alerts inspect "${ALERT_ID}" -o human --details + rune -0 cscli alerts inspect "$ALERT_ID" -o human --details # XXX can we have something here? - rune -0 cscli alerts inspect "${ALERT_ID}" -o raw + rune -0 cscli alerts inspect "$ALERT_ID" -o raw assert_line --regexp "^ *capacity: 0$" assert_line --regexp "^ *id: ${ALERT_ID}$" assert_line --regexp "^ *origin: cscli$" @@ -106,11 +106,11 @@ teardown() { assert_line --regexp "^ *type: ban$" assert_line --regexp "^ *value: 10.20.30.40$" - rune -0 cscli alerts inspect "${ALERT_ID}" -o json + rune -0 cscli alerts inspect "$ALERT_ID" -o json alert=${output} - rune jq -c '.decisions[] | [.origin,.scenario,.scope,.simulated,.type,.value]' <<<"${alert}" + rune jq -c '.decisions[] | [.origin,.scenario,.scope,.simulated,.type,.value]' <<<"$alert" assert_output --regexp "\[\"cscli\",\"manual 'ban' from 'githubciXXXXXXXXXXXXXXXXXXXXXXXX.*'\",\"Ip\",false,\"ban\",\"10.20.30.40\"\]" - rune jq -c '.source' <<<"${alert}" + rune jq -c '.source' <<<"$alert" assert_json '{ip:"10.20.30.40",scope:"Ip",value:"10.20.30.40"}' } @@ -118,7 +118,7 @@ teardown() { rune -0 cscli alerts list --until 200d -o human assert_output "No active alerts" rune -0 cscli alerts list --until 200d -o json - assert_output "null" + assert_json "[]" rune -0 cscli alerts list --until 200d -o raw assert_output "id,scope,value,reason,country,as,decisions,created_at" rune -0 cscli alerts list --until 200d -o raw --machine @@ -172,7 +172,7 @@ teardown() { rune -0 cscli alerts delete -i 1.2.3.4 assert_stderr --partial 'alert(s) deleted' rune -0 cscli decisions list -o json - assert_output null + assert_json '[]' } @test "cscli alerts delete (must ignore the query limit)" { @@ -188,7 +188,7 @@ teardown() { rune -0 cscli decisions add -i 10.20.30.40 -t ban rune -9 cscli decisions list --ip 10.20.30.40 -o json rune -9 jq -r '.[].decisions[].id' <(output) - DECISION_ID="${output}" + DECISION_ID="$output" ./instance-crowdsec stop rune -0 ./instance-db exec_sql "UPDATE decisions SET ... WHERE id=${DECISION_ID}" diff --git a/test/bats/81_alert_context.bats b/test/bats/81_alert_context.bats index 5dbcc733462..69fb4158ffd 100644 --- a/test/bats/81_alert_context.bats +++ b/test/bats/81_alert_context.bats @@ -20,6 +20,9 @@ teardown_file() { setup() { load "../lib/setup.sh" ./instance-data load + cscli collections install crowdsecurity/sshd --error + cscli parsers install crowdsecurity/syslog-logs --error + cscli parsers install crowdsecurity/dateparse-enrich --error } teardown() { @@ -29,8 +32,8 @@ teardown() { #---------- @test "$FILE 1.1.1.172 has context" { - tmpfile=$(TMPDIR="${BATS_TEST_TMPDIR}" mktemp) - touch "${tmpfile}" + tmpfile=$(TMPDIR="$BATS_TEST_TMPDIR" mktemp) + touch "$tmpfile" ACQUIS_YAML=$(config_get '.crowdsec_service.acquisition_path') @@ -40,7 +43,12 @@ teardown() { type: syslog EOT - CONTEXT_YAML=$(config_get '.crowdsec_service.console_context_path') + # we set the path here because the default is empty + CONFIG_DIR=$(dirname "$CONFIG_YAML") + CONTEXT_YAML="$CONFIG_DIR/console/context.yaml" + export CONTEXT_YAML + config_set '.crowdsec_service.console_context_path=strenv(CONTEXT_YAML)' + mkdir -p "$CONFIG_DIR/console" cat <<-EOT >"${CONTEXT_YAML}" target_user: @@ -53,9 +61,9 @@ teardown() { ./instance-crowdsec start sleep 2 - fake_log >>"${tmpfile}" + fake_log >>"$tmpfile" sleep 2 - rm -f -- "${tmpfile}" + rm -f -- "$tmpfile" rune -0 cscli alerts list -o json rune -0 jq '.[0].id' <(output) diff --git a/test/bats/90_decisions.bats b/test/bats/90_decisions.bats index bcb410de979..b892dc84015 100644 --- a/test/bats/90_decisions.bats +++ b/test/bats/90_decisions.bats @@ -16,7 +16,10 @@ teardown_file() { setup() { load "../lib/setup.sh" + load "../lib/bats-file/load.bash" ./instance-data load + LOGFILE=$(config_get '.common.log_dir')/crowdsec.log + export LOGFILE ./instance-crowdsec start } @@ -28,12 +31,11 @@ teardown() { @test "'decisions add' requires parameters" { rune -1 cscli decisions add - assert_line "Usage:" - assert_stderr --partial "Missing arguments, a value is required (--ip, --range or --scope and --value)" + assert_stderr --partial "missing arguments, a value is required (--ip, --range or --scope and --value)" rune -1 cscli decisions add -o json rune -0 jq -c '[ .level, .msg]' <(stderr | grep "^{") - assert_output '["fatal","Missing arguments, a value is required (--ip, --range or --scope and --value)"]' + assert_output '["fatal","missing arguments, a value is required (--ip, --range or --scope and --value)"]' } @test "cscli decisions list, with and without --machine" { @@ -106,12 +108,12 @@ teardown() { # invalid json rune -1 cscli decisions import -i - <<<'{"blah":"blah"}' --format json assert_stderr --partial 'Parsing json' - assert_stderr --partial 'json: cannot unmarshal object into Go value of type []main.decisionRaw' + assert_stderr --partial 'json: cannot unmarshal object into Go value of type []clidecision.decisionRaw' # json with extra data rune -1 cscli decisions import -i - <<<'{"values":"1.2.3.4","blah":"blah"}' --format json assert_stderr --partial 'Parsing json' - assert_stderr --partial 'json: cannot unmarshal object into Go value of type []main.decisionRaw' + assert_stderr --partial 'json: cannot unmarshal object into Go value of type []clidecision.decisionRaw' #---------- # CSV @@ -151,6 +153,7 @@ teardown() { assert_stderr --partial 'Parsing values' assert_stderr --partial 'Imported 3 decisions' + # leading or trailing spaces are ignored rune -0 cscli decisions import -i - --format values <<-EOT 10.2.3.4 10.2.3.5 @@ -159,11 +162,38 @@ teardown() { assert_stderr --partial 'Parsing values' assert_stderr --partial 'Imported 3 decisions' - rune -1 cscli decisions import -i - --format values <<-EOT + # silently discarding (but logging) invalid decisions + + rune -0 cscli alerts delete --all + truncate -s 0 "$LOGFILE" + + rune -0 cscli decisions import -i - --format values <<-EOT whatever EOT assert_stderr --partial 'Parsing values' - assert_stderr --partial 'API error: unable to create alerts: whatever: invalid ip address / range' + assert_stderr --partial 'Imported 1 decisions' + assert_file_contains "$LOGFILE" "invalid addr/range 'whatever': invalid address" + + rune -0 cscli decisions list -a -o json + assert_json '[]' + + # disarding only some invalid decisions + + rune -0 cscli alerts delete --all + truncate -s 0 "$LOGFILE" + + rune -0 cscli decisions import -i - --format values <<-EOT + 1.2.3.4 + bad-apple + 1.2.3.5 + EOT + assert_stderr --partial 'Parsing values' + assert_stderr --partial 'Imported 3 decisions' + assert_file_contains "$LOGFILE" "invalid addr/range 'bad-apple': invalid address" + + rune -0 cscli decisions list -a -o json + rune -0 jq -r '.[0].decisions | length' <(output) + assert_output 2 #---------- # Batch diff --git a/test/bats/97_ipv4_single.bats b/test/bats/97_ipv4_single.bats index c42836071a6..b709930e2e5 100644 --- a/test/bats/97_ipv4_single.bats +++ b/test/bats/97_ipv4_single.bats @@ -9,8 +9,6 @@ setup_file() { ./instance-crowdsec start API_KEY=$(cscli bouncers add testbouncer -o raw) export API_KEY - CROWDSEC_API_URL="http://localhost:8080" - export CROWDSEC_API_URL } teardown_file() { @@ -22,22 +20,17 @@ setup() { if is_db_mysql; then sleep 0.3; fi } -api() { - URI="$1" - curl -s -H "X-Api-Key: ${API_KEY}" "${CROWDSEC_API_URL}${URI}" -} - #---------- @test "cli - first decisions list: must be empty" { # delete community pull rune -0 cscli decisions delete --all rune -0 cscli decisions list -o json - assert_output 'null' + assert_json '[]' } @test "API - first decisions list: must be empty" { - rune -0 api '/v1/decisions' + rune -0 curl-with-key '/v1/decisions' assert_output 'null' } @@ -53,7 +46,7 @@ api() { } @test "API - all decisions" { - rune -0 api '/v1/decisions' + rune -0 curl-with-key '/v1/decisions' rune -0 jq -c '[ . | length, .[0].value ]' <(output) assert_output '[1,"1.2.3.4"]' } @@ -67,18 +60,18 @@ api() { } @test "API - decision for 1.2.3.4" { - rune -0 api '/v1/decisions?ip=1.2.3.4' + rune -0 curl-with-key '/v1/decisions?ip=1.2.3.4' rune -0 jq -r '.[0].value' <(output) assert_output '1.2.3.4' } @test "CLI - decision for 1.2.3.5" { rune -0 cscli decisions list -i '1.2.3.5' -o json - assert_output 'null' + assert_json '[]' } @test "API - decision for 1.2.3.5" { - rune -0 api '/v1/decisions?ip=1.2.3.5' + rune -0 curl-with-key '/v1/decisions?ip=1.2.3.5' assert_output 'null' } @@ -86,11 +79,11 @@ api() { @test "CLI - decision for 1.2.3.0/24" { rune -0 cscli decisions list -r '1.2.3.0/24' -o json - assert_output 'null' + assert_json '[]' } @test "API - decision for 1.2.3.0/24" { - rune -0 api '/v1/decisions?range=1.2.3.0/24' + rune -0 curl-with-key '/v1/decisions?range=1.2.3.0/24' assert_output 'null' } @@ -101,7 +94,7 @@ api() { } @test "API - decisions where IP in 1.2.3.0/24" { - rune -0 api '/v1/decisions?range=1.2.3.0/24&contains=false' + rune -0 curl-with-key '/v1/decisions?range=1.2.3.0/24&contains=false' rune -0 jq -r '.[0].value' <(output) assert_output '1.2.3.4' } diff --git a/test/bats/97_ipv6_single.bats b/test/bats/97_ipv6_single.bats index 41948fb2597..c7aea030f9c 100644 --- a/test/bats/97_ipv6_single.bats +++ b/test/bats/97_ipv6_single.bats @@ -9,8 +9,6 @@ setup_file() { ./instance-crowdsec start API_KEY=$(cscli bouncers add testbouncer -o raw) export API_KEY - CROWDSEC_API_URL="http://localhost:8080" - export CROWDSEC_API_URL } teardown_file() { @@ -19,12 +17,7 @@ teardown_file() { setup() { load "../lib/setup.sh" - if is_db_mysql; then sleep 0.3; fi -} - -api() { - URI="$1" - curl -s -H "X-Api-Key: ${API_KEY}" "${CROWDSEC_API_URL}${URI}" + if is_db_mysql; then sleep 0.5; fi } #---------- @@ -33,7 +26,7 @@ api() { # delete community pull rune -0 cscli decisions delete --all rune -0 cscli decisions list -o json - assert_output 'null' + assert_json '[]' } @test "adding decision for ip 1111:2222:3333:4444:5555:6666:7777:8888" { @@ -48,7 +41,7 @@ api() { } @test "API - all decisions" { - rune -0 api "/v1/decisions" + rune -0 curl-with-key "/v1/decisions" rune -0 jq -r '.[].value' <(output) assert_output '1111:2222:3333:4444:5555:6666:7777:8888' } @@ -60,38 +53,38 @@ api() { } @test "API - decisions for ip 1111:2222:3333:4444:5555:6666:7777:888" { - rune -0 api '/v1/decisions?ip=1111:2222:3333:4444:5555:6666:7777:8888' + rune -0 curl-with-key '/v1/decisions?ip=1111:2222:3333:4444:5555:6666:7777:8888' rune -0 jq -r '.[].value' <(output) assert_output '1111:2222:3333:4444:5555:6666:7777:8888' } @test "CLI - decisions for ip 1211:2222:3333:4444:5555:6666:7777:8888" { rune -0 cscli decisions list -i '1211:2222:3333:4444:5555:6666:7777:8888' -o json - assert_output 'null' + assert_json '[]' } @test "API - decisions for ip 1211:2222:3333:4444:5555:6666:7777:888" { - rune -0 api '/v1/decisions?ip=1211:2222:3333:4444:5555:6666:7777:8888' + rune -0 curl-with-key '/v1/decisions?ip=1211:2222:3333:4444:5555:6666:7777:8888' assert_output 'null' } @test "CLI - decisions for ip 1111:2222:3333:4444:5555:6666:7777:8887" { rune -0 cscli decisions list -i '1111:2222:3333:4444:5555:6666:7777:8887' -o json - assert_output 'null' + assert_json '[]' } @test "API - decisions for ip 1111:2222:3333:4444:5555:6666:7777:8887" { - rune -0 api '/v1/decisions?ip=1111:2222:3333:4444:5555:6666:7777:8887' + rune -0 curl-with-key '/v1/decisions?ip=1111:2222:3333:4444:5555:6666:7777:8887' assert_output 'null' } @test "CLI - decisions for range 1111:2222:3333:4444:5555:6666:7777:8888/48" { rune -0 cscli decisions list -r '1111:2222:3333:4444:5555:6666:7777:8888/48' -o json - assert_output 'null' + assert_json '[]' } @test "API - decisions for range 1111:2222:3333:4444:5555:6666:7777:8888/48" { - rune -0 api '/v1/decisions?range=1111:2222:3333:4444:5555:6666:7777:8888/48' + rune -0 curl-with-key '/v1/decisions?range=1111:2222:3333:4444:5555:6666:7777:8888/48' assert_output 'null' } @@ -102,18 +95,18 @@ api() { } @test "API - decisions for ip/range in 1111:2222:3333:4444:5555:6666:7777:8888/48" { - rune -0 api '/v1/decisions?range=1111:2222:3333:4444:5555:6666:7777:8888/48&&contains=false' + rune -0 curl-with-key '/v1/decisions?range=1111:2222:3333:4444:5555:6666:7777:8888/48&&contains=false' rune -0 jq -r '.[].value' <(output) assert_output '1111:2222:3333:4444:5555:6666:7777:8888' } @test "CLI - decisions for range 1111:2222:3333:4444:5555:6666:7777:8888/64" { rune -0 cscli decisions list -r '1111:2222:3333:4444:5555:6666:7777:8888/64' -o json - assert_output 'null' + assert_json '[]' } @test "API - decisions for range 1111:2222:3333:4444:5555:6666:7777:8888/64" { - rune -0 api '/v1/decisions?range=1111:2222:3333:4444:5555:6666:7777:8888/64' + rune -0 curl-with-key '/v1/decisions?range=1111:2222:3333:4444:5555:6666:7777:8888/64' assert_output 'null' } @@ -124,7 +117,7 @@ api() { } @test "API - decisions for ip/range in 1111:2222:3333:4444:5555:6666:7777:8888/64" { - rune -0 api '/v1/decisions?range=1111:2222:3333:4444:5555:6666:7777:8888/64&&contains=false' + rune -0 curl-with-key '/v1/decisions?range=1111:2222:3333:4444:5555:6666:7777:8888/64&&contains=false' rune -0 jq -r '.[].value' <(output) assert_output '1111:2222:3333:4444:5555:6666:7777:8888' } @@ -141,7 +134,7 @@ api() { @test "CLI - decisions for ip 1111:2222:3333:4444:5555:6666:7777:8889 after delete" { rune -0 cscli decisions list -i '1111:2222:3333:4444:5555:6666:7777:8889' -o json - assert_output 'null' + assert_json '[]' } @test "deleting decision for range 1111:2222:3333:4444:5555:6666:7777:8888/64" { @@ -151,5 +144,5 @@ api() { @test "CLI - decisions for ip/range in 1111:2222:3333:4444:5555:6666:7777:8888/64 after delete" { rune -0 cscli decisions list -r '1111:2222:3333:4444:5555:6666:7777:8888/64' -o json --contained - assert_output 'null' + assert_json '[]' } diff --git a/test/bats/98_ipv4_range.bats b/test/bats/98_ipv4_range.bats index 1983225b910..c85e40267f3 100644 --- a/test/bats/98_ipv4_range.bats +++ b/test/bats/98_ipv4_range.bats @@ -9,8 +9,6 @@ setup_file() { ./instance-crowdsec start API_KEY=$(cscli bouncers add testbouncer -o raw) export API_KEY - CROWDSEC_API_URL="http://localhost:8080" - export CROWDSEC_API_URL } teardown_file() { @@ -22,18 +20,13 @@ setup() { if is_db_mysql; then sleep 0.3; fi } -api() { - URI="$1" - curl -s -H "X-Api-Key: ${API_KEY}" "${CROWDSEC_API_URL}${URI}" -} - #---------- @test "cli - first decisions list: must be empty" { # delete community pull rune -0 cscli decisions delete --all rune -0 cscli decisions list -o json - assert_output 'null' + assert_json '[]' } @test "adding decision for range 4.4.4.0/24" { @@ -48,7 +41,7 @@ api() { } @test "API - all decisions" { - rune -0 api '/v1/decisions' + rune -0 curl-with-key '/v1/decisions' rune -0 jq -r '.[0].value' <(output) assert_output '4.4.4.0/24' } @@ -62,38 +55,38 @@ api() { } @test "API - decisions for ip 4.4.4." { - rune -0 api '/v1/decisions?ip=4.4.4.3' + rune -0 curl-with-key '/v1/decisions?ip=4.4.4.3' rune -0 jq -r '.[0].value' <(output) assert_output '4.4.4.0/24' } @test "CLI - decisions for ip contained in 4.4.4." { rune -0 cscli decisions list -i '4.4.4.4' -o json --contained - assert_output 'null' + assert_json '[]' } @test "API - decisions for ip contained in 4.4.4." { - rune -0 api '/v1/decisions?ip=4.4.4.4&contains=false' + rune -0 curl-with-key '/v1/decisions?ip=4.4.4.4&contains=false' assert_output 'null' } @test "CLI - decisions for ip 5.4.4." { rune -0 cscli decisions list -i '5.4.4.3' -o json - assert_output 'null' + assert_json '[]' } @test "API - decisions for ip 5.4.4." { - rune -0 api '/v1/decisions?ip=5.4.4.3' + rune -0 curl-with-key '/v1/decisions?ip=5.4.4.3' assert_output 'null' } @test "CLI - decisions for range 4.4.0.0/1" { rune -0 cscli decisions list -r '4.4.0.0/16' -o json - assert_output 'null' + assert_json '[]' } @test "API - decisions for range 4.4.0.0/1" { - rune -0 api '/v1/decisions?range=4.4.0.0/16' + rune -0 curl-with-key '/v1/decisions?range=4.4.0.0/16' assert_output 'null' } @@ -104,7 +97,7 @@ api() { } @test "API - decisions for ip/range in 4.4.0.0/1" { - rune -0 api '/v1/decisions?range=4.4.0.0/16&contains=false' + rune -0 curl-with-key '/v1/decisions?range=4.4.0.0/16&contains=false' rune -0 jq -r '.[0].value' <(output) assert_output '4.4.4.0/24' } @@ -118,17 +111,17 @@ api() { } @test "API - decisions for range 4.4.4.2/2" { - rune -0 api '/v1/decisions?range=4.4.4.2/28' + rune -0 curl-with-key '/v1/decisions?range=4.4.4.2/28' rune -0 jq -r '.[].value' <(output) assert_output '4.4.4.0/24' } @test "CLI - decisions for range 4.4.3.2/2" { rune -0 cscli decisions list -r '4.4.3.2/28' -o json - assert_output 'null' + assert_json '[]' } @test "API - decisions for range 4.4.3.2/2" { - rune -0 api '/v1/decisions?range=4.4.3.2/28' + rune -0 curl-with-key '/v1/decisions?range=4.4.3.2/28' assert_output 'null' } diff --git a/test/bats/98_ipv6_range.bats b/test/bats/98_ipv6_range.bats index b85f0dfcde9..531122a5533 100644 --- a/test/bats/98_ipv6_range.bats +++ b/test/bats/98_ipv6_range.bats @@ -9,8 +9,6 @@ setup_file() { ./instance-crowdsec start API_KEY=$(cscli bouncers add testbouncer -o raw) export API_KEY - CROWDSEC_API_URL="http://localhost:8080" - export CROWDSEC_API_URL } teardown_file() { @@ -22,18 +20,13 @@ setup() { if is_db_mysql; then sleep 0.3; fi } -api() { - URI="$1" - curl -s -H "X-Api-Key: ${API_KEY}" "${CROWDSEC_API_URL}${URI}" -} - #---------- @test "cli - first decisions list: must be empty" { # delete community pull rune -0 cscli decisions delete --all rune -0 cscli decisions list -o json - assert_output 'null' + assert_json '[]' } @test "adding decision for range aaaa:2222:3333:4444::/64" { @@ -48,7 +41,7 @@ api() { } @test "API - all decisions (2)" { - rune -0 api '/v1/decisions' + rune -0 curl-with-key '/v1/decisions' rune -0 jq -r '.[].value' <(output) assert_output 'aaaa:2222:3333:4444::/64' } @@ -62,28 +55,28 @@ api() { } @test "API - decisions for ip aaaa:2222:3333:4444:5555:6666:7777:8888" { - rune -0 api '/v1/decisions?ip=aaaa:2222:3333:4444:5555:6666:7777:8888' + rune -0 curl-with-key '/v1/decisions?ip=aaaa:2222:3333:4444:5555:6666:7777:8888' rune -0 jq -r '.[].value' <(output) assert_output 'aaaa:2222:3333:4444::/64' } @test "CLI - decisions for ip aaaa:2222:3333:4445:5555:6666:7777:8888" { rune -0 cscli decisions list -i 'aaaa:2222:3333:4445:5555:6666:7777:8888' -o json - assert_output 'null' + assert_json '[]' } @test "API - decisions for ip aaaa:2222:3333:4445:5555:6666:7777:8888" { - rune -0 api '/v1/decisions?ip=aaaa:2222:3333:4445:5555:6666:7777:8888' + rune -0 curl-with-key '/v1/decisions?ip=aaaa:2222:3333:4445:5555:6666:7777:8888' assert_output 'null' } @test "CLI - decisions for ip aaa1:2222:3333:4444:5555:6666:7777:8887" { rune -0 cscli decisions list -i 'aaa1:2222:3333:4444:5555:6666:7777:8887' -o json - assert_output 'null' + assert_json '[]' } @test "API - decisions for ip aaa1:2222:3333:4444:5555:6666:7777:8887" { - rune -0 api '/v1/decisions?ip=aaa1:2222:3333:4444:5555:6666:7777:8887' + rune -0 curl-with-key '/v1/decisions?ip=aaa1:2222:3333:4444:5555:6666:7777:8887' assert_output 'null' } @@ -96,29 +89,29 @@ api() { } @test "API - decisions for range aaaa:2222:3333:4444:5555::/80" { - rune -0 api '/v1/decisions?range=aaaa:2222:3333:4444:5555::/80' + rune -0 curl-with-key '/v1/decisions?range=aaaa:2222:3333:4444:5555::/80' rune -0 jq -r '.[].value' <(output) assert_output 'aaaa:2222:3333:4444::/64' } @test "CLI - decisions for range aaaa:2222:3333:4441:5555::/80" { rune -0 cscli decisions list -r 'aaaa:2222:3333:4441:5555::/80' -o json - assert_output 'null' + assert_json '[]' } @test "API - decisions for range aaaa:2222:3333:4441:5555::/80" { - rune -0 api '/v1/decisions?range=aaaa:2222:3333:4441:5555::/80' + rune -0 curl-with-key '/v1/decisions?range=aaaa:2222:3333:4441:5555::/80' assert_output 'null' } @test "CLI - decisions for range aaa1:2222:3333:4444:5555::/80" { rune -0 cscli decisions list -r 'aaa1:2222:3333:4444:5555::/80' -o json - assert_output 'null' + assert_json '[]' } @test "API - decisions for range aaa1:2222:3333:4444:5555::/80" { - rune -0 api '/v1/decisions?range=aaa1:2222:3333:4444:5555::/80' + rune -0 curl-with-key '/v1/decisions?range=aaa1:2222:3333:4444:5555::/80' assert_output 'null' } @@ -126,11 +119,11 @@ api() { @test "CLI - decisions for range aaaa:2222:3333:4444:5555:6666:7777:8888/48" { rune -0 cscli decisions list -r 'aaaa:2222:3333:4444:5555:6666:7777:8888/48' -o json - assert_output 'null' + assert_json '[]' } @test "API - decisions for range aaaa:2222:3333:4444:5555:6666:7777:8888/48" { - rune -0 api '/v1/decisions?range=aaaa:2222:3333:4444:5555:6666:7777:8888/48' + rune -0 curl-with-key '/v1/decisions?range=aaaa:2222:3333:4444:5555:6666:7777:8888/48' assert_output 'null' } @@ -141,18 +134,18 @@ api() { } @test "API - decisions for ip/range in aaaa:2222:3333:4444:5555:6666:7777:8888/48" { - rune -0 api '/v1/decisions?range=aaaa:2222:3333:4444:5555:6666:7777:8888/48&contains=false' + rune -0 curl-with-key '/v1/decisions?range=aaaa:2222:3333:4444:5555:6666:7777:8888/48&contains=false' rune -0 jq -r '.[].value' <(output) assert_output 'aaaa:2222:3333:4444::/64' } @test "CLI - decisions for ip/range in aaaa:2222:3333:4445:5555:6666:7777:8888/48" { rune -0 cscli decisions list -r 'aaaa:2222:3333:4445:5555:6666:7777:8888/48' -o json - assert_output 'null' + assert_json '[]' } @test "API - decisions for ip/range in aaaa:2222:3333:4445:5555:6666:7777:8888/48" { - rune -0 api '/v1/decisions?range=aaaa:2222:3333:4445:5555:6666:7777:8888/48' + rune -0 curl-with-key '/v1/decisions?range=aaaa:2222:3333:4445:5555:6666:7777:8888/48' assert_output 'null' } @@ -170,18 +163,18 @@ api() { } @test "API - decisions for ip in bbbb:db8:0000:0000:0000:6fff:ffff:ffff" { - rune -0 api '/v1/decisions?ip=bbbb:db8:0000:0000:0000:6fff:ffff:ffff' + rune -0 curl-with-key '/v1/decisions?ip=bbbb:db8:0000:0000:0000:6fff:ffff:ffff' rune -0 jq -r '.[].value' <(output) assert_output 'bbbb:db8::/81' } @test "CLI - decisions for ip bbbb:db8:0000:0000:0000:8fff:ffff:ffff" { rune -0 cscli decisions list -o json -i 'bbbb:db8:0000:0000:0000:8fff:ffff:ffff' - assert_output 'null' + assert_json '[]' } @test "API - decisions for ip in bbbb:db8:0000:0000:0000:8fff:ffff:ffff" { - rune -0 api '/v1/decisions?ip=bbbb:db8:0000:0000:0000:8fff:ffff:ffff' + rune -0 curl-with-key '/v1/decisions?ip=bbbb:db8:0000:0000:0000:8fff:ffff:ffff' assert_output 'null' } @@ -192,7 +185,7 @@ api() { @test "CLI - decisions for range aaaa:2222:3333:4444::/64 after delete" { rune -0 cscli decisions list -o json -r 'aaaa:2222:3333:4444::/64' - assert_output 'null' + assert_json '[]' } @test "adding decision for ip bbbb:db8:0000:0000:0000:8fff:ffff:ffff" { diff --git a/test/bats/99_lapi-stream-mode-scenario.bats b/test/bats/99_lapi-stream-mode-scenario.bats index 9b4d562f3c9..32c346061d1 100644 --- a/test/bats/99_lapi-stream-mode-scenario.bats +++ b/test/bats/99_lapi-stream-mode-scenario.bats @@ -9,8 +9,6 @@ setup_file() { ./instance-crowdsec start API_KEY=$(cscli bouncers add testbouncer -o raw) export API_KEY - CROWDSEC_API_URL="http://localhost:8080" - export CROWDSEC_API_URL } teardown_file() { @@ -24,16 +22,10 @@ setup() { #---------- -api() { - URI="$1" - curl -s -H "X-Api-Key:${API_KEY}" "${CROWDSEC_API_URL}${URI}" -} - output_new_decisions() { jq -c '.new | map(select(.origin!="CAPI")) | .[] | del(.id) | (.. | .duration?) |= capture("(?[[:digit:]]+h[[:digit:]]+m)").d' <(output) | sort } - @test "adding decisions with different duration, scenario, origin" { # origin: test rune -0 cscli decisions add -i 127.0.0.1 -d 1h -R crowdsecurity/test @@ -62,7 +54,7 @@ output_new_decisions() { } @test "test startup" { - rune -0 api "/v1/decisions/stream?startup=true" + rune -0 curl-with-key "/v1/decisions/stream?startup=true" rune -0 output_new_decisions assert_output - <<-EOT {"duration":"2h59m","origin":"test","scenario":"crowdsecurity/test","scope":"Ip","type":"ban","value":"127.0.0.2"} @@ -71,7 +63,7 @@ output_new_decisions() { } @test "test startup with scenarios containing" { - rune -0 api "/v1/decisions/stream?startup=true&scenarios_containing=ssh_bf" + rune -0 curl-with-key "/v1/decisions/stream?startup=true&scenarios_containing=ssh_bf" rune -0 output_new_decisions assert_output - <<-EOT {"duration":"2h59m","origin":"another_origin","scenario":"crowdsecurity/ssh_bf","scope":"Ip","type":"ban","value":"127.0.0.1"} @@ -80,7 +72,7 @@ output_new_decisions() { } @test "test startup with multiple scenarios containing" { - rune -0 api "/v1/decisions/stream?startup=true&scenarios_containing=ssh_bf,test" + rune -0 curl-with-key "/v1/decisions/stream?startup=true&scenarios_containing=ssh_bf,test" rune -0 output_new_decisions assert_output - <<-EOT {"duration":"2h59m","origin":"another_origin","scenario":"crowdsecurity/ssh_bf","scope":"Ip","type":"ban","value":"127.0.0.1"} @@ -89,12 +81,12 @@ output_new_decisions() { } @test "test startup with unknown scenarios containing" { - rune -0 api "/v1/decisions/stream?startup=true&scenarios_containing=unknown" + rune -0 curl-with-key "/v1/decisions/stream?startup=true&scenarios_containing=unknown" assert_output '{"deleted":null,"new":null}' } @test "test startup with scenarios containing and not containing" { - rune -0 api "/v1/decisions/stream?startup=true&scenarios_containing=test&scenarios_not_containing=ssh_bf" + rune -0 curl-with-key "/v1/decisions/stream?startup=true&scenarios_containing=test&scenarios_not_containing=ssh_bf" rune -0 output_new_decisions assert_output - <<-EOT {"duration":"2h59m","origin":"test","scenario":"crowdsecurity/test","scope":"Ip","type":"ban","value":"127.0.0.2"} @@ -103,7 +95,7 @@ output_new_decisions() { } @test "test startup with scenarios containing and not containing 2" { - rune -0 api "/v1/decisions/stream?startup=true&scenarios_containing=longest&scenarios_not_containing=ssh_bf,test" + rune -0 curl-with-key "/v1/decisions/stream?startup=true&scenarios_containing=longest&scenarios_not_containing=ssh_bf,test" rune -0 output_new_decisions assert_output - <<-EOT {"duration":"4h59m","origin":"test","scenario":"crowdsecurity/longest","scope":"Ip","type":"ban","value":"127.0.0.1"} @@ -111,7 +103,7 @@ output_new_decisions() { } @test "test startup with scenarios not containing" { - rune -0 api "/v1/decisions/stream?startup=true&scenarios_not_containing=ssh_bf" + rune -0 curl-with-key "/v1/decisions/stream?startup=true&scenarios_not_containing=ssh_bf" rune -0 output_new_decisions assert_output - <<-EOT {"duration":"2h59m","origin":"test","scenario":"crowdsecurity/test","scope":"Ip","type":"ban","value":"127.0.0.2"} @@ -120,7 +112,7 @@ output_new_decisions() { } @test "test startup with multiple scenarios not containing" { - rune -0 api "/v1/decisions/stream?startup=true&scenarios_not_containing=ssh_bf,test" + rune -0 curl-with-key "/v1/decisions/stream?startup=true&scenarios_not_containing=ssh_bf,test" rune -0 output_new_decisions assert_output - <<-EOT {"duration":"4h59m","origin":"test","scenario":"crowdsecurity/longest","scope":"Ip","type":"ban","value":"127.0.0.1"} @@ -128,7 +120,7 @@ output_new_decisions() { } @test "test startup with origins parameter" { - rune -0 api "/v1/decisions/stream?startup=true&origins=another_origin" + rune -0 curl-with-key "/v1/decisions/stream?startup=true&origins=another_origin" rune -0 output_new_decisions assert_output - <<-EOT {"duration":"1h59m","origin":"another_origin","scenario":"crowdsecurity/test","scope":"Ip","type":"ban","value":"127.0.0.2"} @@ -137,7 +129,7 @@ output_new_decisions() { } @test "test startup with multiple origins parameter" { - rune -0 api "/v1/decisions/stream?startup=true&origins=another_origin,test" + rune -0 curl-with-key "/v1/decisions/stream?startup=true&origins=another_origin,test" rune -0 output_new_decisions assert_output - <<-EOT {"duration":"2h59m","origin":"test","scenario":"crowdsecurity/test","scope":"Ip","type":"ban","value":"127.0.0.2"} @@ -146,7 +138,7 @@ output_new_decisions() { } @test "test startup with unknown origins" { - rune -0 api "/v1/decisions/stream?startup=true&origins=unknown" + rune -0 curl-with-key "/v1/decisions/stream?startup=true&origins=unknown" assert_output '{"deleted":null,"new":null}' } @@ -230,4 +222,3 @@ output_new_decisions() { # NewChecks: []DecisionCheck{}, # }, #} - diff --git a/test/bats/99_lapi-stream-mode-scopes.bats b/test/bats/99_lapi-stream-mode-scopes.bats index a1d01c489e6..67badebea0e 100644 --- a/test/bats/99_lapi-stream-mode-scopes.bats +++ b/test/bats/99_lapi-stream-mode-scopes.bats @@ -9,8 +9,6 @@ setup_file() { ./instance-crowdsec start API_KEY=$(cscli bouncers add testbouncer -o raw) export API_KEY - CROWDSEC_API_URL="http://localhost:8080" - export CROWDSEC_API_URL } teardown_file() { @@ -23,11 +21,6 @@ setup() { #---------- -api() { - URI="$1" - curl -s -H "X-Api-Key: ${API_KEY}" "${CROWDSEC_API_URL}${URI}" -} - @test "adding decisions for multiple scopes" { rune -0 cscli decisions add -i '1.2.3.6' assert_stderr --partial 'Decision successfully added' @@ -36,28 +29,28 @@ api() { } @test "stream start (implicit ip scope)" { - rune -0 api "/v1/decisions/stream?startup=true" + rune -0 curl-with-key "/v1/decisions/stream?startup=true" rune -0 jq -r '.new' <(output) assert_output --partial '1.2.3.6' refute_output --partial 'toto' } @test "stream start (explicit ip scope)" { - rune -0 api "/v1/decisions/stream?startup=true&scopes=ip" + rune -0 curl-with-key "/v1/decisions/stream?startup=true&scopes=ip" rune -0 jq -r '.new' <(output) assert_output --partial '1.2.3.6' refute_output --partial 'toto' } @test "stream start (user scope)" { - rune -0 api "/v1/decisions/stream?startup=true&scopes=user" + rune -0 curl-with-key "/v1/decisions/stream?startup=true&scopes=user" rune -0 jq -r '.new' <(output) refute_output --partial '1.2.3.6' assert_output --partial 'toto' } @test "stream start (user+ip scope)" { - rune -0 api "/v1/decisions/stream?startup=true&scopes=user,ip" + rune -0 curl-with-key "/v1/decisions/stream?startup=true&scopes=user,ip" rune -0 jq -r '.new' <(output) assert_output --partial '1.2.3.6' assert_output --partial 'toto' diff --git a/test/bats/99_lapi-stream-mode.bats b/test/bats/99_lapi-stream-mode.bats index 08ddde42c5f..b3ee8a434ff 100644 --- a/test/bats/99_lapi-stream-mode.bats +++ b/test/bats/99_lapi-stream-mode.bats @@ -9,8 +9,6 @@ setup_file() { ./instance-crowdsec start API_KEY=$(cscli bouncers add testbouncer -o raw) export API_KEY - CROWDSEC_API_URL="http://localhost:8080" - export CROWDSEC_API_URL } teardown_file() { @@ -23,11 +21,6 @@ setup() { #---------- -api() { - URI="$1" - curl -s -H "X-Api-Key: ${API_KEY}" "${CROWDSEC_API_URL}${URI}" -} - @test "adding decisions for multiple ips" { rune -0 cscli decisions add -i '1111:2222:3333:4444:5555:6666:7777:8888' assert_stderr --partial 'Decision successfully added' @@ -38,7 +31,7 @@ api() { } @test "stream start" { - rune -0 api "/v1/decisions/stream?startup=true" + rune -0 curl-with-key "/v1/decisions/stream?startup=true" if is_db_mysql; then sleep 3; fi rune -0 jq -r '.new' <(output) assert_output --partial '1111:2222:3333:4444:5555:6666:7777:8888' @@ -49,7 +42,7 @@ api() { @test "stream cont (add)" { rune -0 cscli decisions add -i '1.2.3.5' if is_db_mysql; then sleep 3; fi - rune -0 api "/v1/decisions/stream" + rune -0 curl-with-key "/v1/decisions/stream" rune -0 jq -r '.new' <(output) assert_output --partial '1.2.3.5' } @@ -57,13 +50,13 @@ api() { @test "stream cont (del)" { rune -0 cscli decisions delete -i '1.2.3.4' if is_db_mysql; then sleep 3; fi - rune -0 api "/v1/decisions/stream" + rune -0 curl-with-key "/v1/decisions/stream" rune -0 jq -r '.deleted' <(output) assert_output --partial '1.2.3.4' } @test "stream restart" { - rune -0 api "/v1/decisions/stream?startup=true" + rune -0 curl-with-key "/v1/decisions/stream?startup=true" api_out=${output} rune -0 jq -r '.deleted' <(output) assert_output --partial '1.2.3.4' diff --git a/test/bats/testdata/cfssl/agent.json b/test/bats/testdata/cfssl/agent.json index 693e3aa512b..47b342e5a40 100644 --- a/test/bats/testdata/cfssl/agent.json +++ b/test/bats/testdata/cfssl/agent.json @@ -1,10 +1,10 @@ { - "CN": "localhost", - "key": { - "algo": "rsa", - "size": 2048 - }, - "names": [ + "CN": "localhost", + "key": { + "algo": "rsa", + "size": 2048 + }, + "names": [ { "C": "FR", "L": "Paris", @@ -12,5 +12,5 @@ "OU": "agent-ou", "ST": "France" } - ] - } \ No newline at end of file + ] +} diff --git a/test/bats/testdata/cfssl/agent_invalid.json b/test/bats/testdata/cfssl/agent_invalid.json index c61d4dee677..eb7db8d96fb 100644 --- a/test/bats/testdata/cfssl/agent_invalid.json +++ b/test/bats/testdata/cfssl/agent_invalid.json @@ -1,10 +1,10 @@ { - "CN": "localhost", - "key": { - "algo": "rsa", - "size": 2048 - }, - "names": [ + "CN": "localhost", + "key": { + "algo": "rsa", + "size": 2048 + }, + "names": [ { "C": "FR", "L": "Paris", @@ -12,5 +12,5 @@ "OU": "this-is-not-the-ou-youre-looking-for", "ST": "France" } - ] - } \ No newline at end of file + ] +} diff --git a/test/bats/testdata/cfssl/bouncer.json b/test/bats/testdata/cfssl/bouncer.json index 9a07f576610..bf642c48ad8 100644 --- a/test/bats/testdata/cfssl/bouncer.json +++ b/test/bats/testdata/cfssl/bouncer.json @@ -1,10 +1,10 @@ { - "CN": "localhost", - "key": { - "algo": "rsa", - "size": 2048 - }, - "names": [ + "CN": "localhost", + "key": { + "algo": "rsa", + "size": 2048 + }, + "names": [ { "C": "FR", "L": "Paris", @@ -12,5 +12,5 @@ "OU": "bouncer-ou", "ST": "France" } - ] - } \ No newline at end of file + ] +} diff --git a/test/bats/testdata/cfssl/bouncer_invalid.json b/test/bats/testdata/cfssl/bouncer_invalid.json index c61d4dee677..eb7db8d96fb 100644 --- a/test/bats/testdata/cfssl/bouncer_invalid.json +++ b/test/bats/testdata/cfssl/bouncer_invalid.json @@ -1,10 +1,10 @@ { - "CN": "localhost", - "key": { - "algo": "rsa", - "size": 2048 - }, - "names": [ + "CN": "localhost", + "key": { + "algo": "rsa", + "size": 2048 + }, + "names": [ { "C": "FR", "L": "Paris", @@ -12,5 +12,5 @@ "OU": "this-is-not-the-ou-youre-looking-for", "ST": "France" } - ] - } \ No newline at end of file + ] +} diff --git a/test/bats/testdata/cfssl/ca.json b/test/bats/testdata/cfssl/ca.json deleted file mode 100644 index ed907e0375b..00000000000 --- a/test/bats/testdata/cfssl/ca.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "CN": "CrowdSec Test CA", - "key": { - "algo": "rsa", - "size": 2048 - }, - "names": [ - { - "C": "FR", - "L": "Paris", - "O": "Crowdsec", - "OU": "Crowdsec", - "ST": "France" - } - ] -} \ No newline at end of file diff --git a/test/bats/testdata/cfssl/intermediate.json b/test/bats/testdata/cfssl/ca_intermediate.json similarity index 53% rename from test/bats/testdata/cfssl/intermediate.json rename to test/bats/testdata/cfssl/ca_intermediate.json index 3996ce6e189..34f1583da06 100644 --- a/test/bats/testdata/cfssl/intermediate.json +++ b/test/bats/testdata/cfssl/ca_intermediate.json @@ -1,10 +1,10 @@ { - "CN": "CrowdSec Test CA Intermediate", - "key": { - "algo": "rsa", - "size": 2048 - }, - "names": [ + "CN": "CrowdSec Test CA Intermediate", + "key": { + "algo": "rsa", + "size": 2048 + }, + "names": [ { "C": "FR", "L": "Paris", @@ -12,8 +12,8 @@ "OU": "Crowdsec Intermediate", "ST": "France" } - ], - "ca": { + ], + "ca": { "expiry": "42720h" } - } \ No newline at end of file +} diff --git a/test/bats/testdata/cfssl/ca_root.json b/test/bats/testdata/cfssl/ca_root.json new file mode 100644 index 00000000000..a0d64796637 --- /dev/null +++ b/test/bats/testdata/cfssl/ca_root.json @@ -0,0 +1,16 @@ +{ + "CN": "CrowdSec Test CA", + "key": { + "algo": "rsa", + "size": 2048 + }, + "names": [ + { + "C": "FR", + "L": "Paris", + "O": "Crowdsec", + "OU": "Crowdsec", + "ST": "France" + } + ] +} diff --git a/test/bats/testdata/cfssl/profiles.json b/test/bats/testdata/cfssl/profiles.json index d0dfced4a47..47611beb64c 100644 --- a/test/bats/testdata/cfssl/profiles.json +++ b/test/bats/testdata/cfssl/profiles.json @@ -1,44 +1,37 @@ { - "signing": { - "default": { + "signing": { + "default": { + "expiry": "8760h" + }, + "profiles": { + "intermediate_ca": { + "usages": [ + "signing", + "key encipherment", + "cert sign", + "crl sign", + "server auth", + "client auth" + ], + "expiry": "8760h", + "ca_constraint": { + "is_ca": true, + "max_path_len": 0, + "max_path_len_zero": true + } + }, + "server": { + "usages": [ + "server auth" + ], "expiry": "8760h" }, - "profiles": { - "intermediate_ca": { - "usages": [ - "signing", - "digital signature", - "key encipherment", - "cert sign", - "crl sign", - "server auth", - "client auth" - ], - "expiry": "8760h", - "ca_constraint": { - "is_ca": true, - "max_path_len": 0, - "max_path_len_zero": true - } - }, - "server": { - "usages": [ - "signing", - "digital signing", - "key encipherment", - "server auth" - ], - "expiry": "8760h" - }, - "client": { - "usages": [ - "signing", - "digital signature", - "key encipherment", - "client auth" - ], - "expiry": "8760h" - } + "client": { + "usages": [ + "client auth" + ], + "expiry": "8760h" } } - } \ No newline at end of file + } +} diff --git a/test/bats/testdata/cfssl/server.json b/test/bats/testdata/cfssl/server.json index 37018259e95..cce97037ca7 100644 --- a/test/bats/testdata/cfssl/server.json +++ b/test/bats/testdata/cfssl/server.json @@ -1,10 +1,10 @@ { - "CN": "localhost", - "key": { - "algo": "rsa", - "size": 2048 - }, - "names": [ + "CN": "localhost", + "key": { + "algo": "rsa", + "size": 2048 + }, + "names": [ { "C": "FR", "L": "Paris", @@ -12,9 +12,9 @@ "OU": "Crowdsec Server", "ST": "France" } - ], - "hosts": [ - "127.0.0.1", - "localhost" - ] - } \ No newline at end of file + ], + "hosts": [ + "127.0.0.1", + "localhost" + ] +} diff --git a/test/bats/testdata/explain/explain-log.txt b/test/bats/testdata/explain/explain-log.txt index cf83570db6c..76247412c5c 100644 --- a/test/bats/testdata/explain/explain-log.txt +++ b/test/bats/testdata/explain/explain-log.txt @@ -2,14 +2,10 @@ line: Sep 19 18:33:22 scw-d95986 sshd[24347]: pam_unix(sshd:auth): authenticatio ├ s00-raw | └ đŸŸĸ crowdsecurity/syslog-logs (+12 ~9) ├ s01-parse - | └ đŸŸĸ crowdsecurity/sshd-logs (+8 ~1) - ├ s02-enrich - | ├ đŸŸĸ crowdsecurity/dateparse-enrich (+2 ~2) - | └ đŸŸĸ crowdsecurity/geoip-enrich (+10) + | └ đŸŸĸ crowdsecurity/sshd-logs (+8) ├-------- parser success đŸŸĸ ├ Scenarios ├ đŸŸĸ crowdsecurity/ssh-bf ├ đŸŸĸ crowdsecurity/ssh-bf_user-enum ├ đŸŸĸ crowdsecurity/ssh-slow-bf └ đŸŸĸ crowdsecurity/ssh-slow-bf_user-enum - diff --git a/test/bin/check-requirements b/test/bin/check-requirements index c5580c70237..7c85d365f95 100755 --- a/test/bin/check-requirements +++ b/test/bin/check-requirements @@ -36,12 +36,6 @@ check_jq() { fi } -check_nc() { - if ! command -v nc >/dev/null; then - die "missing required program 'nc' (package 'netcat-openbsd')" - fi -} - check_base64() { if ! command -v base64 >/dev/null; then die "missing required program 'base64'" @@ -66,7 +60,6 @@ check_bats_core check_curl check_daemonizer check_jq -check_nc check_base64 check_python3 check_pkill diff --git a/test/bin/decode-jwt b/test/bin/decode-jwt new file mode 100755 index 00000000000..2e36e9a47df --- /dev/null +++ b/test/bin/decode-jwt @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 + +import base64 +import json +import sys + + +def decode_base64url(data): + # Not the same as "bin/base64 -d": + # + -> - + # / -> _ + # = -> '' + pad = len(data) % 4 + if pad > 0: + data += '=' * (4 - pad) + return base64.urlsafe_b64decode(data) + + +def decode_jwt(token): + token = token.rstrip('\n') + header, payload, signature = token.split('.') + decoded_header = json.loads(decode_base64url(header)) + decoded_payload = json.loads(decode_base64url(payload)) + # the signature is binary, so we don't decode it + + return decoded_header, decoded_payload, signature + + +def main(): + header, payload, signature = decode_jwt(sys.stdin.read()) + out = { + 'header': header, + 'payload': payload, + 'signature': signature, + } + print(json.dumps(out, indent=4)) + + +if __name__ == '__main__': + main() diff --git a/test/bin/generate-hub-tests b/test/bin/generate-hub-tests index 21031285c1d..658cc33a79a 100755 --- a/test/bin/generate-hub-tests +++ b/test/bin/generate-hub-tests @@ -7,44 +7,13 @@ THIS_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd) # shellcheck disable=SC1091 . "${THIS_DIR}/../.environment.sh" -cscli() { - "${CSCLI}" "$@" -} - "${TEST_DIR}/instance-data" load hubdir="${LOCAL_DIR}/hub-tests" git clone --depth 1 https://github.com/crowdsecurity/hub.git "${hubdir}" >/dev/null 2>&1 || (cd "${hubdir}"; git pull) -HUBTESTS_BATS="${TEST_DIR}/dyn-bats/hub.bats" - -cat << EOT > "${HUBTESTS_BATS}" -set -u - -setup_file() { - load "../lib/setup_file.sh" -} - -teardown_file() { - load "../lib/teardown_file.sh" -} - -setup() { - load "../lib/setup.sh" -} - -EOT - echo "Generating hub tests..." -for testname in $("${CSCLI}" --crowdsec "${CROWDSEC}" --cscli "${CSCLI}" hubtest --hub "${hubdir}" list -o json | jq -r '.[] | .Name'); do - cat << EOT >> "${HUBTESTS_BATS}" - -@test "${testname}" { - run "\${CSCLI}" --crowdsec "\${CROWDSEC}" --cscli "\${CSCLI}" --hub "${hubdir}" hubtest run "${testname}" --clean - # in case of error, need to see what went wrong - echo "\$output" - assert_success -} -EOT -done +python3 "$THIS_DIR/generate-hub-tests.py" \ + <("${CSCLI}" --crowdsec "${CROWDSEC}" --cscli "${CSCLI}" hubtest --hub "${hubdir}" list -o json) \ + "${TEST_DIR}/dyn-bats/" diff --git a/test/bin/generate-hub-tests.py b/test/bin/generate-hub-tests.py new file mode 100644 index 00000000000..48f296776d7 --- /dev/null +++ b/test/bin/generate-hub-tests.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 + +import json +import pathlib +import os +import sys +import textwrap + +test_header = """ +set -u + +setup_file() { + load "../lib/setup_file.sh" +} + +teardown_file() { + load "../lib/teardown_file.sh" +} + +setup() { + load "../lib/setup.sh" +} +""" + + +def write_chunk(target_dir, n, chunk): + with open(target_dir / f"hub-{n}.bats", "w") as f: + f.write(test_header) + for test in chunk: + cscli = os.environ['CSCLI'] + crowdsec = os.environ['CROWDSEC'] + testname = test['Name'] + hubdir = os.environ['LOCAL_DIR'] + '/hub-tests' + f.write(textwrap.dedent(f""" + @test "{testname}" {{ + run "{cscli}" \\ + --crowdsec "{crowdsec}" \\ + --cscli "{cscli}" \\ + --hub "{hubdir}" \\ + hubtest run "{testname}" \\ + --clean + echo "$output" + assert_success + }} + """)) + + +def main(): + hubtests_json = sys.argv[1] + target_dir = sys.argv[2] + + with open(hubtests_json) as f: + j = json.load(f) + chunk_size = len(j) // 3 + 1 + n = 1 + for i in range(0, len(j), chunk_size): + chunk = j[i:i + chunk_size] + write_chunk(pathlib.Path(target_dir), n, chunk) + n += 1 + + +if __name__ == "__main__": + main() diff --git a/test/bin/mock-http.py b/test/bin/mock-http.py index 3f26271b400..d11a4ebf717 100644 --- a/test/bin/mock-http.py +++ b/test/bin/mock-http.py @@ -6,6 +6,7 @@ from http.server import HTTPServer, BaseHTTPRequestHandler + class RequestHandler(BaseHTTPRequestHandler): def do_POST(self): request_path = self.path @@ -18,7 +19,7 @@ def do_POST(self): } print(json.dumps(log)) self.send_response(200) - self.send_header('Content-type','application/json') + self.send_header('Content-type', 'application/json') self.end_headers() self.wfile.write(json.dumps({}).encode()) self.wfile.flush() @@ -27,6 +28,7 @@ def do_POST(self): def log_message(self, format, *args): return + def main(argv): try: port = int(argv[1]) @@ -42,6 +44,6 @@ def main(argv): return 0 -if __name__ == "__main__" : +if __name__ == "__main__": logging.basicConfig(level=logging.INFO) sys.exit(main(sys.argv)) diff --git a/test/bin/preload-hub-items b/test/bin/preload-hub-items new file mode 100755 index 00000000000..79e20efbea2 --- /dev/null +++ b/test/bin/preload-hub-items @@ -0,0 +1,32 @@ +#!/usr/bin/env bash + +set -eu + +# shellcheck disable=SC1007 +THIS_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd) +# shellcheck disable=SC1091 +. "${THIS_DIR}/../.environment.sh" + +# pre-download everything but don't install anything + +echo "Pre-downloading Hub content..." + +start=$(date +%s%N) + +types=$("$CSCLI" hub types -o raw) + +for itemtype in $types; do + ALL_ITEMS=$("$CSCLI" "$itemtype" list -a -o json | itemtype="$itemtype" yq '.[env(itemtype)][] | .name') + if [[ -n "${ALL_ITEMS}" ]]; then + #shellcheck disable=SC2086 + "$CSCLI" "$itemtype" install \ + $ALL_ITEMS \ + --download-only + fi +done + +elapsed=$((($(date +%s%N) - start)/1000000)) +# bash only does integer arithmetic, we could use bc or have some fun with sed +elapsed=$(echo "$elapsed" | sed -e 's/...$/.&/;t' -e 's/.$/.0&/') + +echo " done in $elapsed secs." diff --git a/test/bin/remove-all-hub-items b/test/bin/remove-all-hub-items new file mode 100755 index 00000000000..981602b775a --- /dev/null +++ b/test/bin/remove-all-hub-items @@ -0,0 +1,20 @@ +#!/usr/bin/env bash + +set -eu + +# shellcheck disable=SC1007 +THIS_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd) +# shellcheck disable=SC1091 +. "${THIS_DIR}/../.environment.sh" + +# pre-download everything but don't install anything + +echo "Pre-downloading Hub content..." + +types=$("$CSCLI" hub types -o raw) + +for itemtype in $types; do + "$CSCLI" "$itemtype" remove --all --force +done + +echo " done." diff --git a/test/bin/wait-for b/test/bin/wait-for new file mode 100755 index 00000000000..b226783d44b --- /dev/null +++ b/test/bin/wait-for @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 + +import asyncio +import argparse +import os +import re +import signal +import sys + +DEFAULT_TIMEOUT = 30 + +# TODO: signal handler to terminate spawned process group when wait-for is killed +# TODO: better return codes esp. when matches are found +# TODO: multiple patterns (multiple out, err, both) +# TODO: print unmatched patterns + + +async def terminate(p): + # Terminate the process group (shell, crowdsec plugins) + try: + os.killpg(os.getpgid(p.pid), signal.SIGTERM) + except ProcessLookupError: + pass + + +async def monitor(cmd, args, want_out, want_err, timeout): + """Monitor a process and terminate it if a pattern is matched in stdout or stderr. + + Args: + cmd: The command to run. + args: A list of arguments to pass to the command. + stdout: A regular expression pattern to search for in stdout. + stderr: A regular expression pattern to search for in stderr. + timeout: The maximum number of seconds to wait for the process to terminate. + + Returns: + The exit code of the process. + """ + + status = None + + async def read_stream(stream, outstream, pattern): + nonlocal status + if stream is None: + return + while True: + line = await stream.readline() + if line: + line = line.decode('utf-8') + outstream.write(line) + if pattern and pattern.search(line): + await terminate(process) + # this is nasty. + # if we timeout, we want to return a different exit code + # in case of a match, so that the caller can tell + # if the application was still running. + # XXX: still not good for match found, but return code != 0 + if timeout != DEFAULT_TIMEOUT: + status = 128 + else: + status = 0 + break + else: + break + + process = await asyncio.create_subprocess_exec( + cmd, + *args, + # capture stdout + stdout=asyncio.subprocess.PIPE, + # capture stderr + stderr=asyncio.subprocess.PIPE, + # disable buffering + bufsize=0, + # create a new process group + # (required to kill child processes when cmd is a shell) + preexec_fn=os.setsid) + + out_regex = re.compile(want_out) if want_out else None + err_regex = re.compile(want_err) if want_err else None + + # Apply a timeout + try: + await asyncio.wait_for( + asyncio.wait([ + asyncio.create_task(process.wait()), + asyncio.create_task(read_stream(process.stdout, sys.stdout, out_regex)), + asyncio.create_task(read_stream(process.stderr, sys.stderr, err_regex)) + ]), timeout) + if status is None: + status = process.returncode + except asyncio.TimeoutError: + await terminate(process) + status = 241 + + # Return the same exit code, stdout and stderr as the spawned process + return status + + +async def main(): + parser = argparse.ArgumentParser( + description='Monitor a process and terminate it if a pattern is matched in stdout or stderr.') + parser.add_argument('cmd', help='The command to run.') + parser.add_argument('args', nargs=argparse.REMAINDER, help='A list of arguments to pass to the command.') + parser.add_argument('--out', default='', help='A regular expression pattern to search for in stdout.') + parser.add_argument('--err', default='', help='A regular expression pattern to search for in stderr.') + parser.add_argument('--timeout', type=float, default=DEFAULT_TIMEOUT) + args = parser.parse_args() + + exit_code = await monitor(args.cmd, args.args, args.out, args.err, args.timeout) + + sys.exit(exit_code) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/test/bin/wait-for-port b/test/bin/wait-for-port index 4c6c55be7e7..72f26bf409c 100755 --- a/test/bin/wait-for-port +++ b/test/bin/wait-for-port @@ -1,42 +1,60 @@ -#!/usr/bin/env bash - -set -eu - -script_name=$0 - -die() { - echo >&2 "$@" - exit 1 -} - -about() { - die "usage: ${script_name} [-q] " -} - -[[ $# -lt 1 ]] && about - -QUIET= -if [[ "$1" == "-q" ]]; then - QUIET=quiet - shift -fi - -[[ $# -lt 1 ]] && about - -port_number=$1 - -# 4 seconds may seem long, but the tests must work on embedded, slow arm boxes too -for _ in $(seq 40); do - nc -z localhost "${port_number}" >/dev/null 2>&1 && exit 0 - sleep .1 -done - -# send to &3 if open -if { true >&3; } 2>/dev/null; then - [[ -z "${QUIET}" ]] && echo "Can't connect to port ${port_number}" >&3 -else - [[ -z "${QUIET}" ]] && echo "Can't connect to port ${port_number}" >&2 -fi - -exit 1 - +#!/usr/bin/env python3 + +import argparse +import os +import socket +import sys +import time + +initial_interval = 0.02 +max_interval = 0.5 + + +def is_fd_open(fd): + try: + os.fstat(fd) + return True + except OSError: + return False + + +# write to file descriptor 3 if it is open (during bats tests), otherwise stderr +def write_error(ex): + fd = 2 + if is_fd_open(3): + fd = 3 + os.write(fd, str(ex).encode()) + + +def wait(host, port, timeout): + t0 = time.perf_counter() + current_interval = initial_interval + while True: + try: + with socket.create_connection((host, port), timeout=timeout): + break + except OSError as ex: + if time.perf_counter() - t0 >= timeout: + raise TimeoutError(f'Timeout waiting for {host}:{port} after {timeout}s') from ex + time.sleep(current_interval) + current_interval = min(current_interval * 1.5, max_interval) + + +def main(argv): + parser = argparse.ArgumentParser(description="Check if a port is open.") + parser.add_argument("port", type=int, help="Port number to check") + parser.add_argument("--host", type=str, default="localhost", help="Host to check") + parser.add_argument("-t", "--timeout", type=float, default=10.0, help="Timeout duration in seconds") + parser.add_argument("-q", "--quiet", action="store_true", help="Enable quiet mode") + args = parser.parse_args(argv) + + try: + wait(args.host, args.port, args.timeout) + except TimeoutError as ex: + if not args.quiet: + write_error(ex) + sys.exit(1) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/test/disable-capi b/test/disable-capi index f19bef5314c..b847accae48 100755 --- a/test/disable-capi +++ b/test/disable-capi @@ -5,4 +5,4 @@ THIS_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd) # shellcheck disable=SC1091 . "${THIS_DIR}/.environment.sh" -yq e 'del(.api.server.online_client)' -i "${CONFIG_YAML}" +yq e 'del(.api.server.online_client)' -i "$CONFIG_YAML" diff --git a/test/enable-capi b/test/enable-capi index ddbf8764c44..59980e6a059 100755 --- a/test/enable-capi +++ b/test/enable-capi @@ -5,7 +5,7 @@ THIS_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd) # shellcheck disable=SC1091 . "${THIS_DIR}/.environment.sh" -online_api_credentials="$(dirname "${CONFIG_YAML}")/online_api_credentials.yaml" +online_api_credentials="$(dirname "$CONFIG_YAML")/online_api_credentials.yaml" export online_api_credentials -yq e '.api.server.online_client.credentials_path=strenv(online_api_credentials)' -i "${CONFIG_YAML}" +yq e '.api.server.online_client.credentials_path=strenv(online_api_credentials)' -i "$CONFIG_YAML" diff --git a/test/instance-crowdsec b/test/instance-crowdsec index d87145c3881..f0cef729693 100755 --- a/test/instance-crowdsec +++ b/test/instance-crowdsec @@ -2,15 +2,15 @@ #shellcheck disable=SC1007 THIS_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd) -cd "${THIS_DIR}" || exit 1 +cd "$THIS_DIR" || exit 1 # shellcheck disable=SC1091 . ./.environment.sh backend_script="./lib/init/crowdsec-${INIT_BACKEND}" -if [[ ! -x "${backend_script}" ]]; then +if [[ ! -x "$backend_script" ]]; then echo "unknown init system '${INIT_BACKEND}'" >&2 exit 1 fi -exec "${backend_script}" "$@" +exec "$backend_script" "$@" diff --git a/test/instance-data b/test/instance-data index 02742b4ec85..e7fd05a9e54 100755 --- a/test/instance-data +++ b/test/instance-data @@ -1,16 +1,26 @@ #!/usr/bin/env bash +set -eu + +die() { + echo >&2 "$@" + exit 1 +} + #shellcheck disable=SC1007 THIS_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd) -cd "${THIS_DIR}" || exit 1 +cd "$THIS_DIR" || exit 1 # shellcheck disable=SC1091 . ./.environment.sh +if [[ -f "$LOCAL_INIT_DIR/.lock" ]] && [[ "$1" != "unlock" ]]; then + die "init data is locked: are you doing some manual test? if so, please finish what you are doing, run 'instance-data unlock' and retry" +fi + backend_script="./lib/config/config-${CONFIG_BACKEND}" -if [[ ! -x "${backend_script}" ]]; then - echo "unknown config backend '${CONFIG_BACKEND}'" >&2 - exit 1 +if [[ ! -x "$backend_script" ]]; then + die "unknown config backend '${CONFIG_BACKEND}'" fi -exec "${backend_script}" "$@" +exec "$backend_script" "$@" diff --git a/test/instance-db b/test/instance-db index fbbc18dc433..de09465bc32 100755 --- a/test/instance-db +++ b/test/instance-db @@ -2,7 +2,7 @@ #shellcheck disable=SC1007 THIS_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd) -cd "${THIS_DIR}" || exit 1 +cd "$THIS_DIR" || exit 1 # shellcheck disable=SC1091 . ./.environment.sh @@ -10,9 +10,9 @@ cd "${THIS_DIR}" || exit 1 backend_script="./lib/db/instance-${DB_BACKEND}" -if [[ ! -x "${backend_script}" ]]; then +if [[ ! -x "$backend_script" ]]; then echo "unknown database '${DB_BACKEND}'" >&2 exit 1 fi -exec "${backend_script}" "$@" +exec "$backend_script" "$@" diff --git a/test/instance-mock-http b/test/instance-mock-http index cca19b79e3e..b5a56d3489d 100755 --- a/test/instance-mock-http +++ b/test/instance-mock-http @@ -13,7 +13,7 @@ about() { #shellcheck disable=SC1007 THIS_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd) -cd "${THIS_DIR}" +cd "$THIS_DIR" # shellcheck disable=SC1091 . ./.environment.sh @@ -31,7 +31,7 @@ DAEMON_PID=${PID_DIR}/mock-http.pid start_instance() { [[ $# -lt 1 ]] && about daemonize \ - -p "${DAEMON_PID}" \ + -p "$DAEMON_PID" \ -e "${LOG_DIR}/mock-http.err" \ -o "${LOG_DIR}/mock-http.out" \ /usr/bin/env python3 -u "${THIS_DIR}/bin/mock-http.py" "$1" @@ -40,10 +40,10 @@ start_instance() { } stop_instance() { - if [[ -f "${DAEMON_PID}" ]]; then + if [[ -f "$DAEMON_PID" ]]; then # terminate with extreme prejudice, all the application data will be thrown away anyway - kill -9 "$(cat "${DAEMON_PID}")" > /dev/null 2>&1 - rm -f -- "${DAEMON_PID}" + kill -9 "$(cat "$DAEMON_PID")" > /dev/null 2>&1 + rm -f -- "$DAEMON_PID" fi } diff --git a/test/lib/bats-assert b/test/lib/bats-assert index 397c735212b..44913ffe602 160000 --- a/test/lib/bats-assert +++ b/test/lib/bats-assert @@ -1 +1 @@ -Subproject commit 397c735212bf1a06cfdd0cb7806c5a6ea79582bf +Subproject commit 44913ffe6020d1561c4c4d1e26cda8e07a1f374f diff --git a/test/lib/bats-core b/test/lib/bats-core index 6636e2c2ef5..f7defb94362 160000 --- a/test/lib/bats-core +++ b/test/lib/bats-core @@ -1 +1 @@ -Subproject commit 6636e2c2ef5ffe361535cb45fc61682c5ef46b71 +Subproject commit f7defb94362f2053a3e73d13086a167448ea9133 diff --git a/test/lib/bats-file b/test/lib/bats-file index 17fa557f6fe..cb914cdc176 160000 --- a/test/lib/bats-file +++ b/test/lib/bats-file @@ -1 +1 @@ -Subproject commit 17fa557f6fe28a327933e3fa32efef1d211caa5a +Subproject commit cb914cdc176da00e321d3bc92f88383698c701d6 diff --git a/test/lib/bats-support b/test/lib/bats-support index d140a65044b..3c8fadc5097 160000 --- a/test/lib/bats-support +++ b/test/lib/bats-support @@ -1 +1 @@ -Subproject commit d140a65044b2d6810381935ae7f0c94c7023c8c3 +Subproject commit 3c8fadc5097c9acfc96d836dced2bb598e48b009 diff --git a/test/lib/color-formatter b/test/lib/color-formatter new file mode 100755 index 00000000000..aee8d750698 --- /dev/null +++ b/test/lib/color-formatter @@ -0,0 +1,355 @@ +#!/usr/bin/env bash + +# +# Taken from pretty formatter, minus the cursor movements. +# Used in gihtub workflows CI where color is allowed. +# + +set -e + +# shellcheck source=lib/bats-core/formatter.bash +source "$BATS_ROOT/lib/bats-core/formatter.bash" + +BASE_PATH=. +BATS_ENABLE_TIMING= + +while [[ "$#" -ne 0 ]]; do + case "$1" in + -T) + BATS_ENABLE_TIMING="-T" + ;; + --base-path) + shift + normalize_base_path BASE_PATH "$1" + ;; + esac + shift +done + +update_count_column_width() { + count_column_width=$((${#count} * 2 + 2)) + if [[ -n "$BATS_ENABLE_TIMING" ]]; then + # additional space for ' in %s sec' + count_column_width=$((count_column_width + ${#SECONDS} + 8)) + fi + # also update dependent value + update_count_column_left +} + +update_screen_width() { + screen_width="$(tput cols)" + # also update dependent value + update_count_column_left +} + +update_count_column_left() { + count_column_left=$((screen_width - count_column_width)) +} + +# avoid unset variables +count=0 +screen_width=80 +update_count_column_width +#update_screen_width +test_result= + +#trap update_screen_width WINCH + +begin() { + test_result= # reset to avoid carrying over result state from previous test + line_backoff_count=0 + #go_to_column 0 + #update_count_column_width + #buffer_with_truncation $((count_column_left - 1)) ' %s' "$name" + #clear_to_end_of_line + #go_to_column $count_column_left + #if [[ -n "$BATS_ENABLE_TIMING" ]]; then + # buffer "%${#count}s/${count} in %s sec" "$index" "$SECONDS" + #else + # buffer "%${#count}s/${count}" "$index" + #fi + #go_to_column 1 + buffer "%${#count}s" "$index" +} + +finish_test() { + #move_up $line_backoff_count + #go_to_column 0 + buffer "$@" + if [[ -n "${TIMEOUT-}" ]]; then + set_color 2 + if [[ -n "$BATS_ENABLE_TIMING" ]]; then + buffer ' [%s (timeout: %s)]' "$TIMING" "$TIMEOUT" + else + buffer ' [timeout: %s]' "$TIMEOUT" + fi + else + if [[ -n "$BATS_ENABLE_TIMING" ]]; then + set_color 2 + buffer ' [%s]' "$TIMING" + fi + fi + advance + move_down $((line_backoff_count - 1)) +} + +pass() { + local TIMING="${1:-}" + finish_test ' ✓ %s' "$name" + test_result=pass +} + +skip() { + local reason="$1" TIMING="${2:-}" + if [[ -n "$reason" ]]; then + reason=": $reason" + fi + finish_test ' - %s (skipped%s)' "$name" "$reason" + test_result=skip +} + +fail() { + local TIMING="${1:-}" + set_color 1 bold + finish_test ' ✗ %s' "$name" + test_result=fail +} + +timeout() { + local TIMING="${1:-}" + set_color 3 bold + TIMEOUT="${2:-}" finish_test ' ✗ %s' "$name" + test_result=timeout +} + +log() { + case ${test_result} in + pass) + clear_color + ;; + fail) + set_color 1 + ;; + timeout) + set_color 3 + ;; + esac + buffer ' %s\n' "$1" + clear_color +} + +summary() { + if [ "$failures" -eq 0 ]; then + set_color 2 bold + else + set_color 1 bold + fi + + buffer '\n%d test' "$count" + if [[ "$count" -ne 1 ]]; then + buffer 's' + fi + + buffer ', %d failure' "$failures" + if [[ "$failures" -ne 1 ]]; then + buffer 's' + fi + + if [[ "$skipped" -gt 0 ]]; then + buffer ', %d skipped' "$skipped" + fi + + if ((timed_out > 0)); then + buffer ', %d timed out' "$timed_out" + fi + + not_run=$((count - passed - failures - skipped - timed_out)) + if [[ "$not_run" -gt 0 ]]; then + buffer ', %d not run' "$not_run" + fi + + if [[ -n "$BATS_ENABLE_TIMING" ]]; then + buffer " in $SECONDS seconds" + fi + + buffer '\n' + clear_color +} + +buffer_with_truncation() { + local width="$1" + shift + local string + + # shellcheck disable=SC2059 + printf -v 'string' -- "$@" + + if [[ "${#string}" -gt "$width" ]]; then + buffer '%s...' "${string:0:$((width - 4))}" + else + buffer '%s' "$string" + fi +} + +move_up() { + if [[ $1 -gt 0 ]]; then # avoid moving if we got 0 + buffer '\x1B[%dA' "$1" + fi +} + +move_down() { + if [[ $1 -gt 0 ]]; then # avoid moving if we got 0 + buffer '\x1B[%dB' "$1" + fi +} + +go_to_column() { + local column="$1" + buffer '\x1B[%dG' $((column + 1)) +} + +clear_to_end_of_line() { + buffer '\x1B[K' +} + +advance() { + clear_to_end_of_line + buffer '\n' + clear_color +} + +set_color() { + local color="$1" + local weight=22 + + if [[ "${2:-}" == 'bold' ]]; then + weight=1 + fi + buffer '\x1B[%d;%dm' "$((30 + color))" "$weight" +} + +clear_color() { + buffer '\x1B[0m' +} + +_buffer= + +buffer() { + local content + # shellcheck disable=SC2059 + printf -v content -- "$@" + _buffer+="$content" +} + +prefix_buffer_with() { + local old_buffer="$_buffer" + _buffer='' + "$@" + _buffer="$_buffer$old_buffer" +} + +flush() { + printf '%s' "$_buffer" + _buffer= +} + +finish() { + flush + printf '\n' +} + +trap finish EXIT +trap '' INT + +bats_tap_stream_plan() { + count="$1" + index=0 + passed=0 + failures=0 + skipped=0 + timed_out=0 + name= + update_count_column_width +} + +bats_tap_stream_begin() { + index="$1" + name="$2" + begin + flush +} + +bats_tap_stream_ok() { + index="$1" + name="$2" + ((++passed)) + + pass "${BATS_FORMATTER_TEST_DURATION:-}" +} + +bats_tap_stream_skipped() { + index="$1" + name="$2" + ((++skipped)) + skip "$3" "${BATS_FORMATTER_TEST_DURATION:-}" +} + +bats_tap_stream_not_ok() { + index="$1" + name="$2" + + if [[ ${BATS_FORMATTER_TEST_TIMEOUT-x} != x ]]; then + timeout "${BATS_FORMATTER_TEST_DURATION:-}" "${BATS_FORMATTER_TEST_TIMEOUT}s" + ((++timed_out)) + else + fail "${BATS_FORMATTER_TEST_DURATION:-}" + ((++failures)) + fi + +} + +bats_tap_stream_comment() { # + local scope=$2 + # count the lines we printed after the begin text, + if [[ $line_backoff_count -eq 0 && $scope == begin ]]; then + # if this is the first line after begin, go down one line + buffer "\n" + ((++line_backoff_count)) # prefix-increment to avoid "error" due to returning 0 + fi + + ((++line_backoff_count)) + ((line_backoff_count += ${#1} / screen_width)) # account for linebreaks due to length + log "$1" +} + +bats_tap_stream_suite() { + #test_file="$1" + line_backoff_count=0 + index= + # indicate filename for failures + local file_name="${1#"$BASE_PATH"}" + name="File $file_name" + set_color 4 bold + buffer "%s\n" "$file_name" + clear_color +} + +line_backoff_count=0 +bats_tap_stream_unknown() { # + local scope=$2 + # count the lines we printed after the begin text, (or after suite, in case of syntax errors) + if [[ $line_backoff_count -eq 0 && ($scope == begin || $scope == suite) ]]; then + # if this is the first line after begin, go down one line + buffer "\n" + ((++line_backoff_count)) # prefix-increment to avoid "error" due to returning 0 + fi + + ((++line_backoff_count)) + ((line_backoff_count += ${#1} / screen_width)) # account for linebreaks due to length + buffer "%s\n" "$1" + flush +} + +bats_parse_internal_extended_tap + +summary diff --git a/test/lib/config/config-global b/test/lib/config/config-global index 592a927c2e8..9b2b71c1dd1 100755 --- a/test/lib/config/config-global +++ b/test/lib/config/config-global @@ -38,6 +38,8 @@ DATA_DIR="${LOCAL_DIR}/${REL_DATA_DIR}" export DATA_DIR CONFIG_DIR="${LOCAL_DIR}/${REL_CONFIG_DIR}" export CONFIG_DIR +HUB_DIR="${CONFIG_DIR}/hub" +export HUB_DIR if [[ $(uname) == "OpenBSD" ]]; then TAR=gtar @@ -52,18 +54,33 @@ remove_init_data() { # we need a separate function for initializing config when testing package # because we want to test the configuration as well +config_prepare() { + # remove trailing slash from CONFIG_DIR + # since it's assumed to be missing during the tests + yq e -i ' + .api.server.listen_socket="/run/crowdsec.sock" | + .config_paths.config_dir |= sub("/$", "") + ' "${CONFIG_DIR}/config.yaml" +} + make_init_data() { ./bin/assert-crowdsec-not-running || die "Cannot create fixture data." + config_prepare ./instance-db config-yaml ./instance-db setup + # preload some content and data files + "$CSCLI" collections install crowdsecurity/linux --download-only + # sub-items did not respect --download-only + ./bin/remove-all-hub-items + # when installed packages are always using sqlite, so no need to regenerate # local credz for sqlite - [[ "${DB_BACKEND}" == "sqlite" ]] || ${CSCLI} machines add --auto + [[ "${DB_BACKEND}" == "sqlite" ]] || ${CSCLI} machines add githubciXXXXXXXXXXXXXXXXXXXXXXXX --auto --force - mkdir -p "${LOCAL_INIT_DIR}" + mkdir -p "$LOCAL_INIT_DIR" ./instance-db dump "${LOCAL_INIT_DIR}/database" @@ -96,7 +113,6 @@ load_init_data() { ./instance-db restore "${LOCAL_INIT_DIR}/database" } - # --------------------------- [[ $# -lt 1 ]] && about diff --git a/test/lib/config/config-local b/test/lib/config/config-local index c4f97e7b5fd..3e3c806b616 100755 --- a/test/lib/config/config-local +++ b/test/lib/config/config-local @@ -9,7 +9,7 @@ die() { } about() { - die "usage: ${script_name} [make | load | clean]" + die "usage: ${script_name} [make | load | lock | unlock | clean]" } #shellcheck disable=SC1007 @@ -57,12 +57,9 @@ config_generate() { cp ../config/profiles.yaml \ ../config/simulation.yaml \ - ../config/local_api_credentials.yaml \ ../config/online_api_credentials.yaml \ "${CONFIG_DIR}/" - cp ../config/context.yaml "${CONFIG_DIR}/console/" - cp ../config/detect.yaml \ "${HUB_DIR}" @@ -76,14 +73,13 @@ config_generate() { type: syslog EOT - cp ../plugins/notifications/*/{http,email,slack,splunk,dummy}.yaml \ + cp ../cmd/notification-*/*.yaml \ "${CONFIG_DIR}/notifications/" yq e ' .common.daemonize=true | del(.common.pid_dir) | .common.log_level="info" | - .common.force_color_logs=true | .common.log_dir=strenv(LOG_DIR) | .config_paths.config_dir=strenv(CONFIG_DIR) | .config_paths.data_dir=strenv(DATA_DIR) | @@ -97,14 +93,13 @@ config_generate() { .db_config.db_path=strenv(DATA_DIR)+"/crowdsec.db" | .db_config.use_wal=true | .api.client.credentials_path=strenv(CONFIG_DIR)+"/local_api_credentials.yaml" | + .api.server.listen_socket=strenv(DATA_DIR)+"/crowdsec.sock" | .api.server.profiles_path=strenv(CONFIG_DIR)+"/profiles.yaml" | .api.server.console_path=strenv(CONFIG_DIR)+"/console.yaml" | - .crowdsec_service.console_context_path=strenv(CONFIG_DIR) + "/console/context.yaml" | del(.api.server.online_client) ' ../config/config.yaml >"${CONFIG_DIR}/config.yaml" } - make_init_data() { ./bin/assert-crowdsec-not-running || die "Cannot create fixture data." @@ -113,16 +108,21 @@ make_init_data() { mkdir -p "${CONFIG_DIR}/notifications" mkdir -p "${CONFIG_DIR}/hub" mkdir -p "${CONFIG_DIR}/patterns" - mkdir -p "${CONFIG_DIR}/console" cp -a "../config/patterns" "${CONFIG_DIR}/" config_generate # XXX errors from instance-db should be reported... ./instance-db config-yaml ./instance-db setup - "$CSCLI" --warning machines add githubciXXXXXXXXXXXXXXXXXXXXXXXX --auto - "$CSCLI" --warning hub update - "$CSCLI" --warning collections install crowdsecurity/linux + "$CSCLI" --warning hub update --with-content + + # preload some content and data files + "$CSCLI" collections install crowdsecurity/linux --download-only + # sub-items did not respect --download-only + ./bin/remove-all-hub-items + + # force TCP, the default would be unix socket + "$CSCLI" --warning machines add githubciXXXXXXXXXXXXXXXXXXXXXXXX --url http://127.0.0.1:8080 --auto --force mkdir -p "$LOCAL_INIT_DIR" @@ -137,7 +137,16 @@ make_init_data() { remove_init_data } +lock_init_data() { + touch "${LOCAL_INIT_DIR}/.lock" +} + +unlock_init_data() { + rm -f "${LOCAL_INIT_DIR}/.lock" +} + load_init_data() { + [[ -f "${LOCAL_INIT_DIR}/.lock" ]] && die "init data is locked" ./bin/assert-crowdsec-not-running || die "Cannot load fixture data." if [[ ! -f "${LOCAL_INIT_DIR}/init-config-data.tar" ]]; then @@ -156,7 +165,6 @@ load_init_data() { ./instance-db restore "${LOCAL_INIT_DIR}/database" } - # --------------------------- [[ $# -lt 1 ]] && about @@ -168,6 +176,12 @@ case "$1" in load) load_init_data ;; + lock) + lock_init_data + ;; + unlock) + unlock_init_data + ;; clean) remove_init_data ;; diff --git a/test/lib/db/instance-mysql b/test/lib/db/instance-mysql index 6b40c84acba..df38f09761f 100755 --- a/test/lib/db/instance-mysql +++ b/test/lib/db/instance-mysql @@ -21,7 +21,7 @@ about() { check_requirements() { if ! command -v mysql >/dev/null; then - die "missing required program 'mysql' as a mysql client (package mariadb-client-core-10.6 on debian like system)" + die "missing required program 'mysql' as a mysql client (package mariadb-client on debian like system)" fi } diff --git a/test/lib/setup_file.sh b/test/lib/setup_file.sh index a4231c98edb..39a084596e2 100755 --- a/test/lib/setup_file.sh +++ b/test/lib/setup_file.sh @@ -20,6 +20,7 @@ eval "$(debug)" # Allow tests to use relative paths for helper scripts. # shellcheck disable=SC2164 cd "${TEST_DIR}" +export PATH="${TEST_DIR}/bin:${PATH}" # complain if there's a crowdsec running system-wide or leftover from a previous test ./bin/assert-crowdsec-not-running @@ -67,7 +68,9 @@ config_set() { export -f config_set config_disable_agent() { - config_set 'del(.crowdsec_service)' + config_set '.crowdsec_service.enable=false' + # this should be equivalent to: + # config_set 'del(.crowdsec_service)' } export -f config_disable_agent @@ -77,7 +80,9 @@ config_log_stderr() { export -f config_log_stderr config_disable_lapi() { - config_set 'del(.api.server)' + config_set '.api.server.enable=false' + # this should be equivalent to: + # config_set 'del(.api.server)' } export -f config_disable_lapi @@ -112,18 +117,23 @@ output() { } export -f output +is_package_testing() { + [[ "$PACKAGE_TESTING" != "" ]] +} +export -f is_package_testing + is_db_postgres() { - [[ "${DB_BACKEND}" =~ ^postgres|pgx$ ]] + [[ "$DB_BACKEND" =~ ^postgres|pgx$ ]] } export -f is_db_postgres is_db_mysql() { - [[ "${DB_BACKEND}" == "mysql" ]] + [[ "$DB_BACKEND" == "mysql" ]] } export -f is_db_mysql is_db_sqlite() { - [[ "${DB_BACKEND}" == "sqlite" ]] + [[ "$DB_BACKEND" == "sqlite" ]] } export -f is_db_sqlite @@ -145,6 +155,11 @@ assert_log() { } export -f assert_log +cert_serial_number() { + cfssl certinfo -cert "$1" | jq -r '.serial_number' +} +export -f cert_serial_number + # Compare ignoring the key order, and allow "expected" without quoted identifiers. # Preserve the output variable in case the following commands require it. assert_json() { @@ -229,6 +244,32 @@ assert_stderr_line() { } export -f assert_stderr_line +# remove all installed items and data +hub_purge_all() { + local CONFIG_DIR + local itemtype + CONFIG_DIR=$(dirname "$CONFIG_YAML") + for itemtype in $(cscli hub types -o raw); do + rm -rf "$CONFIG_DIR"/"${itemtype:?}"/* "$CONFIG_DIR"/hub/"${itemtype:?}"/* + done + local DATA_DIR + DATA_DIR=$(config_get .config_paths.data_dir) + # should remove everything except the db (find $DATA_DIR -not -name "crowdsec.db*" -delete), + # but don't play with fire if there is a misconfiguration + rm -rfv "$DATA_DIR"/GeoLite* +} +export -f hub_purge_all + +# remove unused data from the index, to make sure we don't rely on it in any way +hub_strip_index() { + local INDEX + INDEX=$(config_get .config_paths.index_path) + local hub_min + hub_min=$(jq <"$INDEX" 'del(..|.long_description?) | del(..|.deprecated?) | del (..|.labels?)') + echo "$hub_min" >"$INDEX" +} +export -f hub_strip_index + # remove color and style sequences from stdin plaintext() { sed -E 's/\x1B\[[0-9;]*[JKmsu]//g' @@ -240,3 +281,62 @@ rune() { run --separate-stderr "$@" } export -f rune + +# call the lapi through unix socket +# the path (and query string) must be the first parameter, the others will be passed to curl +curl-socket() { + [[ -z "$1" ]] && { fail "${FUNCNAME[0]}: missing path"; } + local path=$1 + shift + local socket + socket=$(config_get '.api.server.listen_socket') + [[ -z "$socket" ]] && { fail "${FUNCNAME[0]}: missing .api.server.listen_socket"; } + # curl needs a fake hostname when using a unix socket + curl --unix-socket "$socket" "http://lapi$path" "$@" +} +export -f curl-socket + +# call the lapi through tcp +# the path (and query string) must be the first parameter, the others will be passed to curl +curl-tcp() { + [[ -z "$1" ]] && { fail "${FUNCNAME[0]}: missing path"; } + local path=$1 + shift + local cred + cred=$(config_get .api.client.credentials_path) + local base_url + base_url="$(yq '.url' < "$cred")" + curl "$base_url$path" "$@" +} +export -f curl-tcp + +# call the lapi through unix socket with an API_KEY (authenticates as a bouncer) +# after $1, pass throught extra arguments to curl +curl-with-key() { + [[ -z "$API_KEY" ]] && { fail "${FUNCNAME[0]}: missing API_KEY"; } + curl-tcp "$@" -sS --fail-with-body -H "X-Api-Key: $API_KEY" +} +export -f curl-with-key + +# call the lapi through unix socket with a TOKEN (authenticates as a machine) +# after $1, pass throught extra arguments to curl +curl-with-token() { + [[ -z "$TOKEN" ]] && { fail "${FUNCNAME[0]}: missing TOKEN"; } + # curl needs a fake hostname when using a unix socket + curl-tcp "$@" -sS --fail-with-body -H "Authorization: Bearer $TOKEN" +} +export -f curl-with-token + +# as a log processor, connect to lapi and get a token +lp-get-token() { + local cred + cred=$(config_get .api.client.credentials_path) + local resp + resp=$(yq -oj -I0 '{"machine_id":.login,"password":.password}' < "$cred" | curl-socket '/v1/watchers/login' -s -X POST --data-binary @-) + if [[ "$(yq -e '.code' <<<"$resp")" != 200 ]]; then + echo "login_lp: failed to login" >&3 + return 1 + fi + echo "$resp" | yq -r '.token' +} +export -f lp-get-token diff --git a/test/localstack/docker-compose.yml b/test/localstack/docker-compose.yml index 66a820da3a5..f58f3c7f263 100644 --- a/test/localstack/docker-compose.yml +++ b/test/localstack/docker-compose.yml @@ -3,7 +3,7 @@ version: "3.8" services: localstack: container_name: localstack_main - image: localstack/localstack:1.3.0 + image: localstack/localstack:3.0 network_mode: bridge ports: - "127.0.0.1:53:53" # only required for Pro (DNS) @@ -14,21 +14,18 @@ services: environment: AWS_HOST: localstack DEBUG: "" - LAMBDA_EXECUTOR: "" KINESYS_ERROR_PROBABILITY: "" DOCKER_HOST: "unix://var/run/docker.sock" - KINESIS_INITIALIZE_STREAMS: "stream-1-shard:1,stream-2-shards:2" - HOSTNAME_EXTERNAL: "localstack" - AWS_ACCESS_KEY_ID: "AKIAIOSFODNN7EXAMPLE" - AWS_SECRET_ACCESS_KEY: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" + LOCALSTACK_HOST: "localstack" AWS_REGION: "us-east-1" volumes: - "${TMPDIR:-/tmp}/localstack:/var/lib/localstack" - "/var/run/docker.sock:/var/run/docker.sock" + - "./scripts/init_script.sh:/etc/localstack/init/ready.d/init_script.sh" zoo1: - image: confluentinc/cp-zookeeper:7.3.0 + image: confluentinc/cp-zookeeper:7.4.3 ports: - "2181:2181" environment: @@ -77,3 +74,8 @@ services: interval: 10s retries: 5 timeout: 10s + + loki: + image: grafana/loki:2.9.1 + ports: + - "127.0.0.1:3100:3100" diff --git a/test/localstack/scripts/init_script.sh b/test/localstack/scripts/init_script.sh new file mode 100755 index 00000000000..808ae4eb0a2 --- /dev/null +++ b/test/localstack/scripts/init_script.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +# Create Kinesis streams +aws --endpoint-url=http://localstack:4566 --region us-east-1 kinesis create-stream --stream-name stream-1-shard --shard-count 1 +aws --endpoint-url=http://localstack:4566 --region us-east-1 kinesis create-stream --stream-name stream-2-shards --shard-count 2 + diff --git a/test/run-tests b/test/run-tests index 21b7a7320c5..957eb663b9c 100755 --- a/test/run-tests +++ b/test/run-tests @@ -10,35 +10,37 @@ die() { # shellcheck disable=SC1007 TEST_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd) # shellcheck source=./.environment.sh -. "${TEST_DIR}/.environment.sh" +. "$TEST_DIR/.environment.sh" -"${TEST_DIR}/bin/check-requirements" +"$TEST_DIR/bin/check-requirements" echo "Running tests..." -echo "DB_BACKEND: ${DB_BACKEND}" -if [[ -z "${TEST_COVERAGE}" ]]; then +echo "DB_BACKEND: $DB_BACKEND" +if [[ -z "$TEST_COVERAGE" ]]; then echo "Coverage report: no" else echo "Coverage report: yes" fi -dump_backend="$(cat "${LOCAL_INIT_DIR}/.backend")" -if [[ "${DB_BACKEND}" != "${dump_backend}" ]]; then - die "Can't run with backend '${DB_BACKEND}' because the test data was build with '${dump_backend}'" +[[ -f "$LOCAL_INIT_DIR/.lock" ]] && die "init data is locked: are you doing some manual test? if so, please finish what you are doing, run 'instance-data unlock' and retry" + +dump_backend="$(cat "$LOCAL_INIT_DIR/.backend")" +if [[ "$DB_BACKEND" != "$dump_backend" ]]; then + die "Can't run with backend '$DB_BACKEND' because the test data was build with '$dump_backend'" fi if [[ $# -ge 1 ]]; then echo "test files: $*" - "${TEST_DIR}/lib/bats-core/bin/bats" \ + "$TEST_DIR/lib/bats-core/bin/bats" \ --jobs 1 \ --timing \ --print-output-on-failure \ "$@" else - echo "test files: ${TEST_DIR}/bats ${TEST_DIR}/dyn-bats" - "${TEST_DIR}/lib/bats-core/bin/bats" \ + echo "test files: $TEST_DIR/bats $TEST_DIR/dyn-bats" + "$TEST_DIR/lib/bats-core/bin/bats" \ --jobs 1 \ --timing \ --print-output-on-failure \ - "${TEST_DIR}/bats" "${TEST_DIR}/dyn-bats" + "$TEST_DIR/bats" "$TEST_DIR/dyn-bats" fi diff --git a/windows/installer/product.wxs b/windows/installer/product.wxs index b43cd6de322..a0c1ea11e9f 100644 --- a/windows/installer/product.wxs +++ b/windows/installer/product.wxs @@ -87,16 +87,19 @@
- + - + - + - + + + + @@ -120,10 +123,11 @@ - - - - + + + + + @@ -139,7 +143,7 @@ - + @@ -186,4 +190,4 @@ - \ No newline at end of file + diff --git a/wizard.sh b/wizard.sh index e331db2fb5e..6e215365f6c 100755 --- a/wizard.sh +++ b/wizard.sh @@ -18,7 +18,6 @@ NC='\033[0m' SILENT="false" DOCKER_MODE="false" -CROWDSEC_RUN_DIR="/var/run" CROWDSEC_LIB_DIR="/var/lib/crowdsec" CROWDSEC_USR_DIR="/usr/local/lib/crowdsec" CROWDSEC_DATA_DIR="${CROWDSEC_LIB_DIR}/data" @@ -77,48 +76,53 @@ smb " -HTTP_PLUGIN_BINARY="./plugins/notifications/http/notification-http" -SLACK_PLUGIN_BINARY="./plugins/notifications/slack/notification-slack" -SPLUNK_PLUGIN_BINARY="./plugins/notifications/splunk/notification-splunk" -EMAIL_PLUGIN_BINARY="./plugins/notifications/email/notification-email" +HTTP_PLUGIN_BINARY="./cmd/notification-http/notification-http" +SLACK_PLUGIN_BINARY="./cmd/notification-slack/notification-slack" +SPLUNK_PLUGIN_BINARY="./cmd/notification-splunk/notification-splunk" +EMAIL_PLUGIN_BINARY="./cmd/notification-email/notification-email" +SENTINEL_PLUGIN_BINARY="./cmd/notification-sentinel/notification-sentinel" +FILE_PLUGIN_BINARY="./cmd/notification-file/notification-file" + +HTTP_PLUGIN_CONFIG="./cmd/notification-http/http.yaml" +SLACK_PLUGIN_CONFIG="./cmd/notification-slack/slack.yaml" +SPLUNK_PLUGIN_CONFIG="./cmd/notification-splunk/splunk.yaml" +EMAIL_PLUGIN_CONFIG="./cmd/notification-email/email.yaml" +SENTINEL_PLUGIN_CONFIG="./cmd/notification-sentinel/sentinel.yaml" +FILE_PLUGIN_CONFIG="./cmd/notification-file/file.yaml" -HTTP_PLUGIN_CONFIG="./plugins/notifications/http/http.yaml" -SLACK_PLUGIN_CONFIG="./plugins/notifications/slack/slack.yaml" -SPLUNK_PLUGIN_CONFIG="./plugins/notifications/splunk/splunk.yaml" -EMAIL_PLUGIN_CONFIG="./plugins/notifications/email/email.yaml" BACKUP_DIR=$(mktemp -d) rm -rf -- "$BACKUP_DIR" log_info() { msg=$1 - date=$(date +%x:%X) + date=$(date "+%Y-%m-%d %H:%M:%S") echo -e "${BLUE}INFO${NC}[${date}] crowdsec_wizard: ${msg}" } log_fatal() { msg=$1 - date=$(date +%x:%X) - echo -e "${RED}FATA${NC}[${date}] crowdsec_wizard: ${msg}" 1>&2 + date=$(date "+%Y-%m-%d %H:%M:%S") + echo -e "${RED}FATA${NC}[${date}] crowdsec_wizard: ${msg}" 1>&2 exit 1 } log_warn() { msg=$1 - date=$(date +%x:%X) + date=$(date "+%Y-%m-%d %H:%M:%S") echo -e "${ORANGE}WARN${NC}[${date}] crowdsec_wizard: ${msg}" } log_err() { msg=$1 - date=$(date +%x:%X) + date=$(date "+%Y-%m-%d %H:%M:%S") echo -e "${RED}ERR${NC}[${date}] crowdsec_wizard: ${msg}" 1>&2 } log_dbg() { if [[ ${DEBUG_MODE} == "true" ]]; then msg=$1 - date=$(date +%x:%X) + date=$(date "+%Y-%m-%d %H:%M:%S") echo -e "[${date}][${YELLOW}DBG${NC}] crowdsec_wizard: ${msg}" 1>&2 fi } @@ -126,16 +130,16 @@ log_dbg() { detect_services () { DETECTED_SERVICES=() HMENU=() - #list systemd services - SYSTEMD_SERVICES=`systemctl --state=enabled list-unit-files '*.service' | cut -d ' ' -f1` - #raw ps - PSAX=`ps ax -o comm=` + # list systemd services + SYSTEMD_SERVICES=$(systemctl --state=enabled list-unit-files '*.service' | cut -d ' ' -f1) + # raw ps + PSAX=$(ps ax -o comm=) for SVC in ${SUPPORTED_SERVICES} ; do log_dbg "Checking if service '${SVC}' is running (ps+systemd)" for SRC in "${SYSTEMD_SERVICES}" "${PSAX}" ; do echo ${SRC} | grep ${SVC} >/dev/null if [ $? -eq 0 ]; then - #on centos, apache2 is named httpd + # on centos, apache2 is named httpd if [[ ${SVC} == "httpd" ]] ; then SVC="apache2"; fi @@ -149,12 +153,12 @@ detect_services () { if [[ ${OSTYPE} == "linux-gnu" ]] || [[ ${OSTYPE} == "linux-gnueabihf" ]]; then DETECTED_SERVICES+=("linux") HMENU+=("linux" "on") - else + else log_info "NOT A LINUX" fi; if [[ ${SILENT} == "false" ]]; then - #we put whiptail results in an array, notice the dark magic fd redirection + # we put whiptail results in an array, notice the dark magic fd redirection DETECTED_SERVICES=($(whiptail --separate-output --noitem --ok-button Continue --title "Services to monitor" --checklist "Detected services, uncheck to ignore. Ignored services won't be monitored." 18 70 10 ${HMENU[@]} 3>&1 1>&2 2>&3)) if [ $? -eq 1 ]; then log_err "user bailed out at services selection" @@ -186,28 +190,27 @@ log_locations[mysql]='/var/log/mysql/error.log' log_locations[smb]='/var/log/samba*.log' log_locations[linux]='/var/log/syslog,/var/log/kern.log,/var/log/messages' -#$1 is service name, such those in SUPPORTED_SERVICES +# $1 is service name, such those in SUPPORTED_SERVICES find_logs_for() { - ret="" x=${1} - #we have trailing and starting quotes because of whiptail + # we have trailing and starting quotes because of whiptail SVC="${x%\"}" SVC="${SVC#\"}" DETECTED_LOGFILES=() HMENU=() - #log_info "Searching logs for ${SVC} : ${log_locations[${SVC}]}" + # log_info "Searching logs for ${SVC} : ${log_locations[${SVC}]}" - #split the line into an array with ',' separator + # split the line into an array with ',' separator OIFS=${IFS} IFS=',' read -r -a a <<< "${log_locations[${SVC}]}," IFS=${OIFS} - #readarray -td, a <<<"${log_locations[${SVC}]},"; unset 'a[-1]'; + # readarray -td, a <<<"${log_locations[${SVC}]},"; unset 'a[-1]'; for poss_path in "${a[@]}"; do - #Split /var/log/nginx/*.log into '/var/log/nginx' and '*.log' so we can use find + # Split /var/log/nginx/*.log into '/var/log/nginx' and '*.log' so we can use find path=${poss_path%/*} fname=${poss_path##*/} - candidates=`find "${path}" -type f -mtime -5 -ctime -5 -name "$fname"` - #We have some candidates, add them + candidates=$(find "${path}" -type f -mtime -5 -ctime -5 -name "$fname" 2>/dev/null) + # We have some candidates, add them for final_file in ${candidates} ; do log_dbg "Found logs file for '${SVC}': ${final_file}" DETECTED_LOGFILES+=(${final_file}) @@ -246,12 +249,12 @@ install_collection() { in_array $collection "${DETECTED_SERVICES[@]}" if [[ $? == 0 ]]; then HMENU+=("${collection}" "${description}" "ON") - #in case we're not in interactive mode, assume defaults + # in case we're not in interactive mode, assume defaults COLLECTION_TO_INSTALL+=(${collection}) else if [[ ${collection} == "linux" ]]; then HMENU+=("${collection}" "${description}" "ON") - #in case we're not in interactive mode, assume defaults + # in case we're not in interactive mode, assume defaults COLLECTION_TO_INSTALL+=(${collection}) else HMENU+=("${collection}" "${description}" "OFF") @@ -269,10 +272,10 @@ install_collection() { for collection in "${COLLECTION_TO_INSTALL[@]}"; do log_info "Installing collection '${collection}'" - ${CSCLI_BIN_INSTALLED} collections install "${collection}" > /dev/null 2>&1 || log_err "fail to install collection ${collection}" + ${CSCLI_BIN_INSTALLED} collections install "${collection}" --error done - ${CSCLI_BIN_INSTALLED} parsers install "crowdsecurity/whitelists" > /dev/null 2>&1 || log_err "fail to install collection crowdsec/whitelists" + ${CSCLI_BIN_INSTALLED} parsers install "crowdsecurity/whitelists" --error if [[ ${SILENT} == "false" ]]; then whiptail --msgbox "Out of safety, I installed a parser called 'crowdsecurity/whitelists'. This one will prevent private IP addresses from being banned, feel free to remove it any time." 20 50 fi @@ -282,14 +285,14 @@ install_collection() { fi } -#$1 is the service name, $... is the list of candidate logs (from find_logs_for) +# $1 is the service name, $... is the list of candidate logs (from find_logs_for) genyamllog() { local service="${1}" shift local files=("${@}") - + echo "#Generated acquisition file - wizard.sh (service: ${service}) / files : ${files[@]}" >> ${TMP_ACQUIS_FILE} - + echo "filenames:" >> ${TMP_ACQUIS_FILE} for fd in ${files[@]}; do echo " - ${fd}" >> ${TMP_ACQUIS_FILE} @@ -303,9 +306,9 @@ genyamllog() { genyamljournal() { local service="${1}" shift - + echo "#Generated acquisition file - wizard.sh (service: ${service}) / files : ${files[@]}" >> ${TMP_ACQUIS_FILE} - + echo "journalctl_filter:" >> ${TMP_ACQUIS_FILE} echo " - _SYSTEMD_UNIT="${service}".service" >> ${TMP_ACQUIS_FILE} echo "labels:" >> ${TMP_ACQUIS_FILE} @@ -315,7 +318,7 @@ genyamljournal() { } genacquisition() { - if skip_tmp_acquis; then + if skip_tmp_acquis; then TMP_ACQUIS_FILE="${ACQUIS_TARGET}" ACQUIS_FILE_MSG="acquisition file generated to: ${TMP_ACQUIS_FILE}" else @@ -333,7 +336,7 @@ genacquisition() { log_info "using journald for '${PSVG}'" genyamljournal ${PSVG} fi; - done + done } detect_cs_install () { @@ -368,7 +371,7 @@ check_cs_version () { fi elif [[ $NEW_MINOR_VERSION -gt $CURRENT_MINOR_VERSION ]] ; then log_warn "new version ($NEW_CS_VERSION) is a minor upgrade !" - if [[ $ACTION != "upgrade" ]] ; then + if [[ $ACTION != "upgrade" ]] ; then if [[ ${FORCE_MODE} == "false" ]]; then echo "" echo "We recommend to upgrade with : sudo ./wizard.sh --upgrade " @@ -380,7 +383,7 @@ check_cs_version () { fi elif [[ $NEW_PATCH_VERSION -gt $CURRENT_PATCH_VERSION ]] ; then log_warn "new version ($NEW_CS_VERSION) is a patch !" - if [[ $ACTION != "binupgrade" ]] ; then + if [[ $ACTION != "binupgrade" ]] ; then if [[ ${FORCE_MODE} == "false" ]]; then echo "" echo "We recommend to upgrade binaries only : sudo ./wizard.sh --binupgrade " @@ -403,17 +406,21 @@ check_cs_version () { fi } -#install crowdsec and cscli +# install crowdsec and cscli install_crowdsec() { mkdir -p "${CROWDSEC_DATA_DIR}" (cd config && find patterns -type f -exec install -Dm 644 "{}" "${CROWDSEC_CONFIG_PATH}/{}" \; && cd ../) || exit + mkdir -p "${CROWDSEC_CONFIG_PATH}/acquis.d" || exit mkdir -p "${CROWDSEC_CONFIG_PATH}/scenarios" || exit mkdir -p "${CROWDSEC_CONFIG_PATH}/postoverflows" || exit mkdir -p "${CROWDSEC_CONFIG_PATH}/collections" || exit mkdir -p "${CROWDSEC_CONFIG_PATH}/patterns" || exit + mkdir -p "${CROWDSEC_CONFIG_PATH}/appsec-configs" || exit + mkdir -p "${CROWDSEC_CONFIG_PATH}/appsec-rules" || exit + mkdir -p "${CROWDSEC_CONFIG_PATH}/contexts" || exit mkdir -p "${CROWDSEC_CONSOLE_DIR}" || exit - #tmp + # tmp mkdir -p /tmp/data mkdir -p /etc/crowdsec/hub/ install -v -m 600 -D "./config/${CLIENT_SECRETS}" "${CROWDSEC_CONFIG_PATH}" 1> /dev/null || exit @@ -485,7 +492,7 @@ install_bins() { install -v -m 755 -D "${CSCLI_BIN}" "${CSCLI_BIN_INSTALLED}" 1> /dev/null || exit which systemctl && systemctl is-active --quiet crowdsec if [ $? -eq 0 ]; then - systemctl stop crowdsec + systemctl stop crowdsec fi install_plugins symlink_bins @@ -503,7 +510,7 @@ symlink_bins() { delete_bins() { log_info "Removing crowdsec binaries" rm -f ${CROWDSEC_BIN_INSTALLED} - rm -f ${CSCLI_BIN_INSTALLED} + rm -f ${CSCLI_BIN_INSTALLED} } delete_plugins() { @@ -518,17 +525,21 @@ install_plugins(){ cp ${SPLUNK_PLUGIN_BINARY} ${CROWDSEC_PLUGIN_DIR} cp ${HTTP_PLUGIN_BINARY} ${CROWDSEC_PLUGIN_DIR} cp ${EMAIL_PLUGIN_BINARY} ${CROWDSEC_PLUGIN_DIR} + cp ${SENTINEL_PLUGIN_BINARY} ${CROWDSEC_PLUGIN_DIR} + cp ${FILE_PLUGIN_BINARY} ${CROWDSEC_PLUGIN_DIR} if [[ ${DOCKER_MODE} == "false" ]]; then cp -n ${SLACK_PLUGIN_CONFIG} /etc/crowdsec/notifications/ cp -n ${SPLUNK_PLUGIN_CONFIG} /etc/crowdsec/notifications/ cp -n ${HTTP_PLUGIN_CONFIG} /etc/crowdsec/notifications/ cp -n ${EMAIL_PLUGIN_CONFIG} /etc/crowdsec/notifications/ + cp -n ${SENTINEL_PLUGIN_CONFIG} /etc/crowdsec/notifications/ + cp -n ${FILE_PLUGIN_CONFIG} /etc/crowdsec/notifications/ fi } check_running_bouncers() { - #when uninstalling, check if user still has bouncers + # when uninstalling, check if user still has bouncers BOUNCERS_COUNT=$(${CSCLI_BIN} bouncers list -o=raw | tail -n +2 | wc -l) if [[ ${BOUNCERS_COUNT} -gt 0 ]] ; then if [[ ${FORCE_MODE} == "false" ]]; then @@ -639,7 +650,7 @@ main() { then return fi - + if [[ "$1" == "uninstall" ]]; then if ! [ $(id -u) = 0 ]; then @@ -678,11 +689,11 @@ main() { log_info "installing crowdsec" install_crowdsec log_dbg "configuring ${CSCLI_BIN_INSTALLED}" - ${CSCLI_BIN_INSTALLED} hub update > /dev/null 2>&1 || (log_err "fail to update crowdsec hub. exiting" && exit 1) + ${CSCLI_BIN_INSTALLED} hub update --error || (log_err "fail to update crowdsec hub. exiting" && exit 1) # detect running services detect_services - if ! [ ${#DETECTED_SERVICES[@]} -gt 0 ] ; then + if ! [ ${#DETECTED_SERVICES[@]} -gt 0 ] ; then log_err "No detected or selected services, stopping." exit 1 fi; @@ -704,11 +715,10 @@ main() { # api register ${CSCLI_BIN_INSTALLED} machines add --force "$(cat /etc/machine-id)" -a -f "${CROWDSEC_CONFIG_PATH}/${CLIENT_SECRETS}" || log_fatal "unable to add machine to the local API" - log_dbg "Crowdsec LAPI registered" - - ${CSCLI_BIN_INSTALLED} capi register || log_fatal "unable to register to the Central API" - log_dbg "Crowdsec CAPI registered" - + log_dbg "Crowdsec LAPI registered" + + ${CSCLI_BIN_INSTALLED} capi register --error || log_fatal "unable to register to the Central API" + systemctl enable -q crowdsec >/dev/null || log_fatal "unable to enable crowdsec" systemctl start crowdsec >/dev/null || log_fatal "unable to start crowdsec" log_info "enabling and starting crowdsec daemon" @@ -722,7 +732,7 @@ main() { rm -f "${TMP_ACQUIS_FILE}" fi detect_services - if [[ ${DETECTED_SERVICES} == "" ]] ; then + if [[ ${DETECTED_SERVICES} == "" ]] ; then log_err "No detected or selected services, stopping." exit fi; @@ -750,7 +760,7 @@ usage() { echo " ./wizard.sh --docker-mode Will install crowdsec without systemd and generate random machine-id" echo " ./wizard.sh -n|--noop Do nothing" - exit 0 + exit 0 } if [[ $# -eq 0 ]]; then @@ -763,15 +773,15 @@ do case ${key} in --uninstall) ACTION="uninstall" - shift #past argument + shift # past argument ;; --binupgrade) ACTION="binupgrade" - shift #past argument + shift # past argument ;; --upgrade) ACTION="upgrade" - shift #past argument + shift # past argument ;; -i|--install) ACTION="install" @@ -806,11 +816,11 @@ do -f|--force) FORCE_MODE="true" shift - ;; + ;; -v|--verbose) DEBUG_MODE="true" shift - ;; + ;; -h|--help) usage exit 0