From ee8de9f70f24839b53fe1baca5a69d1caa828592 Mon Sep 17 00:00:00 2001 From: Tristan Cartledge Date: Mon, 25 Nov 2024 13:44:11 +1000 Subject: [PATCH] fix: more robust support for navigating custom structures via jsonpointer --- jsonpointer/jsonpointer.go | 73 ++++++++++---- jsonpointer/jsonpointer_test.go | 162 +++++++++++++++++++++++++++++++- jsonschema/oas31/core/value.go | 6 +- marshaller/node.go | 4 +- 4 files changed, 223 insertions(+), 22 deletions(-) diff --git a/jsonpointer/jsonpointer.go b/jsonpointer/jsonpointer.go index 6ba2713..df84cf6 100644 --- a/jsonpointer/jsonpointer.go +++ b/jsonpointer/jsonpointer.go @@ -10,9 +10,14 @@ import ( ) const ( - ErrNotFound = errors.Error("not found") + // ErrNotFound is returned when the target is not found. + ErrNotFound = errors.Error("not found") + // ErrInvalidPath is returned when the path is invalid. ErrInvalidPath = errors.Error("invalid path") - ErrValidation = errors.Error("validation error") + // ErrValidation is returned when the jsonpointer is invalid. + ErrValidation = errors.Error("validation error") + // ErrSkipInterface is returned when this implementation of the interface is not applicable to the current type. + ErrSkipInterface = errors.Error("skip interface") ) const ( @@ -171,26 +176,41 @@ type IndexNavigable interface { // NavigableNoder is an interface that can be implemented by a struct to allow returning an alternative node to evaluate instead of the struct itself. type NavigableNoder interface { - GetNavigableNode() any + GetNavigableNode() (any, error) } func getStructTarget(sourceVal reflect.Value, currentPart navigationPart, stack []navigationPart, currentPath string, o *options) (any, []navigationPart, error) { - sourceValElem := reflect.Indirect(sourceVal) - - if currentPart.Type != partTypeKey { - return nil, nil, ErrInvalidPath.Wrap(fmt.Errorf("expected key, got %s at %s", currentPart.Type, currentPath)) - } - - if sourceVal.Type().Implements(reflect.TypeOf((*KeyNavigable)(nil)).Elem()) { - return getNavigableWithKeyTarget(sourceVal, currentPart, stack, currentPath, o) + if sourceVal.Type().Implements(reflect.TypeOf((*NavigableNoder)(nil)).Elem()) { + val, stack, err := getNavigableNoderTarget(sourceVal, currentPart, stack, currentPath, o) + if err != nil { + if !errors.Is(err, ErrSkipInterface) { + return nil, nil, err + } + } else { + return val, stack, nil + } } - if sourceVal.Type().Implements(reflect.TypeOf((*IndexNavigable)(nil)).Elem()) { - return getNavigableWithIndexTarget(sourceVal, currentPart, stack, currentPath, o) + switch currentPart.Type { + case partTypeKey: + return getKeyBasedStructTarget(sourceVal, currentPart, stack, currentPath, o) + case partTypeIndex: + return getIndexBasedStructTarget(sourceVal, currentPart, stack, currentPath, o) + default: + return nil, nil, ErrInvalidPath.Wrap(fmt.Errorf("expected key or index, got %s at %s", currentPart.Type, currentPath)) } +} - if sourceVal.Type().Implements(reflect.TypeOf((*NavigableNoder)(nil)).Elem()) { - return getNavigableNoderTarget(sourceVal, currentPart, stack, currentPath, o) +func getKeyBasedStructTarget(sourceVal reflect.Value, currentPart navigationPart, stack []navigationPart, currentPath string, o *options) (any, []navigationPart, error) { + if sourceVal.Type().Implements(reflect.TypeOf((*KeyNavigable)(nil)).Elem()) { + val, stack, err := getNavigableWithKeyTarget(sourceVal, currentPart, stack, currentPath, o) + if err != nil { + if !errors.Is(err, ErrSkipInterface) { + return nil, nil, err + } + } else { + return val, stack, nil + } } if sourceVal.Kind() == reflect.Ptr && sourceVal.IsNil() { @@ -199,6 +219,8 @@ func getStructTarget(sourceVal reflect.Value, currentPart navigationPart, stack key := currentPart.unescapeValue() + sourceValElem := reflect.Indirect(sourceVal) + for i := 0; i < sourceValElem.NumField(); i++ { field := sourceValElem.Type().Field(i) if !field.IsExported() { @@ -227,6 +249,22 @@ func getStructTarget(sourceVal reflect.Value, currentPart navigationPart, stack return nil, nil, ErrNotFound.Wrap(fmt.Errorf("key %s not found in %v at %s", key, sourceVal.Type(), currentPath)) } +func getIndexBasedStructTarget(sourceVal reflect.Value, currentPart navigationPart, stack []navigationPart, currentPath string, o *options) (any, []navigationPart, error) { + if sourceVal.Type().Implements(reflect.TypeOf((*IndexNavigable)(nil)).Elem()) { + val, stack, err := getNavigableWithIndexTarget(sourceVal, currentPart, stack, currentPath, o) + if err != nil { + if errors.Is(err, ErrSkipInterface) { + return nil, nil, fmt.Errorf("can't navigate by index on %s at %s", sourceVal.Type(), currentPath) + } + return nil, nil, err + } else { + return val, stack, nil + } + } else { + return nil, nil, ErrNotFound.Wrap(fmt.Errorf("expected IndexNavigable, got %s at %s", sourceVal.Kind(), currentPath)) + } +} + func getNavigableWithKeyTarget(sourceVal reflect.Value, currentPart navigationPart, stack []navigationPart, currentPath string, o *options) (any, []navigationPart, error) { if sourceVal.Kind() == reflect.Ptr && sourceVal.IsNil() { return nil, nil, ErrNotFound.Wrap(fmt.Errorf("source is nil at %s", currentPath)) @@ -277,7 +315,10 @@ func getNavigableNoderTarget(sourceVal reflect.Value, currentPart navigationPart return nil, nil, ErrNotFound.Wrap(fmt.Errorf("expected navigableNoder, got %s at %s", sourceVal.Kind(), currentPath)) } - value := nn.GetNavigableNode() + value, err := nn.GetNavigableNode() + if err != nil { + return nil, nil, err + } return getTarget(value, currentPart, stack, currentPath, o) } diff --git a/jsonpointer/jsonpointer_test.go b/jsonpointer/jsonpointer_test.go index 85609ce..50e59d4 100644 --- a/jsonpointer/jsonpointer_test.go +++ b/jsonpointer/jsonpointer_test.go @@ -2,6 +2,7 @@ package jsonpointer import ( "errors" + "fmt" "testing" "github.com/speakeasy-api/openapi/sequencedmap" @@ -382,7 +383,166 @@ func TestGetTarget_Error(t *testing.T) { source: TestStruct{}, pointer: JSONPointer("/1"), }, - wantErr: errors.New("invalid path -- expected key, got index at /1"), + wantErr: errors.New("not found -- expected IndexNavigable, got struct at /1"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + target, err := GetTarget(tt.args.source, tt.args.pointer, tt.args.opts...) + assert.EqualError(t, err, tt.wantErr.Error()) + assert.Nil(t, target) + }) + } +} + +type InterfaceTestStruct struct { + typ string + valuesByKey map[string]any + valuesByIndex []any + Field1 any + Field2 any +} + +var ( + _ KeyNavigable = (*InterfaceTestStruct)(nil) + _ IndexNavigable = (*InterfaceTestStruct)(nil) +) + +func (t InterfaceTestStruct) NavigateWithKey(key string) (any, error) { + switch t.typ { + case "map": + return t.valuesByKey[key], nil + case "struct": + return nil, ErrSkipInterface + case "slice": + return nil, ErrInvalidPath + default: + return nil, fmt.Errorf("unknown type %s", t.typ) + } +} + +func (t InterfaceTestStruct) NavigateWithIndex(index int) (any, error) { + switch t.typ { + case "map": + return nil, ErrInvalidPath + case "struct": + return nil, ErrSkipInterface + case "slice": + return t.valuesByIndex[index], nil + default: + return nil, fmt.Errorf("unknown type %s", t.typ) + } +} + +type NavigableNodeWrapper struct { + typ string + NavigableNode InterfaceTestStruct + Field1 any + Field2 any +} + +var _ NavigableNoder = (*NavigableNodeWrapper)(nil) + +func (n NavigableNodeWrapper) GetNavigableNode() (any, error) { + switch n.typ { + case "wrapper": + return n.NavigableNode, nil + case "struct": + return nil, ErrSkipInterface + case "other": + return nil, ErrInvalidPath + default: + return nil, fmt.Errorf("unknown type %s", n.typ) + } +} + +func TestGetTarget_WithInterfaces_Success(t *testing.T) { + type args struct { + source any + pointer JSONPointer + opts []option + } + tests := []struct { + name string + args args + want any + }{ + { + name: "KeyNavigable succeeds", + args: args{ + source: InterfaceTestStruct{typ: "map", valuesByKey: map[string]any{"key1": "value1"}}, + pointer: JSONPointer("/key1"), + }, + want: "value1", + }, + { + name: "IndexNavigable succeeds", + args: args{ + source: InterfaceTestStruct{typ: "slice", valuesByIndex: []any{"value1", "value2"}}, + pointer: JSONPointer("/1"), + }, + want: "value2", + }, + { + name: "Struct is navigable", + args: args{ + source: InterfaceTestStruct{typ: "struct", Field1: "value1"}, + pointer: JSONPointer("/Field1"), + }, + want: "value1", + }, + { + name: "NavigableNoder succeeds", + args: args{ + source: NavigableNodeWrapper{typ: "wrapper", NavigableNode: InterfaceTestStruct{typ: "struct", Field1: "value1"}}, + pointer: JSONPointer("/Field1"), + }, + want: "value1", + }, + { + name: "NavigableNoder struct is navigable", + args: args{ + source: NavigableNodeWrapper{typ: "struct", Field2: "value2"}, + pointer: JSONPointer("/Field2"), + }, + want: "value2", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + target, err := GetTarget(tt.args.source, tt.args.pointer, tt.args.opts...) + require.NoError(t, err) + assert.Equal(t, tt.want, target) + }) + } +} + +func TestGetTarget_WithInterfaces_Error(t *testing.T) { + type args struct { + source any + pointer JSONPointer + opts []option + } + tests := []struct { + name string + args args + wantErr error + }{ + { + name: "Error returned for invalid KeyNavigable type", + args: args{ + source: InterfaceTestStruct{typ: "slice", valuesByIndex: []any{"value1", "value2"}}, + pointer: JSONPointer("/key2"), + }, + wantErr: errors.New("not found -- invalid path"), + }, + { + name: "Error returned for invalid IndexNavigable type", + args: args{ + source: InterfaceTestStruct{typ: "struct", Field1: "value1"}, + pointer: JSONPointer("/1"), + }, + wantErr: errors.New("can't navigate by index on jsonpointer.InterfaceTestStruct at /1"), }, } for _, tt := range tests { diff --git a/jsonschema/oas31/core/value.go b/jsonschema/oas31/core/value.go index 840e828..9692193 100644 --- a/jsonschema/oas31/core/value.go +++ b/jsonschema/oas31/core/value.go @@ -72,11 +72,11 @@ func (v *EitherValue[L, R]) SyncChanges(ctx context.Context, model any, valueNod } } -func (v *EitherValue[L, R]) GetNavigableNode() any { +func (v *EitherValue[L, R]) GetNavigableNode() (any, error) { if v.Left != nil { - return v.Left + return v.Left, nil } - return v.Right + return v.Right, nil } func unmarshalValue[T any](ctx context.Context, node *yaml.Node) (*T, []error) { diff --git a/marshaller/node.go b/marshaller/node.go index a52483a..645c94a 100644 --- a/marshaller/node.go +++ b/marshaller/node.go @@ -120,6 +120,6 @@ func (n Node[V]) GetMapValueNodeOrRoot(key string, rootNode *yaml.Node) *yaml.No return n.ValueNode } -func (n Node[V]) GetNavigableNode() any { - return n.Value +func (n Node[V]) GetNavigableNode() (any, error) { + return n.Value, nil }