Skip to content

Commit

Permalink
Modularize scanning engine (#2887)
Browse files Browse the repository at this point in the history
* POC: Modularize scanning engine.

* fix typo

* update interface name

* fix tests

* update test

* fix moar tests

* fix bug

* fixes.

* fix merge

* add detector verification overrides

* handle --no-verification flag

* support fp

* add test

* update name

* filter

* update test

* explicit use of detector

* updates
  • Loading branch information
ahrav authored Jun 13, 2024
1 parent 4addd81 commit cb07260
Show file tree
Hide file tree
Showing 6 changed files with 761 additions and 551 deletions.
277 changes: 80 additions & 197 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"encoding/json"
"fmt"
"io"
"net/http"
Expand All @@ -24,8 +25,6 @@ import (
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
"github.com/trufflesecurity/trufflehog/v3/pkg/config"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/decoders"
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors"
"github.com/trufflesecurity/trufflehog/v3/pkg/engine"
"github.com/trufflesecurity/trufflehog/v3/pkg/handlers"
"github.com/trufflesecurity/trufflehog/v3/pkg/log"
Expand Down Expand Up @@ -352,88 +351,6 @@ func run(state overseer.State) {
handlers.SetArchiveMaxTimeout(*archiveTimeout)
}

// Build include and exclude detector sets for filtering on engine initialization.
// Exit if there was an error to inform the user of the misconfiguration.
var includeDetectorSet, excludeDetectorSet map[config.DetectorID]struct{}
var detectorsWithCustomVerifierEndpoints map[config.DetectorID][]string
{
includeList, err := config.ParseDetectors(*includeDetectors)
if err != nil {
logFatal(err, "invalid include list detector configuration")
}
excludeList, err := config.ParseDetectors(*excludeDetectors)
if err != nil {
logFatal(err, "invalid exclude list detector configuration")
}
detectorsWithCustomVerifierEndpoints, err = config.ParseVerifierEndpoints(*verifiers)
if err != nil {
logFatal(err, "invalid verifier detector configuration")
}
includeDetectorSet = detectorTypeToSet(includeList)
excludeDetectorSet = detectorTypeToSet(excludeList)
}

// Verify that all the user-provided detectors support the optional
// detector features.
{
if id, err := verifyDetectorsAreVersioner(includeDetectorSet); err != nil {
logFatal(err, "invalid include list detector configuration", "detector", id)
}
if id, err := verifyDetectorsAreVersioner(excludeDetectorSet); err != nil {
logFatal(err, "invalid exclude list detector configuration", "detector", id)
}
if id, err := verifyDetectorsAreVersioner(detectorsWithCustomVerifierEndpoints); err != nil {
logFatal(err, "invalid verifier detector configuration", "detector", id)
}
// Extra check for endpoint customization.
isEndpointCustomizer := engine.DefaultDetectorTypesImplementing[detectors.EndpointCustomizer]()
for id := range detectorsWithCustomVerifierEndpoints {
if _, ok := isEndpointCustomizer[id.ID]; !ok {
logFatal(
fmt.Errorf("endpoint provided but detector does not support endpoint customization"),
"invalid custom verifier endpoint detector configuration",
"detector", id,
)
}
}
}

includeFilter := func(d detectors.Detector) bool {
_, ok := getWithDetectorID(d, includeDetectorSet)
return ok
}
excludeFilter := func(d detectors.Detector) bool {
_, ok := getWithDetectorID(d, excludeDetectorSet)
return !ok
}
// Abuse filter to cause a side-effect.
endpointCustomizer := func(d detectors.Detector) bool {
urls, ok := getWithDetectorID(d, detectorsWithCustomVerifierEndpoints)
if !ok {
return true
}
id := config.GetDetectorID(d)
customizer, ok := d.(detectors.EndpointCustomizer)
if !ok {
// NOTE: We should never reach here due to validation above.
logFatal(
fmt.Errorf("failed to configure a detector endpoint"),
"the provided detector does not support endpoint configuration",
"detector", id,
)
}
if !*customVerifiersOnly || len(urls) == 0 {
urls = append(urls, customizer.DefaultEndpoint())
}
if err := customizer.SetEndpoints(urls...); err != nil {
logFatal(err, "failed configuring custom endpoint for detector", "detector", id)
}
logger.Info("configured detector with verification urls",
"detector", id, "urls", urls,
)
return true
}

// Set how the engine will print its results.
var printer engine.Printer
switch {
Expand All @@ -451,11 +368,6 @@ func run(state overseer.State) {
fmt.Fprintf(os.Stderr, "🐷🔑🐷 TruffleHog. Unearth your secrets. 🐷🔑🐷\n\n")
}

var jobReportWriter io.WriteCloser
if *jobReportFile != nil {
jobReportWriter = *jobReportFile
}

// Parse --results flag.
if *onlyVerified {
r := "verified"
Expand All @@ -466,34 +378,31 @@ func run(state overseer.State) {
logFatal(err, "failed to configure results flag")
}

scanConfig := scanConfig{
Command: cmd,
Concurrency: *concurrency,
Decoders: decoders.DefaultDecoders(),
Conf: conf,
IncludeFilter: includeFilter,
ExcludeFilter: excludeFilter,
EndpointCustomizer: endpointCustomizer,
NoVerification: *noVerification,
PrintAvgDetectorTime: *printAvgDetectorTime,
FilterUnverified: *filterUnverified,
FilterEntropy: *filterEntropy,
ScanEntireChunk: *scanEntireChunk,
JobReportWriter: jobReportWriter,
AllowVerificationOverlap: *allowVerificationOverlap,
ParsedResults: parsedResults,
Printer: printer,
engConf := engine.Config{
Concurrency: *concurrency,
Detectors: conf.Detectors,
Verify: !*noVerification,
IncludeDetectors: *includeDetectors,
ExcludeDetectors: *excludeDetectors,
CustomVerifiersOnly: *customVerifiersOnly,
VerifierEndpoints: *verifiers,
Dispatcher: engine.NewPrinterDispatcher(printer),
FilterUnverified: *filterUnverified,
FilterEntropy: *filterEntropy,
VerificationOverlap: *allowVerificationOverlap,
Results: parsedResults,
PrintAvgDetectorTime: *printAvgDetectorTime,
ShouldScanEntireChunk: *scanEntireChunk,
}

if *compareDetectionStrategies {
err := compareScans(ctx, scanConfig)
if err != nil {
if err := compareScans(ctx, cmd, engConf); err != nil {
logFatal(err, "error comparing detection strategies")
}
return
}

metrics, err := runSingleScan(ctx, scanConfig, *scanEntireChunk)
metrics, err := runSingleScan(ctx, cmd, engConf)
if err != nil {
logFatal(err, "error running scan")
}
Expand All @@ -514,26 +423,7 @@ func run(state overseer.State) {
}
}

type scanConfig struct {
Command string
Concurrency int
Decoders []decoders.Decoder
Conf *config.Config
IncludeFilter func(detectors.Detector) bool
ExcludeFilter func(detectors.Detector) bool
EndpointCustomizer func(detectors.Detector) bool
NoVerification bool
PrintAvgDetectorTime bool
FilterUnverified bool
FilterEntropy float64
ScanEntireChunk bool
JobReportWriter io.WriteCloser
AllowVerificationOverlap bool
ParsedResults map[string]struct{}
Printer engine.Printer
}

func compareScans(ctx context.Context, cfg scanConfig) error {
func compareScans(ctx context.Context, cmd string, cfg engine.Config) error {
var (
entireMetrics metrics
maxLengthMetrics metrics
Expand All @@ -546,14 +436,15 @@ func compareScans(ctx context.Context, cfg scanConfig) error {
go func() {
defer wg.Done()
// Run scan with entire chunk span calculator.
entireMetrics, err = runSingleScan(ctx, cfg, true)
cfg.ShouldScanEntireChunk = true
entireMetrics, err = runSingleScan(ctx, cmd, cfg)
if err != nil {
ctx.Logger().Error(err, "error running scan with entire chunk span calculator")
}
}()

// Run scan with max-length span calculator.
maxLengthMetrics, err = runSingleScan(ctx, cfg, false)
maxLengthMetrics, err = runSingleScan(ctx, cmd, cfg)
if err != nil {
return fmt.Errorf("error running scan with custom span calculator: %v", err)
}
Expand Down Expand Up @@ -585,27 +476,64 @@ type metrics struct {
hasFoundResults bool
}

func runSingleScan(ctx context.Context, cfg scanConfig, scanEntireChunk bool) (metrics, error) {
eng, err := engine.Start(ctx,
engine.WithConcurrency(cfg.Concurrency),
engine.WithDecoders(cfg.Decoders...),
engine.WithDetectors(engine.DefaultDetectors()...),
engine.WithDetectors(cfg.Conf.Detectors...),
engine.WithVerify(!cfg.NoVerification),
engine.WithFilterDetectors(cfg.IncludeFilter),
engine.WithFilterDetectors(cfg.ExcludeFilter),
engine.WithFilterDetectors(cfg.EndpointCustomizer),
engine.WithFilterUnverified(cfg.FilterUnverified),
engine.WithResults(cfg.ParsedResults),
engine.WithPrintAvgDetectorTime(cfg.PrintAvgDetectorTime),
engine.WithPrinter(cfg.Printer),
engine.WithFilterEntropy(cfg.FilterEntropy),
engine.WithVerificationOverlap(cfg.AllowVerificationOverlap),
engine.WithEntireChunkScan(scanEntireChunk),
)
func runSingleScan(ctx context.Context, cmd string, cfg engine.Config) (metrics, error) {
var scanMetrics metrics

// Setup job report writer if provided
var jobReportWriter io.WriteCloser
if *jobReportFile != nil {
jobReportWriter = *jobReportFile
}

handleFinishedMetrics := func(ctx context.Context, finishedMetrics <-chan sources.UnitMetrics, jobReportWriter io.WriteCloser) {
go func() {
defer func() {
jobReportWriter.Close()
if namer, ok := jobReportWriter.(interface{ Name() string }); ok {
ctx.Logger().Info("report written", "path", namer.Name())
} else {
ctx.Logger().Info("report written")
}
}()

for metrics := range finishedMetrics {
metrics.Errors = common.ExportErrors(metrics.Errors...)
details, err := json.Marshal(map[string]any{
"version": 1,
"data": metrics,
})
if err != nil {
ctx.Logger().Error(err, "error marshalling job details")
continue
}
if _, err := jobReportWriter.Write(append(details, '\n')); err != nil {
ctx.Logger().Error(err, "error writing to file")
}
}
}()
}

const defaultOutputBufferSize = 64
opts := []func(*sources.SourceManager){
sources.WithConcurrentSources(cfg.Concurrency),
sources.WithConcurrentUnits(cfg.Concurrency),
sources.WithSourceUnits(),
sources.WithBufferedOutput(defaultOutputBufferSize),
}

if jobReportWriter != nil {
unitHook, finishedMetrics := sources.NewUnitHook(ctx)
opts = append(opts, sources.WithReportHook(unitHook))
handleFinishedMetrics(ctx, finishedMetrics, jobReportWriter)
}

cfg.SourceManager = sources.NewManager(opts...)

eng, err := engine.NewEngine(ctx, &cfg)
if err != nil {
return metrics{}, fmt.Errorf("error initializing engine: %v", err)
return scanMetrics, fmt.Errorf("error initializing engine: %v", err)
}
eng.Start(ctx)

defer func() {
// Clean up temporary artifacts.
Expand All @@ -614,8 +542,7 @@ func runSingleScan(ctx context.Context, cfg scanConfig, scanEntireChunk bool) (m
}
}()

var scanMetrics metrics
switch cfg.Command {
switch cmd {
case gitScan.FullCommand():
gitCfg := sources.GitConfig{
URI: *gitScanURI,
Expand Down Expand Up @@ -812,7 +739,7 @@ func runSingleScan(ctx context.Context, cfg scanConfig, scanEntireChunk bool) (m
return scanMetrics, fmt.Errorf("failed to scan Jenkins: %v", err)
}
default:
return scanMetrics, fmt.Errorf("invalid command: %s", cfg.Command)
return scanMetrics, fmt.Errorf("invalid command: %s", cmd)
}

// Wait for all workers to finish.
Expand Down Expand Up @@ -887,47 +814,3 @@ func printAverageDetectorTime(e *engine.Engine) {
fmt.Fprintf(os.Stderr, "%s: %s\n", detectorName, duration)
}
}

// detectorTypeToSet is a helper function to convert a slice of detector IDs into a set.
func detectorTypeToSet(detectors []config.DetectorID) map[config.DetectorID]struct{} {
out := make(map[config.DetectorID]struct{}, len(detectors))
for _, d := range detectors {
out[d] = struct{}{}
}
return out
}

// getWithDetectorID is a helper function to get a value from a map using a
// detector's ID. This function behaves like a normal map lookup, with an extra
// step of checking for the non-specific version of a detector.
func getWithDetectorID[T any](d detectors.Detector, data map[config.DetectorID]T) (T, bool) {
key := config.GetDetectorID(d)
// Check if the specific ID is provided.
if t, ok := data[key]; ok || key.Version == 0 {
return t, ok
}
// Check if the generic type is provided without a version.
// This means "all" versions of a type.
key.Version = 0
t, ok := data[key]
return t, ok
}

// verifyDetectorsAreVersioner checks all keys in a provided map to verify the
// provided type is actually a Versioner.
func verifyDetectorsAreVersioner[T any](data map[config.DetectorID]T) (config.DetectorID, error) {
isVersioner := engine.DefaultDetectorTypesImplementing[detectors.Versioner]()
for id := range data {
if id.Version == 0 {
// Version not provided.
continue
}
if _, ok := isVersioner[id.ID]; ok {
// Version provided for a Versioner detector.
continue
}
// Version provided on a non-Versioner detector.
return id, fmt.Errorf("version provided but detector does not have a version")
}
return config.DetectorID{}, nil
}
9 changes: 9 additions & 0 deletions pkg/detectors/detectors.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,17 @@ func unwrapToLast(err error) error {
}

type ResultWithMetadata struct {
// IsWordlistFalsePositive indicates whether this secret was flagged as a false positive based on a wordlist check
IsWordlistFalsePositive bool
// SourceMetadata contains source-specific contextual information.
SourceMetadata *source_metadatapb.MetaData
// SourceID is the ID of the source that the API uses to map secrets to specific sources.
SourceID sources.SourceID
// JobID is the ID of the job that the API uses to map secrets to specific jobs.
JobID sources.JobID
// SecretID is the ID of the secret, if it exists.
// Only secrets that are being reverified will have a SecretID.
SecretID int64
// SourceType is the type of Source.
SourceType sourcespb.SourceType
// SourceName is the name of the Source.
Expand All @@ -139,6 +146,8 @@ func CopyMetadata(chunk *sources.Chunk, result Result) ResultWithMetadata {
return ResultWithMetadata{
SourceMetadata: chunk.SourceMetadata,
SourceID: chunk.SourceID,
JobID: chunk.JobID,
SecretID: chunk.SecretID,
SourceType: chunk.SourceType,
SourceName: chunk.SourceName,
Result: result,
Expand Down
Loading

0 comments on commit cb07260

Please sign in to comment.