Skip to content

Commit

Permalink
sql/opt: Generate synthetic check constraint to enforce RLS policies …
Browse files Browse the repository at this point in the history
…for new rows

With row-level security, policies include a WITH CHECK expression to
enforce constraints on new rows. This commit begins adding support for
enforcing these policies by modifying the optbuilder to construct the
check constraint, evaluate the expression, and pass the result to the
execution engine. A future commit will integrate the execution engine to
fully enforce these policies.

Since the expression for the synthetic check constraint is determined at
INSERT or UPDATE time, a placeholder check constraint is added when
building the optimizer table catalog. The check constraint is then
finalized in the mutationBuilder.

Because the check constraint is constructed late in the process, a
function is needed to look up the column ordinal for a given column ID.
To facilitate this, the previously internal function lookupColumnOrdinal
has been made external as LookupColumnOrdinal.

Epic: CRDB-45203
Release note: None
Informs: #136704
  • Loading branch information
spilchen committed Feb 24, 2025
1 parent 34e34c2 commit 5136c0c
Show file tree
Hide file tree
Showing 11 changed files with 777 additions and 161 deletions.
6 changes: 6 additions & 0 deletions pkg/sql/opt/cat/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,16 @@ type Policy struct {
// read operations. If the policy does not define a USING expression, this is
// an empty string.
UsingExpr string
// UsingColumnIDs is a set of column IDs that are referenced in the USING
// expression.
UsingColumnIDs descpb.ColumnIDs
// WithCheckExpr is the optional validation expression applied to new rows
// during write operations. If the policy does not define a WITH CHECK expression,
// this is an empty string.
WithCheckExpr string
// WithCheckColumnIDs is a set of column IDs that are referenced in the WITH
// CHECK expression.
WithCheckColumnIDs descpb.ColumnIDs
// Command is the command that the policy was defined for.
Command catpb.PolicyCommand
// roles are the roles the applies to. If the policy applies to all roles (aka
Expand Down
15 changes: 9 additions & 6 deletions pkg/sql/opt/cat/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ type Table interface {
// that they cannot be mutated.
IsMaterializedView() bool

// LookupColumnOrdinal returns the ordinal of the column with the given ID.
LookupColumnOrdinal(colID descpb.ColumnID) (int, error)

// ColumnCount returns the number of columns in the table. This includes
// public columns, write-only columns, etc.
ColumnCount() int
Expand Down Expand Up @@ -185,12 +188,8 @@ type Table interface {
// IsRowLevelSecurityEnabled is true if policies should be applied during the query.
IsRowLevelSecurityEnabled() bool

// PolicyCount returns the number of policies in the table for the given type.
PolicyCount(polType tree.PolicyType) int

// Policy retrieves the policy of the specified type at the given index (i),
// where i < PolicyCount for the specified type.
Policy(polType tree.PolicyType, i int) Policy
// Policies returns all the policies defined for this table.
Policies() *Policies
}

// CheckConstraint represents a check constraint on a table. Check constraints
Expand All @@ -212,6 +211,10 @@ type CheckConstraint interface {
// ColumnOrdinal returns the table column ordinal of the ith column in this
// constraint.
ColumnOrdinal(i int) int

// IsRLSConstraint is true if this is a constraint used to enforce
// row-level security policies.
IsRLSConstraint() bool
}

// TableStatistic is an interface to a table statistic. Each statistic is
Expand Down
5 changes: 5 additions & 0 deletions pkg/sql/opt/cat/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ func FormatTable(
}

for i := 0; i < tab.CheckCount(); i++ {
// We only show constraints that are constant and known when the catalog is
// built. For this reason, skip the one we add for row-level security.
if tab.Check(i).IsRLSConstraint() {
continue
}
child.Childf("CHECK (%s)", MaybeMarkRedactable(tab.Check(i).Constraint(), redactableValues))
}

Expand Down
6 changes: 4 additions & 2 deletions pkg/sql/opt/exec/explain/emit.go
Original file line number Diff line number Diff line change
Expand Up @@ -1366,8 +1366,10 @@ func (e *emitter) emitPolicies(
ob.AddField("policies", "row-level security enabled, no policies applied.")
} else {
var sb strings.Builder
for i := 0; i < table.PolicyCount(tree.PolicyTypePermissive); i++ {
policy := table.Policy(tree.PolicyTypePermissive, i)
policies := table.Policies()
// TODO(136742): Add support for restrictive policies.
for i := range policies.Permissive {
policy := policies.Permissive[i]
if applied.Policies.Contains(policy.ID) {
if sb.Len() > 0 {
sb.WriteString(", ")
Expand Down
13 changes: 6 additions & 7 deletions pkg/sql/opt/exec/explain/plan_gist_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,10 @@ func (u *unknownTable) IsMaterializedView() bool {
return false
}

func (u *unknownTable) LookupColumnOrdinal(descpb.ColumnID) (int, error) {
panic(errors.AssertionFailedf("not implemented"))
}

func (u *unknownTable) ColumnCount() int {
return 0
}
Expand Down Expand Up @@ -662,13 +666,8 @@ func (u *unknownTable) Trigger(i int) cat.Trigger {
// IsRowLevelSecurityEnabled is part of the cat.Table interface
func (u *unknownTable) IsRowLevelSecurityEnabled() bool { return false }

// PolicyCount is part of the cat.Table interface
func (u *unknownTable) PolicyCount(polType tree.PolicyType) int { return 0 }

// Policy is part of the cat.Table interface
func (u *unknownTable) Policy(polType tree.PolicyType, i int) cat.Policy {
panic(errors.AssertionFailedf("not implemented"))
}
// Policies is part of the cat.Table interface.
func (u *unknownTable) Policies() *cat.Policies { return nil }

var _ cat.Table = &unknownTable{}

Expand Down
40 changes: 34 additions & 6 deletions pkg/sql/opt/optbuilder/mutation_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,23 @@ func (mb *mutationBuilder) addCheckConstraintCols(isUpdate bool) {

for i, n := 0, mb.tab.CheckCount(); i < n; i++ {
check := mb.tab.Check(i)

// For tables with RLS enabled, we create a synthetic check constraint
// to enforce the policies. Since this check varies based on the role
// and command used, it must be generated each time it is needed rather
// than being included with the table's actual check constraints.
if check.IsRLSConstraint() {
chkBuilder := optRLSConstraintBuilder{
tab: mb.tab,
md: mb.md,
tabMeta: mb.md.TableMeta(mb.tabID),
oc: mb.b.catalog,
user: mb.b.checkPrivilegeUser,
isUpdate: isUpdate,
}
check = chkBuilder.Build(mb.b.ctx)
}

expr, err := parser.ParseExpr(check.Constraint())
if err != nil {
panic(err)
Expand All @@ -881,19 +898,30 @@ func (mb *mutationBuilder) addCheckConstraintCols(isUpdate bool) {

// Use an anonymous name because the column cannot be referenced
// in other expressions.
colName := scopeColName("").WithMetadataName(fmt.Sprintf("check%d", i+1))
colName := scopeColName("")
if check.IsRLSConstraint() {
colName = colName.WithMetadataName("rls")
} else {
colName = colName.WithMetadataName(fmt.Sprintf("check%d", i+1))
}
scopeCol := projectionsScope.addColumn(colName, texpr)

// TODO(ridwanmsharif): Maybe we can avoid building constraints here
// and instead use the constraints stored in the table metadata.
referencedCols := &opt.ColSet{}
mb.b.buildScalar(texpr, mb.outScope, projectionsScope, scopeCol, referencedCols)

// If the mutation is not an UPDATE, track the synthesized check
// columns in checkColIDS. If the mutation is an UPDATE, only track
// the check columns if the columns referenced in the check
// expression are being mutated.
if !isUpdate || referencedCols.Intersects(mutationCols) {
// For non-UPDATE mutations, track the synthesized check columns in
// checkColIDs. For UPDATE mutations, track the check columns in two
// scenarios:
// - If the check expression is a real check constraint and the columns
// referenced in the check expression are being mutated.
// - If the check expression is a synthetic one used for row-level
// security (RLS). Since it's not a real check expression, different
// expressions can exist for read and write operations. This means it's
// possible to read a row whose column values would violate the write
// expression.
if !isUpdate || check.IsRLSConstraint() || referencedCols.Intersects(mutationCols) {
mb.checkColIDs[i] = scopeCol.id

// TODO(michae2): Under weaker isolation levels we need to use shared
Expand Down
156 changes: 152 additions & 4 deletions pkg/sql/opt/optbuilder/row_level_security.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@
package optbuilder

import (
"context"
"fmt"
"strings"

"github.com/cockroachdb/cockroach/pkg/security/username"
"github.com/cockroachdb/cockroach/pkg/sql/catalog/catpb"
"github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb"
"github.com/cockroachdb/cockroach/pkg/sql/opt"
"github.com/cockroachdb/cockroach/pkg/sql/opt/cat"
"github.com/cockroachdb/cockroach/pkg/sql/opt/memo"
"github.com/cockroachdb/cockroach/pkg/sql/parser"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/types"
"github.com/cockroachdb/cockroach/pkg/util/intsets"
"github.com/cockroachdb/errors"
)

Expand Down Expand Up @@ -48,9 +54,8 @@ func (b *Builder) buildRowLevelSecurityUsingExpression(
tabMeta *opt.TableMeta, tableScope *scope, cmdScope cat.PolicyCommandScope,
) opt.ScalarExpr {
var policiesUsed opt.PolicyIDSet
for i := 0; i < tabMeta.Table.PolicyCount(tree.PolicyTypePermissive); i++ {
policy := tabMeta.Table.Policy(tree.PolicyTypePermissive, i)

policies := tabMeta.Table.Policies()
for _, policy := range policies.Permissive {
if !policy.AppliesToRole(b.checkPrivilegeUser) || !b.policyAppliesToCommandScope(policy, cmdScope) {
continue
}
Expand Down Expand Up @@ -102,3 +107,146 @@ func (b *Builder) policyAppliesToCommandScope(
panic(errors.AssertionFailedf("unknown policy command %v", cmd))
}
}

// optRLSConstraintBuilder is used synthesize a check constraint to enforce the
// RLS policies for new rows.
type optRLSConstraintBuilder struct {
tab cat.Table
md *opt.Metadata
tabMeta *opt.TableMeta
oc cat.Catalog
user username.SQLUsername
isUpdate bool
}

// Build will construct a CheckConstraint to enforce the policies for the
// current user and command.
func (r *optRLSConstraintBuilder) Build(ctx context.Context) cat.CheckConstraint {
expr, colIDs := r.genExpression(ctx)
if expr == "" {
panic(fmt.Sprintf("must return some expression but empty string returned for user: %v", r.user))
}
return &rlsCheckConstraint{
constraint: expr,
columnCount: len(colIDs),
lookupColumnOrdinal: func(i int) (int, error) {
return r.tab.LookupColumnOrdinal(descpb.ColumnID(colIDs[i]))
},
}
}

// genExpression builds the expression that will be used within the check
// constraint built for RLS.
func (r *optRLSConstraintBuilder) genExpression(ctx context.Context) (string, []int) {
var sb strings.Builder

// colIDs tracks the column IDs referenced in all the policy expressions
// that are applied. We use a set as we need to combine the columns used
// for multiple policies.
var colIDs intsets.Fast

// Admin users are exempt from any RLS policies.
isAdmin, err := r.oc.UserHasAdminRole(ctx, r.user)
if err != nil {
panic(err)
}
r.md.SetRLSEnabled(r.user, isAdmin, r.tabMeta.MetaID)
if isAdmin {
// Return a constraint check that always passes.
return "true", nil
}

var policiesUsed opt.PolicyIDSet
for i := range r.tab.Policies().Permissive {
p := &r.tab.Policies().Permissive[i]

if !p.AppliesToRole(r.user) || !r.policyAppliesToCommand(p, r.isUpdate) {
continue
}
policiesUsed.Add(p.ID)
var expr string
// If the WITH CHECK expression is missing, we default to the USING
// expression. If both are missing, then this policy doesn't apply and can
// be skipped.
if p.WithCheckExpr == "" {
if p.UsingExpr == "" {
continue
}
expr = p.UsingExpr
for _, id := range p.UsingColumnIDs {
colIDs.Add(int(id))
}
} else {
expr = p.WithCheckExpr
for _, id := range p.WithCheckColumnIDs {
colIDs.Add(int(id))
}
}
if sb.Len() != 0 {
sb.WriteString(" OR ")
}
sb.WriteString("(")
sb.WriteString(expr)
sb.WriteString(")")
// TODO(136742): Add support for multiple policies.
r.md.GetRLSMeta().AddPoliciesUsed(r.tabMeta.MetaID, policiesUsed)
break
}

// TODO(136742): Add support for restrictive policies.

// If no policies apply, then we will add a false check as nothing is allowed
// to be written.
if sb.Len() == 0 {
r.md.GetRLSMeta().NoPoliciesApplied = true
return "false", nil
}

return sb.String(), colIDs.Ordered()
}

// policyAppliesToCommand will return true iff the command set in the policy
// applies to the current mutation action.
func (r *optRLSConstraintBuilder) policyAppliesToCommand(policy *cat.Policy, isUpdate bool) bool {
switch policy.Command {
case catpb.PolicyCommand_ALL:
return true
case catpb.PolicyCommand_SELECT, catpb.PolicyCommand_DELETE:
return false
case catpb.PolicyCommand_INSERT:
return !isUpdate
case catpb.PolicyCommand_UPDATE:
return isUpdate
default:
panic(errors.AssertionFailedf("unknown policy command %v", policy.Command))
}
}

// rlsCheckConstraint is an implementation of cat.CheckConstraint for the
// check constraint built to enforce the RLS policies on write.
type rlsCheckConstraint struct {
constraint string
columnCount int
lookupColumnOrdinal func(i int) (int, error)
}

// Constraint implements the cat.CheckConstraint interface.
func (r *rlsCheckConstraint) Constraint() string { return r.constraint }

// Validated implements the cat.CheckConstraint interface.
func (r *rlsCheckConstraint) Validated() bool { return true }

// ColumnCount implements the cat.CheckConstraint interface.
func (r *rlsCheckConstraint) ColumnCount() int { return r.columnCount }

// ColumnOrdinal implements the cat.CheckConstraint interface.
func (r *rlsCheckConstraint) ColumnOrdinal(i int) int {
ord, err := r.lookupColumnOrdinal(i)
if err != nil {
panic(err)
}
return ord
}

// IsRLSConstraint implements the cat.CheckConstraint interface.
func (r *rlsCheckConstraint) IsRLSConstraint() bool { return true }
Loading

0 comments on commit 5136c0c

Please sign in to comment.