diff --git a/cmd/omni/main.go b/cmd/omni/main.go index a0e4c0e4..cb1eb83c 100644 --- a/cmd/omni/main.go +++ b/cmd/omni/main.go @@ -543,4 +543,11 @@ func init() { config.Config.EnableBreakGlassConfigs, "Allows downloading admin Talos and Kubernetes configs.", ) + + rootCmd.Flags().StringVar( + &config.Config.AuditLogDir, + "audit-log-dir", + config.Config.AuditLogDir, + "Directory for audit log storage", + ) } diff --git a/go.mod b/go.mod index 682f9282..c85c44cd 100644 --- a/go.mod +++ b/go.mod @@ -54,6 +54,7 @@ require ( github.com/jxskiss/base62 v1.1.0 github.com/mattn/go-shellwords v1.0.12 github.com/prometheus/client_golang v1.19.1 + github.com/prometheus/common v0.55.0 github.com/siderolabs/crypto v0.4.4 github.com/siderolabs/discovery-api v0.1.4 github.com/siderolabs/discovery-client v0.1.9 @@ -199,7 +200,6 @@ require ( github.com/planetscale/vtprotobuf v0.6.0 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/client_model v0.6.1 // indirect - github.com/prometheus/common v0.55.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect github.com/rs/cors v1.11.0 // indirect github.com/russellhaering/goxmldsig v1.4.0 // indirect diff --git a/go.sum b/go.sum index cb255167..709b9568 100644 --- a/go.sum +++ b/go.sum @@ -276,6 +276,8 @@ github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8Hm github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= +github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA= +github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/jsimonetti/rtnetlink/v2 v2.0.2 h1:ZKlbCujrIpp4/u3V2Ka0oxlf4BCkt6ojkvpy3nZoCBY= github.com/jsimonetti/rtnetlink/v2 v2.0.2/go.mod h1:7MoNYNbb3UaDHtF8udiJo/RH6VsTKP1pqKLUTVCvToE= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= @@ -334,6 +336,8 @@ github.com/muhlemmer/httpforwarded v0.1.0 h1:x4DLrzXdliq8mprgUMR0olDvHGkou5BJsK/ github.com/muhlemmer/httpforwarded v0.1.0/go.mod h1:yo9czKedo2pdZhoXe+yDkGVbU0TJ0q9oQ90BVoDEtw0= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f h1:KUppIJq7/+SVif2QVs3tOP0zanoHgBEVAwHxUSIzRqU= +github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/onsi/ginkgo/v2 v2.19.0 h1:9Cnnf7UHo57Hy3k6/m5k3dRfGTMXGvxhHFvkDTCTpvA= github.com/onsi/ginkgo/v2 v2.19.0/go.mod h1:rlwLi9PilAFJ8jCg9UE1QP6VBpd6/xj3SRC0d6TU0To= github.com/onsi/gomega v1.33.1 h1:dsYjIxxSR755MDmKVsaFQTE22ChNBcuuTWgkUDSubOk= diff --git a/hack/compose/docker-compose.yml b/hack/compose/docker-compose.yml index ca865ad5..8cc09e92 100644 --- a/hack/compose/docker-compose.yml +++ b/hack/compose/docker-compose.yml @@ -18,6 +18,7 @@ services: - logs:/_out/logs - secondary-storage:/_out/secondary-storage - etcd-backup:/tmp/omni-data/etcd-backup + - audit-logs:/tmp/omni-data/audit-logs - ../generate-certs/certs:/etc/ssl/omni-certs:ro container_name: local-omni restart: on-failure @@ -127,3 +128,4 @@ volumes: minio: secondary-storage: etcd-backup: + audit-logs: diff --git a/hack/generate-certs/main.go b/hack/generate-certs/main.go index 5c9e4169..55da5a2f 100644 --- a/hack/generate-certs/main.go +++ b/hack/generate-certs/main.go @@ -274,6 +274,7 @@ services: --debug --etcd-embedded-unsafe-fsync=true --etcd-backup-s3 + --audit-log-dir /tmp/omni-data/audit-logs {{- range $key, $value := .RegistryMirrors }} --registry-mirror {{ $key }}={{ $value }} {{- end }} diff --git a/internal/backend/grpc/auth.go b/internal/backend/grpc/auth.go index 08e2042a..62c588c5 100644 --- a/internal/backend/grpc/auth.go +++ b/internal/backend/grpc/auth.go @@ -29,11 +29,13 @@ import ( "github.com/siderolabs/omni/client/api/omni/specs" "github.com/siderolabs/omni/client/pkg/omni/resources" authres "github.com/siderolabs/omni/client/pkg/omni/resources/auth" + "github.com/siderolabs/omni/internal/backend/runtime/omni/audit" "github.com/siderolabs/omni/internal/backend/runtime/omni/controllers/omni" "github.com/siderolabs/omni/internal/pkg/auth" "github.com/siderolabs/omni/internal/pkg/auth/actor" "github.com/siderolabs/omni/internal/pkg/auth/role" "github.com/siderolabs/omni/internal/pkg/config" + "github.com/siderolabs/omni/internal/pkg/ctxstore" ) const ( @@ -143,6 +145,17 @@ func (s *authServer) RegisterPublicKey(ctx context.Context, request *authpb.Regi newPubKey := authres.NewPublicKey(resources.DefaultNamespace, pubKey.id) + auditData, ok := ctxstore.Value[*audit.Data](ctx) + if !ok { + return nil, errors.New("audit data not found") + } + + auditData.UserID = userID + auditData.Fingerprint = pubKey.id + auditData.PublicKeyExpiration = pubKey.expiration.Unix() + auditData.Role = pubKeyRole + auditData.Email = email + _, err = safe.StateGet[*authres.PublicKey](ctx, s.state, newPubKey.Metadata()) if state.IsNotFoundError(err) { setPubKeyAttributes(newPubKey) @@ -236,6 +249,16 @@ func (s *authServer) ConfirmPublicKey(ctx context.Context, request *authpb.Confi return nil, errors.New("public key <> id mismatch") } + auditData, ok := ctxstore.Value[*audit.Data](ctx) + if !ok { + return nil, errors.New("audit data not found") + } + + auditData.UserID = userID + auditData.Fingerprint = pubKey.Metadata().ID() + auditData.PublicKeyExpiration = pubKey.TypedSpec().Value.Expiration.Seconds + auditData.Role = role.Role(pubKey.TypedSpec().Value.GetRole()) + _, err = safe.StateUpdateWithConflicts(ctx, s.state, pubKey.Metadata(), func(pk *authres.PublicKey) error { pk.TypedSpec().Value.Confirmed = true diff --git a/internal/backend/runtime/omni/audit/audit.go b/internal/backend/runtime/omni/audit/audit.go index a7de4b09..bd2a66e9 100644 --- a/internal/backend/runtime/omni/audit/audit.go +++ b/internal/backend/runtime/omni/audit/audit.go @@ -10,12 +10,21 @@ import ( "github.com/siderolabs/omni/internal/pkg/auth/role" ) +const ( + // Auth0 is auth0 confirmation type. + Auth0 = "auth0" + // SAML is SAML confirmation type. + SAML = "saml" +) + // Data contains the audit data. type Data struct { - UserAgent string `json:"user_agent,omitempty"` - IPAddress string `json:"ip_address,omitempty"` - UserID string `json:"user_id,omitempty"` - Identity string `json:"identity,omitempty"` - Role role.Role `json:"role,omitempty"` - Email string `json:"email,omitempty"` + UserAgent string `json:"user_agent,omitempty"` + IPAddress string `json:"ip_address,omitempty"` + UserID string `json:"user_id,omitempty"` + Role role.Role `json:"role,omitempty"` + Email string `json:"email,omitempty"` + ConfirmationType string `json:"confirmation_type,omitempty"` + Fingerprint string `json:"fingerprint,omitempty"` + PublicKeyExpiration int64 `json:"public_key_expiration,omitempty"` } diff --git a/internal/backend/runtime/omni/audit/gate.go b/internal/backend/runtime/omni/audit/gate.go new file mode 100644 index 00000000..1064df50 --- /dev/null +++ b/internal/backend/runtime/omni/audit/gate.go @@ -0,0 +1,143 @@ +// Copyright (c) 2024 Sidero Labs, Inc. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. + +package audit + +import ( + "context" + "sync" + + "github.com/cosi-project/runtime/pkg/resource" + "github.com/siderolabs/gen/pair" +) + +// Check is a function that checks if the event is allowed. +type Check = func(ctx context.Context, eventType EventType, args ...any) bool + +// Gate is a gate that checks if the event is allowed. +// +//nolint:govet +type Gate struct { + mu sync.RWMutex + fns [10]map[resource.Type]Check +} + +// Check checks if the event is allowed. +func (g *Gate) Check(ctx context.Context, eventType EventType, typ resource.Type, args ...any) bool { + fn := g.check(eventType, typ) + if fn == nil { + return false + } + + return fn(ctx, eventType, args...) +} + +func (g *Gate) check(eventType EventType, typ resource.Type) Check { + g.mu.RLock() + defer g.mu.RUnlock() + + if g.fns[0] == nil { + return nil + } + + for i, e := range allEvents { + if eventType == e.typ { + return g.fns[i][typ] + } + } + + return nil +} + +// AddChecks adds checks for the event types. It's allowed to pass several at once using bitwise OR. +func (g *Gate) AddChecks(eventTypes EventType, pairs []pair.Pair[resource.Type, Check]) { + g.mu.Lock() + defer g.mu.Unlock() + + if g.fns[0] == nil { + for i := range g.fns { + g.fns[i] = map[resource.Type]Check{} + } + } + + for _, p := range pairs { + g.addCheck(eventTypes, p) + } +} + +func (g *Gate) addCheck(eventTypes EventType, p pair.Pair[resource.Type, Check]) { + for i, e := range allEvents { + if e.typ&eventTypes != 0 { + if _, ok := g.fns[i][p.F1]; ok { + panic("duplicate check") + } + + g.fns[i][p.F1] = p.F2 + } + } +} + +// AllowAll is a check that allows all events for certain event type. +func AllowAll(context.Context, EventType, ...any) bool { + return true +} + +const ( + // EventGet is the get event type. + EventGet EventType = 1 << iota + // EventList is the list event type. + EventList + // EventCreate is the create event type. + EventCreate + // EventUpdate is the update event type. + EventUpdate + // EventDestroy is the destroy event type. + EventDestroy + // EventWatch is the watch event type. + EventWatch + // EventWatchKind is the watch kind event type. + EventWatchKind + // EventWatchKindAggregated is the watch kind aggregated event type. + EventWatchKindAggregated + // EventUpdateWithConflicts is the update with conflicts event type. + EventUpdateWithConflicts + // EventWatchFor is the watch for event type. + EventWatchFor +) + +// EventType represents the type of event. +type EventType int + +// MarshalJSON marshals the event type to JSON. +func (e *EventType) MarshalJSON() ([]byte, error) { + return []byte(`"` + e.String() + `"`), nil +} + +// String returns the string representation of the event type. +func (e *EventType) String() string { + for _, ev := range allEvents { + if *e == ev.typ { + return ev.str + } + } + + return "" +} + +var allEvents = []struct { + str string + typ EventType +}{ + {"get", EventGet}, + {"list", EventList}, + {"create", EventCreate}, + {"update", EventUpdate}, + {"destroy", EventDestroy}, + {"watch", EventWatch}, + {"watch_kind", EventWatchKind}, + {"watch_kind_aggregated", EventWatchKindAggregated}, + {"update_with_conflicts", EventUpdateWithConflicts}, + {"watch_for", EventWatchFor}, +} diff --git a/internal/backend/runtime/omni/audit/log.go b/internal/backend/runtime/omni/audit/log.go new file mode 100644 index 00000000..06d39120 --- /dev/null +++ b/internal/backend/runtime/omni/audit/log.go @@ -0,0 +1,76 @@ +// Copyright (c) 2024 Sidero Labs, Inc. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. + +package audit + +import ( + "context" + "fmt" + "os" + "time" + + "github.com/cosi-project/runtime/pkg/resource" + "github.com/siderolabs/gen/pair" + "go.uber.org/zap" + + "github.com/siderolabs/omni/internal/pkg/ctxstore" +) + +// NewLogger creates a new audit logger. +func NewLogger(auditLogDir string, logger *zap.Logger) (*Logger, error) { + err := os.MkdirAll(auditLogDir, 0o755) + if err != nil { + return nil, fmt.Errorf("failed to create audit logger: %w", err) + } + + return &Logger{ + logFile: NewLogFile(auditLogDir), + logger: logger, + }, nil +} + +// Logger logs audit events. +type Logger struct { + gate Gate + logFile *LogFile + logger *zap.Logger +} + +// LogEvent logs an audit event. +func (l *Logger) LogEvent(ctx context.Context, eventType EventType, resType resource.Type, args ...any) { + if !l.gate.Check(ctx, eventType, resType, args...) { + return + } + + value, ok := ctxstore.Value[*Data](ctx) + if !ok { + return + } + + err := l.logFile.Dump(&event{ + Type: eventType, + ResourceType: resType, + Time: time.Now().UnixMilli(), + Data: value, + }) + if err == nil { + return + } + + l.logger.Error("failed to dump audit log", zap.Error(err)) +} + +// ShoudLog adds checks that allow event type to be logged. +func (l *Logger) ShoudLog(eventType EventType, p ...pair.Pair[resource.Type, Check]) { + l.gate.AddChecks(eventType, p) +} + +//nolint:govet +type event struct { + Type EventType `json:"event_type,omitempty"` + ResourceType resource.Type `json:"resource_type,omitempty"` + Time int64 `json:"event_ts,omitempty"` + Data *Data `json:"event_data,omitempty"` +} diff --git a/internal/backend/runtime/omni/audit/log_file_test.go b/internal/backend/runtime/omni/audit/log_file_test.go index 493e4da9..ab574f79 100644 --- a/internal/backend/runtime/omni/audit/log_file_test.go +++ b/internal/backend/runtime/omni/audit/log_file_test.go @@ -26,9 +26,7 @@ import ( var currentDay embed.FS func TestLogFile_CurrentDay(t *testing.T) { - dir := must.Value(os.MkdirTemp("", "log_file_test"))(t) - - t.Cleanup(func() { os.RemoveAll(dir) }) //nolint:errcheck + dir := t.TempDir() entries := []entry{ {shift: time.Second, data: audit.Data{UserAgent: "Mozilla/5.0", IPAddress: "10.10.0.1", Email: "random_email1@example.com"}}, @@ -46,16 +44,19 @@ func TestLogFile_CurrentDay(t *testing.T) { require.NoError(t, file.DumpAt(e.data, now)) } - checkFiles(t, basicLoader(dir), fsSub(t, currentDay, "currentday")) + equalDirs( + t, + fsSub(t, currentDay, "currentday"), + os.DirFS(dir).(subFS), //nolint:forcetypeassert + defaultCmp, + ) } //go:embed testdata/nextday var nextDay embed.FS func TestLogFile_CurrentAndNewDay(t *testing.T) { - dir := must.Value(os.MkdirTemp("", "log_file_test"))(t) - - t.Cleanup(func() { os.RemoveAll(dir) }) //nolint:errcheck + dir := t.TempDir() entries := []entry{ {shift: 0, data: audit.Data{UserAgent: "Mozilla/5.0", IPAddress: "10.10.0.1", Email: "random_email1@example.com"}}, @@ -73,16 +74,19 @@ func TestLogFile_CurrentAndNewDay(t *testing.T) { require.NoError(t, file.DumpAt(e.data, now)) } - checkFiles(t, basicLoader(dir), fsSub(t, nextDay, "nextday")) + equalDirs( + t, + fsSub(t, nextDay, "nextday"), + os.DirFS(dir).(subFS), //nolint:forcetypeassert + defaultCmp, + ) } //go:embed testdata/concurrent var concurrent embed.FS func TestLogFile_CurrentDayConcurrent(t *testing.T) { - dir := must.Value(os.MkdirTemp("", "log_file_test"))(t) - - t.Cleanup(func() { os.RemoveAll(dir) }) //nolint:errcheck + dir := t.TempDir() entries := make([]entry, 0, 250) @@ -110,7 +114,14 @@ func TestLogFile_CurrentDayConcurrent(t *testing.T) { } }) - checkFiles(t, sortedLoader(basicLoader(dir)), fsSub(t, concurrent, "concurrent")) + equalDirs( + t, + fsSub(t, concurrent, "concurrent"), + &sortedFileFS{ + subFS: os.DirFS(dir).(subFS), //nolint:forcetypeassert + }, + defaultCmp, + ) } //nolint:govet @@ -124,40 +135,48 @@ type subFS interface { fs.ReadDirFS } -func checkFiles(t *testing.T, loader fileLoader, expectedFS subFS) { - expectedFiles := must.Value(expectedFS.ReadDir("."))(t) +func fsSub(t *testing.T, subFs subFS, folder string) subFS { + return must.Value(fs.Sub(subFs, filepath.Join("testdata", folder)))(t).(subFS) //nolint:forcetypeassert +} + +func equalDirs(t *testing.T, expected, actual subFS, cmpFn func(t *testing.T, expected, actual string)) { + expectedFiles := must.Value(expected.ReadDir("."))(t) + actualFiles := must.Value(actual.ReadDir("."))(t) + + if len(expectedFiles) != len(actualFiles) { + t.Fatalf("expected %v files, got %v", expectedFiles, actualFiles) + } - for _, expectedFile := range expectedFiles { - if expectedFile.IsDir() { - t.Fatal("unexpected directory", expectedFile.Name()) + for _, actualFile := range actualFiles { + if actualFile.IsDir() { + t.Fatal("unexpected directory", actualFile.Name()) } - expectedData := string(must.Value(expectedFS.ReadFile(expectedFile.Name()))(t)) - actualData := loader(t, expectedFile.Name()) + name := actualFile.Name() - require.Equal(t, expectedData, actualData, "file %s", expectedFile.Name()) - } -} + expectedContent := must.Value(expected.ReadFile(name))(t) + actualContent := must.Value(actual.ReadFile(name))(t) -func fsSub(t *testing.T, subFs subFS, folder string) subFS { - return must.Value(fs.Sub(subFs, filepath.Join("testdata", folder)))(t).(subFS) //nolint:forcetypeassert + cmpFn(t, string(expectedContent), string(actualContent)) + } } -type fileLoader func(t *testing.T, filename string) string +type sortedFileFS struct{ subFS } -func basicLoader(dir string) func(t *testing.T, filename string) string { - return func(t *testing.T, filename string) string { - return string(must.Value(os.ReadFile(filepath.Join(dir, filename)))(t)) +func (s *sortedFileFS) ReadFile(name string) ([]byte, error) { + b, err := s.subFS.ReadFile(name) + if err != nil { + return nil, err } -} -func sortedLoader(loader fileLoader) fileLoader { - return func(t *testing.T, filename string) string { - data := strings.TrimRight(loader(t, filename), "\n") - slc := strings.Split(data, "\n") + data := strings.TrimRight(string(b), "\n") + slc := strings.Split(data, "\n") - slices.Sort(slc) + slices.Sort(slc) - return strings.Join(slc, "\n") + "\n" - } + return []byte(strings.Join(slc, "\n") + "\n"), nil +} + +func defaultCmp(t *testing.T, expected string, actual string) { + require.Equal(t, expected, actual) } diff --git a/internal/backend/runtime/omni/audit/log_test.go b/internal/backend/runtime/omni/audit/log_test.go new file mode 100644 index 00000000..9236d95e --- /dev/null +++ b/internal/backend/runtime/omni/audit/log_test.go @@ -0,0 +1,116 @@ +// Copyright (c) 2024 Sidero Labs, Inc. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. + +package audit_test + +import ( + "context" + "embed" + "encoding/json" + "errors" + "io" + "os" + "strings" + "testing" + + "github.com/cosi-project/runtime/pkg/resource" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/siderolabs/gen/pair" + "github.com/siderolabs/gen/xtesting/must" + "go.uber.org/zap/zaptest" + + "github.com/siderolabs/omni/client/pkg/omni/resources/auth" + "github.com/siderolabs/omni/internal/backend/runtime/omni/audit" + "github.com/siderolabs/omni/internal/pkg/ctxstore" +) + +//go:embed testdata/log +var logDir embed.FS + +func TestLog(t *testing.T) { + tempDir := t.TempDir() + logger := must.Value(audit.NewLogger(tempDir, zaptest.NewLogger(t)))(t) + + logger.ShoudLog(audit.EventCreate|audit.EventUpdate|audit.EventUpdateWithConflicts, + pair.MakePair(auth.PublicKeyType, audit.AllowAll), + ) + + events := []pair.Triple[audit.EventType, resource.Type, *audit.Data]{ + pair.MakeTriple(audit.EventCreate, auth.PublicKeyType, &audit.Data{UserAgent: "Mozilla/5.0", IPAddress: "10.10.0.1", Email: "random_email1@example.com"}), + pair.MakeTriple(audit.EventUpdate, auth.PublicKeyType, &audit.Data{UserAgent: "Mozilla/5.0", IPAddress: "10.10.0.2", Email: "random_email2@example.com"}), + pair.MakeTriple(audit.EventUpdateWithConflicts, auth.PublicKeyType, &audit.Data{UserAgent: "Mozilla/5.0", IPAddress: "10.10.0.3", Email: "random_email3@example.com"}), + pair.MakeTriple(audit.EventDestroy, auth.PublicKeyType, &audit.Data{UserAgent: "Mozilla/5.0", IPAddress: "10.10.0.4", Email: "random_email4@example.com"}), + pair.MakeTriple(audit.EventCreate, auth.PublicKeyType, (*audit.Data)(nil)), + pair.MakeTriple(audit.EventCreate, auth.AuthConfigType, &audit.Data{UserAgent: "Mozilla/5.0", IPAddress: "10.10.0.5", Email: "random_email5@example.com"}), + } + + for _, event := range events { + ctx := context.Background() + + if event.V3 != nil { + ctx = ctxstore.WithValue(ctx, event.V3) + } + + logger.LogEvent(ctx, event.V1, event.V2, 100) + } + + equalDirs( + t, + &wrapFS{ + subFS: fsSub(t, logDir, "log"), + File: "2012-01-01.jsonlog", + }, + os.DirFS(tempDir).(subFS), //nolint:forcetypeassert + cmpIgnoreTime, + ) +} + +type wrapFS struct { + subFS + File string +} + +func (w *wrapFS) ReadFile(string) ([]byte, error) { + return w.subFS.ReadFile(w.File) +} + +func cmpIgnoreTime(t *testing.T, expected string, actual string) { + expectedEvents := loadEvents(t, expected) + actualEvents := loadEvents(t, actual) + + diff := cmp.Diff(expectedEvents, actualEvents, cmpopts.IgnoreMapEntries(func(k string, v any) bool { + _, ok := v.(json.Number) + + return ok && k == "event_ts" + })) + if diff != "" { + t.Fatalf("events mismatch (-want +got):\n%s", diff) + } +} + +func loadEvents(t *testing.T, expected string) []any { + var result []any + + decoder := json.NewDecoder(strings.NewReader(expected)) + decoder.UseNumber() + + for { + var event any + + err := decoder.Decode(&event) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + + t.Fatalf("failed to decode event: %v", err) + } + + result = append(result, event) + } + + return result +} diff --git a/internal/backend/runtime/omni/audit/state.go b/internal/backend/runtime/omni/audit/state.go new file mode 100644 index 00000000..694ffe23 --- /dev/null +++ b/internal/backend/runtime/omni/audit/state.go @@ -0,0 +1,102 @@ +// Copyright (c) 2024 Sidero Labs, Inc. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. + +package audit + +import ( + "context" + + "github.com/cosi-project/runtime/pkg/resource" + "github.com/cosi-project/runtime/pkg/state" +) + +// WrapState wraps the given state with audit log state. +func WrapState(s state.State, l *Logger) state.State { + return &auditState{ + state: s, + logger: l, + } +} + +type auditState struct { + state state.State + logger *Logger +} + +func (a *auditState) Get(ctx context.Context, ptr resource.Pointer, option ...state.GetOption) (resource.Resource, error) { + a.logger.LogEvent(ctx, EventGet, ptr.Type(), option) + + return a.state.Get(ctx, ptr, option...) +} + +func (a *auditState) List(ctx context.Context, kind resource.Kind, option ...state.ListOption) (resource.List, error) { + a.logger.LogEvent(ctx, EventList, kind.Type(), option) + + return a.state.List(ctx, kind, option...) +} + +func (a *auditState) Create(ctx context.Context, res resource.Resource, option ...state.CreateOption) error { + a.logger.LogEvent(ctx, EventCreate, res.Metadata().Type(), option) + + return a.state.Create(ctx, res, option...) +} + +func (a *auditState) Update(ctx context.Context, newRes resource.Resource, opts ...state.UpdateOption) error { + a.logger.LogEvent(ctx, EventUpdate, newRes.Metadata().Type(), opts) + + return a.state.Update(ctx, newRes, opts...) +} + +func (a *auditState) Destroy(ctx context.Context, ptr resource.Pointer, option ...state.DestroyOption) error { + a.logger.LogEvent(ctx, EventDestroy, ptr.Type(), option) + + return a.state.Destroy(ctx, ptr, option...) +} + +func (a *auditState) Watch(ctx context.Context, ptr resource.Pointer, events chan<- state.Event, option ...state.WatchOption) error { + a.logger.LogEvent(ctx, EventWatch, ptr.Type(), option) + + return a.state.Watch(ctx, ptr, events, option...) +} + +func (a *auditState) WatchKind(ctx context.Context, kind resource.Kind, events chan<- state.Event, option ...state.WatchKindOption) error { + a.logger.LogEvent(ctx, EventWatchKind, kind.Type(), option) + + return a.state.WatchKind(ctx, kind, events, option...) +} + +func (a *auditState) WatchKindAggregated(ctx context.Context, kind resource.Kind, c chan<- []state.Event, option ...state.WatchKindOption) error { + a.logger.LogEvent(ctx, EventWatchKindAggregated, kind.Type(), option) + + return a.state.WatchKindAggregated(ctx, kind, c, option...) +} + +func (a *auditState) UpdateWithConflicts(ctx context.Context, ptr resource.Pointer, updaterFunc state.UpdaterFunc, option ...state.UpdateOption) (resource.Resource, error) { + a.logger.LogEvent(ctx, EventUpdateWithConflicts, ptr.Type(), option) + + return a.state.UpdateWithConflicts(ctx, ptr, updaterFunc, option...) +} + +func (a *auditState) WatchFor(ctx context.Context, pointer resource.Pointer, conditionFunc ...state.WatchForConditionFunc) (resource.Resource, error) { + a.logger.LogEvent(ctx, EventWatchFor, pointer.Type(), conditionFunc) + + return a.state.WatchFor(ctx, pointer, conditionFunc...) +} + +func (a *auditState) Teardown(ctx context.Context, pointer resource.Pointer, option ...state.TeardownOption) (bool, error) { + return a.state.Teardown(ctx, pointer, option...) +} + +func (a *auditState) AddFinalizer(ctx context.Context, pointer resource.Pointer, finalizer ...resource.Finalizer) error { + return a.state.AddFinalizer(ctx, pointer, finalizer...) +} + +func (a *auditState) RemoveFinalizer(ctx context.Context, pointer resource.Pointer, finalizer ...resource.Finalizer) error { + return a.state.RemoveFinalizer(ctx, pointer, finalizer...) +} + +func (a *auditState) ContextWithTeardown(ctx context.Context, pointer resource.Pointer) (context.Context, error) { + return a.state.ContextWithTeardown(ctx, pointer) +} diff --git a/internal/backend/runtime/omni/audit/testdata/log/2012-01-01.jsonlog b/internal/backend/runtime/omni/audit/testdata/log/2012-01-01.jsonlog new file mode 100644 index 00000000..c8e7628c --- /dev/null +++ b/internal/backend/runtime/omni/audit/testdata/log/2012-01-01.jsonlog @@ -0,0 +1,3 @@ +{"event_type":"create","resource_type":"PublicKeys.omni.sidero.dev","event_data":{"user_agent":"Mozilla/5.0","ip_address":"10.10.0.1","email":"random_email1@example.com"}} +{"event_type":"update","resource_type":"PublicKeys.omni.sidero.dev","event_data":{"user_agent":"Mozilla/5.0","ip_address":"10.10.0.2","email":"random_email2@example.com"}} +{"event_type":"update_with_conflicts","resource_type":"PublicKeys.omni.sidero.dev","event_data":{"user_agent":"Mozilla/5.0","ip_address":"10.10.0.3","email":"random_email3@example.com"}} diff --git a/internal/backend/runtime/omni/state.go b/internal/backend/runtime/omni/state.go index 92dbb57f..0f3aeea9 100644 --- a/internal/backend/runtime/omni/state.go +++ b/internal/backend/runtime/omni/state.go @@ -17,13 +17,16 @@ import ( "github.com/cosi-project/runtime/pkg/state/impl/namespaced" "github.com/cosi-project/runtime/pkg/state/registry" "github.com/prometheus/client_golang/prometheus" + "github.com/siderolabs/gen/pair" "go.etcd.io/bbolt" "go.uber.org/zap" "github.com/siderolabs/omni/client/pkg/omni/resources" + "github.com/siderolabs/omni/client/pkg/omni/resources/auth" resourceregistry "github.com/siderolabs/omni/client/pkg/omni/resources/registry" "github.com/siderolabs/omni/client/pkg/omni/resources/system" "github.com/siderolabs/omni/internal/backend/logging" + "github.com/siderolabs/omni/internal/backend/runtime/omni/audit" "github.com/siderolabs/omni/internal/backend/runtime/omni/cloudprovider" "github.com/siderolabs/omni/internal/backend/runtime/omni/controllers/omni/etcdbackup/store" "github.com/siderolabs/omni/internal/backend/runtime/omni/external" @@ -35,7 +38,7 @@ import ( // NewState creates a production Omni state. // -//nolint:cyclop +//nolint:cyclop,gocognit func NewState(ctx context.Context, params *config.Params, logger *zap.Logger, metricsRegistry prometheus.Registerer, f func(context.Context, state.State, *virtual.State) error) error { stateFunc := func(ctx context.Context, persistentStateBuilder namespaced.StateBuilder) error { primaryStorageCoreState := persistentStateBuilder(resources.DefaultNamespace) @@ -138,7 +141,16 @@ func NewState(ctx context.Context, params *config.Params, logger *zap.Logger, me return err } - return f(ctx, resourceState, virtualState) + resourceState, fileErr := auditState(resourceState, params, logger) + if fileErr != nil { + return fileErr + } + + return f( + ctx, + resourceState, + virtualState, + ) } switch params.Storage.Kind { @@ -150,3 +162,24 @@ func NewState(ctx context.Context, params *config.Params, logger *zap.Logger, me return fmt.Errorf("unknown storage kind %q", params.Storage.Kind) } } + +func auditState(resState state.State, params *config.Params, logger *zap.Logger) (state.State, error) { + if params.AuditLogDir == "" { + logger.Info("audit log disabled") + + return resState, nil + } + + logger.Info("audit log enabled", zap.String("dir", params.AuditLogDir)) + + l, err := audit.NewLogger(params.AuditLogDir, logger) + if err != nil { + return nil, err + } + + l.ShoudLog(audit.EventCreate|audit.EventUpdate|audit.EventUpdateWithConflicts, + pair.MakePair(auth.PublicKeyType, audit.AllowAll), + ) + + return audit.WrapState(resState, l), nil +} diff --git a/internal/backend/server.go b/internal/backend/server.go index dd1688db..d971055e 100644 --- a/internal/backend/server.go +++ b/internal/backend/server.go @@ -316,6 +316,7 @@ func (s *Server) buildServerOptions() ([]grpc.ServerOption, error) { grpc_zap.UnaryServerInterceptor(s.logger, grpc_zap.WithMessageProducer(messageProducer)), grpcutil.SetUserAgent(), grpcutil.SetRealPeerAddress(), + grpcutil.SetAuditData(), grpcutil.InterceptBodyToTags( grpcutil.NewHook( grpcutil.NewRewriter(resourceServerCreate), @@ -335,6 +336,7 @@ func (s *Server) buildServerOptions() ([]grpc.ServerOption, error) { grpc_zap.StreamServerInterceptor(s.logger, grpc_zap.WithMessageProducer(messageProducer)), grpcutil.StreamSetUserAgent(), grpcutil.StreamSetRealPeerAddress(), + grpcutil.StreamSetAuditData(), grpcutil.StreamIntercept( grpcutil.StreamHooks{ RecvMsg: grpcutil.StreamInterceptRequestBodyToTags( diff --git a/internal/pkg/auth/interceptor/jwt.go b/internal/pkg/auth/interceptor/jwt.go index fbae7f84..d79e3f9a 100644 --- a/internal/pkg/auth/interceptor/jwt.go +++ b/internal/pkg/auth/interceptor/jwt.go @@ -18,6 +18,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "github.com/siderolabs/omni/internal/backend/runtime/omni/audit" "github.com/siderolabs/omni/internal/pkg/auth" "github.com/siderolabs/omni/internal/pkg/auth/auth0" "github.com/siderolabs/omni/internal/pkg/ctxstore" @@ -91,6 +92,14 @@ func (i *JWT) intercept(ctx context.Context) (context.Context, error) { return nil, errGRPCInvalidJWT } + auditData, ok := ctxstore.Value[*audit.Data](ctx) + if !ok { + return nil, status.Error(codes.Internal, "missing or invalid audit data") + } + + auditData.Email = claims.VerifiedEmail + auditData.ConfirmationType = audit.Auth0 + ctx = ctxstore.WithValue(ctx, auth.VerifiedEmailContextKey{Email: claims.VerifiedEmail}) return ctx, nil diff --git a/internal/pkg/auth/interceptor/saml.go b/internal/pkg/auth/interceptor/saml.go index 3188f0fc..0eed3d7c 100644 --- a/internal/pkg/auth/interceptor/saml.go +++ b/internal/pkg/auth/interceptor/saml.go @@ -22,6 +22,7 @@ import ( "github.com/siderolabs/omni/client/pkg/omni/resources" authres "github.com/siderolabs/omni/client/pkg/omni/resources/auth" + "github.com/siderolabs/omni/internal/backend/runtime/omni/audit" "github.com/siderolabs/omni/internal/pkg/auth" "github.com/siderolabs/omni/internal/pkg/auth/actor" "github.com/siderolabs/omni/internal/pkg/ctxstore" @@ -86,6 +87,14 @@ func (i *SAML) intercept(ctx context.Context) (context.Context, error) { return nil, errGRPCInvalidSAML } + auditData, ok := ctxstore.Value[*audit.Data](ctx) + if !ok { + return nil, status.Error(codes.Internal, "missing or invalid audit data") + } + + auditData.Email = session.TypedSpec().Value.Email + auditData.ConfirmationType = audit.SAML + ctx = ctxstore.WithValue(ctx, auth.VerifiedEmailContextKey{Email: session.TypedSpec().Value.Email}) return ctx, nil diff --git a/internal/pkg/auth/interceptor/signature.go b/internal/pkg/auth/interceptor/signature.go index ec896e43..61e5932a 100644 --- a/internal/pkg/auth/interceptor/signature.go +++ b/internal/pkg/auth/interceptor/signature.go @@ -17,8 +17,10 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "github.com/siderolabs/omni/internal/backend/runtime/omni/audit" "github.com/siderolabs/omni/internal/pkg/auth" "github.com/siderolabs/omni/internal/pkg/ctxstore" + "github.com/siderolabs/omni/internal/pkg/grpcutil" ) var errGRPCInvalidSignature = status.Error(codes.Unauthenticated, "invalid signature") @@ -65,6 +67,13 @@ func (i *Signature) Stream() grpc.StreamServerInterceptor { } func (i *Signature) intercept(ctx context.Context) (context.Context, error) { + auditData, ok := ctxstore.Value[*audit.Data](ctx) + if !ok { + // This is allowed because signature interceptor can be called independently of others. + ctx = grpcutil.SetAuditInCtx(ctx) + auditData, _ = ctxstore.Value[*audit.Data](ctx) + } + msgVal, ok := ctxstore.Value[auth.GRPCMessageContextKey](ctx) if !ok { return nil, status.Error(codes.Internal, "missing or invalid message in context") @@ -108,6 +117,9 @@ func (i *Signature) intercept(ctx context.Context) (context.Context, error) { return nil, errGRPCInvalidSignature } + auditData.UserID = authenticator.UserID + auditData.Role = authenticator.Role + grpc_ctxtags.Extract(ctx). Set("authenticator.user_id", authenticator.UserID). Set("authenticator.identity", authenticator.Identity). diff --git a/internal/pkg/config/config.go b/internal/pkg/config/config.go index d24d079c..5e7e1d75 100644 --- a/internal/pkg/config/config.go +++ b/internal/pkg/config/config.go @@ -92,6 +92,8 @@ type Params struct { EmbeddedDiscoveryService EmbeddedDiscoveryServiceParams `yaml:"embeddedDiscoveryService"` EnableBreakGlassConfigs bool `yaml:"enableBreakGlassConfigs"` + + AuditLogDir string `yaml:"auditLogDir"` } // EmbeddedDiscoveryServiceParams defines embedded discovery service configs. diff --git a/internal/pkg/grpcutil/audit.go b/internal/pkg/grpcutil/audit.go new file mode 100644 index 00000000..8f04e83a --- /dev/null +++ b/internal/pkg/grpcutil/audit.go @@ -0,0 +1,39 @@ +// Copyright (c) 2024 Sidero Labs, Inc. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. + +package grpcutil + +import ( + "context" + + grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" + + "github.com/siderolabs/omni/internal/backend/runtime/omni/audit" + "github.com/siderolabs/omni/internal/pkg/ctxstore" +) + +// SetAuditInCtx sets audit data in the context. +func SetAuditInCtx(ctx context.Context) context.Context { + m := grpc_ctxtags.Extract(ctx).Values() + + return ctxstore.WithValue(ctx, &audit.Data{ + UserAgent: valOrEmpty[string](m, "user_agent"), + IPAddress: valOrEmpty[string](m, "peer.address"), + }) +} + +func valOrEmpty[T any](m map[string]any, key string) T { + v, ok := m[key] + if !ok { + return *new(T) + } + + result, ok := v.(T) + if !ok { + return *new(T) + } + + return result +} diff --git a/internal/pkg/grpcutil/logger.go b/internal/pkg/grpcutil/logger.go index 3f9aad82..5183149f 100644 --- a/internal/pkg/grpcutil/logger.go +++ b/internal/pkg/grpcutil/logger.go @@ -245,3 +245,10 @@ func LogLevelInterceptors() (grpc.UnaryServerInterceptor, grpc.StreamServerInter return unary, stream } + +// SetAuditData returns a new unary server interceptor that adds audit data to the context. +func SetAuditData() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + return handler(SetAuditInCtx(ctx), req) + } +} diff --git a/internal/pkg/grpcutil/stream_interceptors.go b/internal/pkg/grpcutil/stream_interceptors.go index 824fd210..3fffc62a 100644 --- a/internal/pkg/grpcutil/stream_interceptors.go +++ b/internal/pkg/grpcutil/stream_interceptors.go @@ -8,6 +8,7 @@ package grpcutil import ( "context" + grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" "google.golang.org/grpc" "google.golang.org/protobuf/proto" @@ -87,3 +88,13 @@ func StreamInterceptRequestBodyToTags(hook Hook, bodyLimit int) RecvMsgHook { return result } } + +// StreamSetAuditData returns a new stream server interceptor that adds audit data to the context. +func StreamSetAuditData() grpc.StreamServerInterceptor { + return func(srv any, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + return handler(srv, &grpc_middleware.WrappedServerStream{ + ServerStream: ss, + WrappedContext: SetAuditInCtx(ss.Context()), + }) + } +}