Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sql/opt: Generate synthetic check constraint to enforce RLS policies for new rows #141614

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pkg/sql/opt/cat/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,16 @@ type Policy struct {
// read operations. If the policy does not define a USING expression, this is
// an empty string.
UsingExpr string
// UsingColumnIDs is a set of column IDs that are referenced in the USING
// expression.
UsingColumnIDs descpb.ColumnIDs
// WithCheckExpr is the optional validation expression applied to new rows
// during write operations. If the policy does not define a WITH CHECK expression,
// this is an empty string.
WithCheckExpr string
// WithCheckColumnIDs is a set of column IDs that are referenced in the WITH
// CHECK expression.
WithCheckColumnIDs descpb.ColumnIDs
// Command is the command that the policy was defined for.
Command catpb.PolicyCommand
// roles are the roles the applies to. If the policy applies to all roles (aka
Expand Down
15 changes: 9 additions & 6 deletions pkg/sql/opt/cat/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ type Table interface {
// that they cannot be mutated.
IsMaterializedView() bool

// LookupColumnOrdinal returns the ordinal of the column with the given ID.
LookupColumnOrdinal(colID descpb.ColumnID) (int, error)

// ColumnCount returns the number of columns in the table. This includes
// public columns, write-only columns, etc.
ColumnCount() int
Expand Down Expand Up @@ -185,12 +188,8 @@ type Table interface {
// IsRowLevelSecurityEnabled is true if policies should be applied during the query.
IsRowLevelSecurityEnabled() bool

// PolicyCount returns the number of policies in the table for the given type.
PolicyCount(polType tree.PolicyType) int

// Policy retrieves the policy of the specified type at the given index (i),
// where i < PolicyCount for the specified type.
Policy(polType tree.PolicyType, i int) Policy
// Policies returns all the policies defined for this table.
Policies() *Policies
}

// CheckConstraint represents a check constraint on a table. Check constraints
Expand All @@ -212,6 +211,10 @@ type CheckConstraint interface {
// ColumnOrdinal returns the table column ordinal of the ith column in this
// constraint.
ColumnOrdinal(i int) int

// IsRLSConstraint is true if this is a constraint used to enforce
// row-level security policies.
IsRLSConstraint() bool
}

// TableStatistic is an interface to a table statistic. Each statistic is
Expand Down
5 changes: 5 additions & 0 deletions pkg/sql/opt/cat/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ func FormatTable(
}

for i := 0; i < tab.CheckCount(); i++ {
// We only show constraints that are constant and known when the catalog is
// built. For this reason, skip the one we add for row-level security.
if tab.Check(i).IsRLSConstraint() {
continue
}
child.Childf("CHECK (%s)", MaybeMarkRedactable(tab.Check(i).Constraint(), redactableValues))
}

Expand Down
6 changes: 4 additions & 2 deletions pkg/sql/opt/exec/explain/emit.go
Original file line number Diff line number Diff line change
Expand Up @@ -1366,8 +1366,10 @@ func (e *emitter) emitPolicies(
ob.AddField("policies", "row-level security enabled, no policies applied.")
} else {
var sb strings.Builder
for i := 0; i < table.PolicyCount(tree.PolicyTypePermissive); i++ {
policy := table.Policy(tree.PolicyTypePermissive, i)
policies := table.Policies()
// TODO(136742): Add support for restrictive policies.
for i := range policies.Permissive {
policy := policies.Permissive[i]
if applied.Policies.Contains(policy.ID) {
if sb.Len() > 0 {
sb.WriteString(", ")
Expand Down
13 changes: 6 additions & 7 deletions pkg/sql/opt/exec/explain/plan_gist_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,10 @@ func (u *unknownTable) IsMaterializedView() bool {
return false
}

func (u *unknownTable) LookupColumnOrdinal(descpb.ColumnID) (int, error) {
panic(errors.AssertionFailedf("not implemented"))
}

func (u *unknownTable) ColumnCount() int {
return 0
}
Expand Down Expand Up @@ -662,13 +666,8 @@ func (u *unknownTable) Trigger(i int) cat.Trigger {
// IsRowLevelSecurityEnabled is part of the cat.Table interface
func (u *unknownTable) IsRowLevelSecurityEnabled() bool { return false }

// PolicyCount is part of the cat.Table interface
func (u *unknownTable) PolicyCount(polType tree.PolicyType) int { return 0 }

// Policy is part of the cat.Table interface
func (u *unknownTable) Policy(polType tree.PolicyType, i int) cat.Policy {
panic(errors.AssertionFailedf("not implemented"))
}
// Policies is part of the cat.Table interface.
func (u *unknownTable) Policies() *cat.Policies { return nil }

var _ cat.Table = &unknownTable{}

Expand Down
45 changes: 39 additions & 6 deletions pkg/sql/opt/optbuilder/mutation_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -869,9 +869,31 @@ func (mb *mutationBuilder) addCheckConstraintCols(isUpdate bool) {
projectionsScope := mb.outScope.replace()
projectionsScope.appendColumnsFromScope(mb.outScope)
mutationCols := mb.mutationColumnIDs()
var seenRLSConstraint bool

for i, n := 0, mb.tab.CheckCount(); i < n; i++ {
check := mb.tab.Check(i)

// For tables with RLS enabled, we create a synthetic check constraint
// to enforce the policies. Since this check varies based on the role
// and command used, it must be generated each time it is needed rather
// than being included with the table's actual check constraints.
if check.IsRLSConstraint() {
if seenRLSConstraint {
panic(errors.AssertionFailedf("a table should only have one RLS constraint"))
}
seenRLSConstraint = true
chkBuilder := optRLSConstraintBuilder{
tab: mb.tab,
md: mb.md,
tabMeta: mb.md.TableMeta(mb.tabID),
oc: mb.b.catalog,
user: mb.b.checkPrivilegeUser,
isUpdate: isUpdate,
}
check = chkBuilder.Build(mb.b.ctx)
}

expr, err := parser.ParseExpr(check.Constraint())
if err != nil {
panic(err)
Expand All @@ -881,19 +903,30 @@ func (mb *mutationBuilder) addCheckConstraintCols(isUpdate bool) {

// Use an anonymous name because the column cannot be referenced
// in other expressions.
colName := scopeColName("").WithMetadataName(fmt.Sprintf("check%d", i+1))
colName := scopeColName("")
if check.IsRLSConstraint() {
colName = colName.WithMetadataName("rls")
} else {
colName = colName.WithMetadataName(fmt.Sprintf("check%d", i+1))
}
scopeCol := projectionsScope.addColumn(colName, texpr)

// TODO(ridwanmsharif): Maybe we can avoid building constraints here
// and instead use the constraints stored in the table metadata.
referencedCols := &opt.ColSet{}
mb.b.buildScalar(texpr, mb.outScope, projectionsScope, scopeCol, referencedCols)

// If the mutation is not an UPDATE, track the synthesized check
// columns in checkColIDS. If the mutation is an UPDATE, only track
// the check columns if the columns referenced in the check
// expression are being mutated.
if !isUpdate || referencedCols.Intersects(mutationCols) {
// For non-UPDATE mutations, track the synthesized check columns in
// checkColIDs. For UPDATE mutations, track the check columns in two
// scenarios:
// - If the check expression is a real check constraint and the columns
// referenced in the check expression are being mutated.
// - If the check expression is a synthetic one used for row-level
// security (RLS). Since it's not a real check expression, different
// expressions can exist for read and write operations. This means it's
// possible to read a row whose column values would violate the write
// expression.
if !isUpdate || check.IsRLSConstraint() || referencedCols.Intersects(mutationCols) {
mb.checkColIDs[i] = scopeCol.id

// TODO(michae2): Under weaker isolation levels we need to use shared
Expand Down
154 changes: 150 additions & 4 deletions pkg/sql/opt/optbuilder/row_level_security.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@
package optbuilder

import (
"context"
"fmt"
"strings"

"github.com/cockroachdb/cockroach/pkg/security/username"
"github.com/cockroachdb/cockroach/pkg/sql/catalog/catpb"
"github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb"
"github.com/cockroachdb/cockroach/pkg/sql/opt"
"github.com/cockroachdb/cockroach/pkg/sql/opt/cat"
"github.com/cockroachdb/cockroach/pkg/sql/opt/memo"
"github.com/cockroachdb/cockroach/pkg/sql/parser"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/types"
"github.com/cockroachdb/cockroach/pkg/util/intsets"
"github.com/cockroachdb/errors"
)

Expand Down Expand Up @@ -48,9 +54,8 @@ func (b *Builder) buildRowLevelSecurityUsingExpression(
tabMeta *opt.TableMeta, tableScope *scope, cmdScope cat.PolicyCommandScope,
) opt.ScalarExpr {
var policiesUsed opt.PolicyIDSet
for i := 0; i < tabMeta.Table.PolicyCount(tree.PolicyTypePermissive); i++ {
policy := tabMeta.Table.Policy(tree.PolicyTypePermissive, i)

policies := tabMeta.Table.Policies()
for _, policy := range policies.Permissive {
if !policy.AppliesToRole(b.checkPrivilegeUser) || !b.policyAppliesToCommandScope(policy, cmdScope) {
continue
}
Expand Down Expand Up @@ -102,3 +107,144 @@ func (b *Builder) policyAppliesToCommandScope(
panic(errors.AssertionFailedf("unknown policy command %v", cmd))
}
}

// optRLSConstraintBuilder is used synthesize a check constraint to enforce the
// RLS policies for new rows.
type optRLSConstraintBuilder struct {
tab cat.Table
md *opt.Metadata
tabMeta *opt.TableMeta
oc cat.Catalog
user username.SQLUsername
isUpdate bool
}

// Build will construct a CheckConstraint to enforce the policies for the
// current user and command.
func (r *optRLSConstraintBuilder) Build(ctx context.Context) cat.CheckConstraint {
expr, colIDs := r.genExpression(ctx)
if expr == "" {
panic(fmt.Sprintf("must return some expression but empty string returned for user: %v", r.user))
}
return &rlsCheckConstraint{
constraint: expr,
colIDs: colIDs,
tab: r.tab,
}
}

// genExpression builds the expression that will be used within the check
// constraint built for RLS.
func (r *optRLSConstraintBuilder) genExpression(ctx context.Context) (string, []int) {
var sb strings.Builder

// colIDs tracks the column IDs referenced in all the policy expressions
// that are applied. We use a set as we need to combine the columns used
// for multiple policies.
var colIDs intsets.Fast

// Admin users are exempt from any RLS policies.
isAdmin, err := r.oc.UserHasAdminRole(ctx, r.user)
if err != nil {
panic(err)
}
r.md.SetRLSEnabled(r.user, isAdmin, r.tabMeta.MetaID)
if isAdmin {
// Return a constraint check that always passes.
return "true", nil
}

var policiesUsed opt.PolicyIDSet
for i := range r.tab.Policies().Permissive {
p := &r.tab.Policies().Permissive[i]

if !p.AppliesToRole(r.user) || !r.policyAppliesToCommand(p, r.isUpdate) {
continue
}
policiesUsed.Add(p.ID)
var expr string
// If the WITH CHECK expression is missing, we default to the USING
// expression. If both are missing, then this policy doesn't apply and can
// be skipped.
if p.WithCheckExpr == "" {
if p.UsingExpr == "" {
continue
}
expr = p.UsingExpr
for _, id := range p.UsingColumnIDs {
colIDs.Add(int(id))
}
} else {
expr = p.WithCheckExpr
for _, id := range p.WithCheckColumnIDs {
colIDs.Add(int(id))
}
}
if sb.Len() != 0 {
sb.WriteString(" OR ")
}
sb.WriteString("(")
sb.WriteString(expr)
sb.WriteString(")")
// TODO(136742): Add support for multiple policies.
r.md.GetRLSMeta().AddPoliciesUsed(r.tabMeta.MetaID, policiesUsed)
break
}

// TODO(136742): Add support for restrictive policies.

// If no policies apply, then we will add a false check as nothing is allowed
// to be written.
if sb.Len() == 0 {
r.md.GetRLSMeta().NoPoliciesApplied = true
return "false", nil
}

return sb.String(), colIDs.Ordered()
}

// policyAppliesToCommand will return true iff the command set in the policy
// applies to the current mutation action.
func (r *optRLSConstraintBuilder) policyAppliesToCommand(policy *cat.Policy, isUpdate bool) bool {
switch policy.Command {
case catpb.PolicyCommand_ALL:
return true
case catpb.PolicyCommand_SELECT, catpb.PolicyCommand_DELETE:
return false
case catpb.PolicyCommand_INSERT:
return !isUpdate
case catpb.PolicyCommand_UPDATE:
return isUpdate
default:
panic(errors.AssertionFailedf("unknown policy command %v", policy.Command))
}
}

// rlsCheckConstraint is an implementation of cat.CheckConstraint for the
// check constraint built to enforce the RLS policies on write.
type rlsCheckConstraint struct {
constraint string
colIDs []int
tab cat.Table
}

// Constraint implements the cat.CheckConstraint interface.
func (r *rlsCheckConstraint) Constraint() string { return r.constraint }

// Validated implements the cat.CheckConstraint interface.
func (r *rlsCheckConstraint) Validated() bool { return true }

// ColumnCount implements the cat.CheckConstraint interface.
func (r *rlsCheckConstraint) ColumnCount() int { return len(r.colIDs) }

// ColumnOrdinal implements the cat.CheckConstraint interface.
func (r *rlsCheckConstraint) ColumnOrdinal(i int) int {
ord, err := r.tab.LookupColumnOrdinal(descpb.ColumnID(r.colIDs[i]))
if err != nil {
panic(err)
}
return ord
}

// IsRLSConstraint implements the cat.CheckConstraint interface.
func (r *rlsCheckConstraint) IsRLSConstraint() bool { return true }
Loading