Skip to content

Commit

Permalink
Properly consider bases when validating interceptors
Browse files Browse the repository at this point in the history
Fixes #3664
  • Loading branch information
raphael committed Mar 9, 2025
1 parent 8037232 commit d6e7b3e
Show file tree
Hide file tree
Showing 2 changed files with 267 additions and 37 deletions.
108 changes: 71 additions & 37 deletions expr/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,63 +46,97 @@ func (i *InterceptorExpr) validate(m *MethodExpr) *eval.ValidationErrors {
verr := new(eval.ValidationErrors)

if i.ReadPayload != nil || i.WritePayload != nil {
payloadObj := AsObject(m.Payload.Type)
if payloadObj == nil {
if !IsObject(m.Payload.Type) {
verr.Add(m, "interceptor %q cannot be applied because the method payload is not an object", i.Name)
}
if i.ReadPayload != nil {
i.validateAttributeAccess(m, "read payload", verr, payloadObj, i.ReadPayload)
}
if i.WritePayload != nil {
i.validateAttributeAccess(m, "write payload", verr, payloadObj, i.WritePayload)
} else {
payload := DupAtt(m.Payload)
if m.Payload.Bases != nil {
for _, base := range m.Payload.Bases {
if ut, ok := base.(UserType); ok {
payload.Merge(ut.Attribute())
}
}
}
if i.ReadPayload != nil {
i.validateAttributeAccess(m, "read payload", verr, AsObject(payload.Type), i.ReadPayload)
}
if i.WritePayload != nil {
i.validateAttributeAccess(m, "write payload", verr, AsObject(payload.Type), i.WritePayload)
}
}
}

if i.ReadResult != nil || i.WriteResult != nil {
if m.IsResultStreaming() {
verr.Add(m, "interceptor %q cannot be applied because the method result is streaming", i.Name)
}
resultObj := AsObject(m.Result.Type)
if resultObj == nil {
if !IsObject(m.Result.Type) {
verr.Add(m, "interceptor %q cannot be applied because the method result is not an object", i.Name)
}
if i.ReadResult != nil {
i.validateAttributeAccess(m, "read result", verr, resultObj, i.ReadResult)
}
if i.WriteResult != nil {
i.validateAttributeAccess(m, "write result", verr, resultObj, i.WriteResult)
} else {
result := DupAtt(m.Result)
if m.Result.Bases != nil {
for _, base := range m.Result.Bases {
if ut, ok := base.(UserType); ok {
result.Merge(ut.Attribute())
}
}
}
if i.ReadResult != nil {
i.validateAttributeAccess(m, "read result", verr, AsObject(result.Type), i.ReadResult)
}
if i.WriteResult != nil {
i.validateAttributeAccess(m, "write result", verr, AsObject(result.Type), i.WriteResult)
}
}
}

if i.ReadStreamingPayload != nil || i.WriteStreamingPayload != nil {
if !m.IsPayloadStreaming() {
if !m.IsPayloadStreaming() || m.StreamingPayload == nil {
verr.Add(m, "interceptor %q cannot be applied because the method payload is not streaming", i.Name)
}
payloadObj := AsObject(m.StreamingPayload.Type)
if payloadObj == nil {
verr.Add(m, "interceptor %q cannot be applied because the method payload is not an object", i.Name)
}
if i.ReadStreamingPayload != nil {
i.validateAttributeAccess(m, "read streaming payload", verr, payloadObj, i.ReadStreamingPayload)
}
if i.WriteStreamingPayload != nil {
i.validateAttributeAccess(m, "write streaming payload", verr, payloadObj, i.WriteStreamingPayload)
} else {
if !IsObject(m.StreamingPayload.Type) {
verr.Add(m, "interceptor %q cannot be applied because the method payload is not an object", i.Name)
} else {
payload := DupAtt(m.StreamingPayload)
if m.StreamingPayload.Bases != nil {
for _, base := range m.StreamingPayload.Bases {
if ut, ok := base.(UserType); ok {
payload.Merge(ut.Attribute())
}
}
}
if i.ReadStreamingPayload != nil {
i.validateAttributeAccess(m, "read streaming payload", verr, AsObject(payload.Type), i.ReadStreamingPayload)
}
if i.WriteStreamingPayload != nil {
i.validateAttributeAccess(m, "write streaming payload", verr, AsObject(payload.Type), i.WriteStreamingPayload)
}
}
}
}

if i.ReadStreamingResult != nil || i.WriteStreamingResult != nil {
if !m.IsResultStreaming() {
verr.Add(m, "interceptor %q cannot be applied because the method result is not streaming", i.Name)
}
resultObj := AsObject(m.Result.Type)
if resultObj == nil {
verr.Add(m, "interceptor %q cannot be applied because the method result is not an object", i.Name)
}
if i.ReadStreamingResult != nil {
i.validateAttributeAccess(m, "read streaming result", verr, resultObj, i.ReadStreamingResult)
}
if i.WriteStreamingResult != nil {
i.validateAttributeAccess(m, "write streaming result", verr, resultObj, i.WriteStreamingResult)
} else {
if !IsObject(m.Result.Type) {
verr.Add(m, "interceptor %q cannot be applied because the method result is not an object", i.Name)
} else {
result := DupAtt(m.Result)
if m.Result.Bases != nil {
for _, base := range m.Result.Bases {
if ut, ok := base.(UserType); ok {
result.Merge(ut.Attribute())
}
}
}
if i.ReadStreamingResult != nil {
i.validateAttributeAccess(m, "read streaming result", verr, AsObject(result.Type), i.ReadStreamingResult)
}
if i.WriteStreamingResult != nil {
i.validateAttributeAccess(m, "write streaming result", verr, AsObject(result.Type), i.WriteStreamingResult)
}
}
}
}

Expand Down
196 changes: 196 additions & 0 deletions expr/interceptor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
package expr

import (
"testing"
)

func TestInterceptorExpr_Validate(t *testing.T) {
cases := map[string]struct {
intercept *InterceptorExpr
method *MethodExpr
wantErrors []string
}{
"valid-payload": {
intercept: makeInterceptor(t, "test-interceptor", withReadPayload(t, namedAttr(t, "foo"))),
method: makeMethod(t, withPayload(t, namedAttr(t, "foo"))),
},
"valid-write-payload": {
intercept: makeInterceptor(t, "test-interceptor", withWritePayload(t, namedAttr(t, "foo"))),
method: makeMethod(t, withPayload(t, namedAttr(t, "foo"))),
},
"payload-with-base": {
intercept: makeInterceptor(t, "test-interceptor", withReadPayload(t, namedAttr(t, "bar"))),
method: makeMethod(t,
withPayload(t, namedAttr(t, "foo")),
withPayloadBases(t, &UserTypeExpr{
AttributeExpr: &AttributeExpr{
Type: &Object{namedAttr(t, "bar")},
},
}),
),
},
"result-with-base": {
intercept: makeInterceptor(t, "test-interceptor", withReadResult(t, namedAttr(t, "bar"))),
method: makeMethod(t,
withResult(t, namedAttr(t, "foo")),
withResultBases(t, &UserTypeExpr{
AttributeExpr: &AttributeExpr{
Type: &Object{namedAttr(t, "bar")},
},
}),
),
},
"invalid-payload-not-object": {
intercept: makeInterceptor(t, "test-interceptor", withReadPayload(t, namedAttr(t, "foo"))),
method: makeMethod(t, func(m *MethodExpr) {
m.Payload = &AttributeExpr{
Type: String,
}
}),
wantErrors: []string{
`interceptor "test-interceptor" cannot be applied because the method payload is not an object`,
},
},
"invalid-streaming-payload-not-streaming": {
intercept: makeInterceptor(t, "test-interceptor", withReadStreamingPayload(t, namedAttr(t, "foo"))),
method: makeMethod(t, func(m *MethodExpr) {
m.Payload = &AttributeExpr{Type: &Object{}}
}),
wantErrors: []string{
`interceptor "test-interceptor" cannot be applied because the method payload is not streaming`,
},
},
"invalid-streaming-result-not-streaming": {
intercept: makeInterceptor(t, "test-interceptor", withReadStreamingResult(t, namedAttr(t, "foo"))),
method: makeMethod(t, func(m *MethodExpr) {
m.Result = &AttributeExpr{Type: &Object{}}
}),
wantErrors: []string{
`interceptor "test-interceptor" cannot be applied because the method result is not streaming`,
},
},
"invalid-attribute-access": {
intercept: makeInterceptor(t, "test-interceptor", withReadPayload(t, namedAttr(t, "bar"))),
method: makeMethod(t, withPayload(t, namedAttr(t, "foo"))),
wantErrors: []string{
`interceptor "test-interceptor" cannot read payload attribute "bar": attribute does not exist`,
},
},
}

for name, tc := range cases {
t.Run(name, func(t *testing.T) {
verr := tc.intercept.validate(tc.method)
if len(tc.wantErrors) != len(verr.Errors) {
t.Errorf("got %d errors, expected %d", len(verr.Errors), len(tc.wantErrors))
}
for i, err := range verr.Errors {
if i >= len(tc.wantErrors) {
break
}
if got := err.Error(); got != tc.wantErrors[i] {
t.Errorf("got error %q, expected %q", got, tc.wantErrors[i])
}
}
})
}
}

// Test helpers (at the end of the file)
func withWritePayload(t *testing.T, attrs *NamedAttributeExpr) func(*InterceptorExpr) {
t.Helper()
return func(i *InterceptorExpr) {
i.WritePayload = &AttributeExpr{Type: &Object{attrs}}
}
}

func withResultBases(t *testing.T, bases ...DataType) func(*MethodExpr) {
t.Helper()
return func(m *MethodExpr) {
if m.Result == nil {
m.Result = &AttributeExpr{Type: &Object{}}
}
m.Result.Bases = bases
}
}

func makeInterceptor(t *testing.T, name string, opts ...func(*InterceptorExpr)) *InterceptorExpr {
t.Helper()
i := &InterceptorExpr{Name: name}
for _, opt := range opts {
opt(i)
}
return i
}

// Helper functions need to be updated to handle empty attributes properly
func withReadPayload(t *testing.T, attrs *NamedAttributeExpr) func(*InterceptorExpr) {
t.Helper()
return func(i *InterceptorExpr) {
i.ReadPayload = &AttributeExpr{Type: &Object{attrs}}
}
}

func withReadResult(t *testing.T, attrs *NamedAttributeExpr) func(*InterceptorExpr) {
t.Helper()
return func(i *InterceptorExpr) {
i.ReadResult = &AttributeExpr{Type: &Object{attrs}}
}
}

func withReadStreamingPayload(t *testing.T, attrs *NamedAttributeExpr) func(*InterceptorExpr) {
t.Helper()
return func(i *InterceptorExpr) {
i.ReadStreamingPayload = &AttributeExpr{Type: &Object{attrs}}
}
}

func withReadStreamingResult(t *testing.T, attrs *NamedAttributeExpr) func(*InterceptorExpr) {
t.Helper()
return func(i *InterceptorExpr) {
i.ReadStreamingResult = &AttributeExpr{Type: &Object{attrs}}
}
}

func makeMethod(t *testing.T, opts ...func(*MethodExpr)) *MethodExpr {
t.Helper()
m := &MethodExpr{}
for _, opt := range opts {
opt(m)
}
return m
}

func withPayload(t *testing.T, attrs *NamedAttributeExpr) func(*MethodExpr) {
t.Helper()
return func(m *MethodExpr) {
m.Payload = &AttributeExpr{Type: &Object{attrs}}
}
}

func withPayloadBases(t *testing.T, bases ...DataType) func(*MethodExpr) {
t.Helper()
return func(m *MethodExpr) {
if m.Payload == nil {
m.Payload = &AttributeExpr{Type: &Object{}}
}
m.Payload.Bases = bases
}
}

func withResult(t *testing.T, attrs *NamedAttributeExpr) func(*MethodExpr) {
t.Helper()
return func(m *MethodExpr) {
m.Result = &AttributeExpr{Type: &Object{attrs}}
}
}

func namedAttr(t *testing.T, name string) *NamedAttributeExpr {
t.Helper()
return &NamedAttributeExpr{
Name: name,
Attribute: &AttributeExpr{
Type: String,
},
}
}

0 comments on commit d6e7b3e

Please sign in to comment.