Skip to content

Commit

Permalink
fix infine loop (#234)
Browse files Browse the repository at this point in the history
  • Loading branch information
Reuven Harrison authored May 3, 2023
1 parent dbca7ff commit d287cc8
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 65 deletions.
16 changes: 0 additions & 16 deletions diff/circular_refs.go

This file was deleted.

7 changes: 4 additions & 3 deletions diff/schema_circular_refs.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package diff

import (
"github.com/getkin/kin-openapi/openapi3"
"github.com/tufin/oasdiff/utils"
)

type circularRefStatus int
Expand All @@ -12,15 +13,15 @@ const (
circularRefStatusNoDiff
)

func getCircularRefsDiff(visited1, visited2 visitedRefs, schema1, schema2 *openapi3.SchemaRef) circularRefStatus {
func getCircularRefsDiff(visited1, visited2 utils.VisitedRefs, schema1, schema2 *openapi3.SchemaRef) circularRefStatus {

if schema1 == nil || schema2 == nil ||
schema1.Value == nil || schema2.Value == nil {
return circularRefStatusNone
}

circular1 := visited1.isVisited(schema1.Ref)
circular2 := visited2.isVisited(schema2.Ref)
circular1 := visited1.IsVisited(schema1.Ref)
circular2 := visited2.IsVisited(schema2.Ref)

// neither are circular
if !circular1 && !circular2 {
Expand Down
8 changes: 4 additions & 4 deletions diff/schema_diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,13 +331,13 @@ func getSchemaDiffInternal(config *Config, state *state, schema1, schema2 *opena

// mark visited schema references to avoid infinite loops
if schema1.Ref != "" {
state.visitedSchemasBase.add(schema1.Ref)
defer state.visitedSchemasBase.remove(schema1.Ref)
state.visitedSchemasBase.Add(schema1.Ref)
defer state.visitedSchemasBase.Remove(schema1.Ref)
}

if schema2.Ref != "" {
state.visitedSchemasRevision.add(schema2.Ref)
defer state.visitedSchemasRevision.remove(schema2.Ref)
state.visitedSchemasRevision.Add(schema2.Ref)
defer state.visitedSchemasRevision.Remove(schema2.Ref)
}

result := SchemaDiff{}
Expand Down
10 changes: 6 additions & 4 deletions diff/state.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package diff

import "github.com/tufin/oasdiff/utils"

type direction int

const (
Expand All @@ -8,16 +10,16 @@ const (
)

type state struct {
visitedSchemasBase visitedRefs
visitedSchemasRevision visitedRefs
visitedSchemasBase utils.VisitedRefs
visitedSchemasRevision utils.VisitedRefs
cache directionalSchemaDiffCache
direction direction
}

func newState() *state {
return &state{
visitedSchemasBase: visitedRefs{},
visitedSchemasRevision: visitedRefs{},
visitedSchemasBase: utils.VisitedRefs{},
visitedSchemasRevision: utils.VisitedRefs{},
cache: newDirectionalSchemaDiffCache(),
direction: directionRequest,
}
Expand Down
88 changes: 50 additions & 38 deletions lint/check-regexp.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,24 @@ import (

"github.com/getkin/kin-openapi/openapi3"
"github.com/tufin/oasdiff/load"
"github.com/tufin/oasdiff/utils"
)

type regexCtx struct {
source string
cache map[string]error
type state struct {
source string
cache map[string]error
visitedRefs utils.VisitedRefs
}

func newRegexCtx(source string) *regexCtx {
return &regexCtx{
source: source,
cache: map[string]error{},
func newState(source string) *state {
return &state{
source: source,
cache: map[string]error{},
visitedRefs: utils.VisitedRefs{},
}
}

func (context *regexCtx) validate(pattern string) error {
func (context *state) validate(pattern string) error {
if result, ok := context.cache[pattern]; ok {
return result
}
Expand All @@ -31,119 +34,128 @@ func (context *regexCtx) validate(pattern string) error {
// *** THIS IS A TEMPORARY IMPLEMENTATION ***
// SHOULD USE ECMA 262, SEE: https://swagger.io/docs/specification/data-models/data-types/#pattern

func RegexCheck(source string, s *load.OpenAPISpecInfo) []*Error {
func RegexCheck(source string, spec *load.OpenAPISpecInfo) []*Error {

result := make([]*Error, 0)

if s == nil || s.Spec == nil {
if spec == nil || spec.Spec == nil {
return result
}

context := newRegexCtx(source)
s := newState(source)

for _, path := range s.Spec.Paths {
result = append(result, checkParameters(path.Parameters, context)...)
result = append(result, checkOperations(path.Operations(), context)...)
for _, path := range spec.Spec.Paths {
result = append(result, checkParameters(path.Parameters, s)...)
result = append(result, checkOperations(path.Operations(), s)...)
}

return result
}

func checkOperations(operations map[string]*openapi3.Operation, context *regexCtx) []*Error {
func checkOperations(operations map[string]*openapi3.Operation, s *state) []*Error {
result := make([]*Error, 0)
for _, op := range operations {

result = append(result, checkParameters(op.Parameters, context)...)
result = append(result, checkParameters(op.Parameters, s)...)

if op.RequestBody != nil {
for _, mediaType := range op.RequestBody.Value.Content {
result = append(result, checkSchemaRef(mediaType.Schema, context)...)
result = append(result, checkSchemaRef(mediaType.Schema, s)...)
}
}

for _, response := range op.Responses {
for _, mediaType := range response.Value.Content {
result = append(result, checkSchemaRef(mediaType.Schema, context)...)
result = append(result, checkSchemaRef(mediaType.Schema, s)...)
}
for _, header := range response.Value.Headers {
result = append(result, checkSchemaRef(header.Value.Schema, context)...)
result = append(result, checkSchemaRef(header.Value.Schema, s)...)
}
}

for _, callback := range op.Callbacks {
for _, pathItem := range *callback.Value {
result = append(result, checkParameters(pathItem.Parameters, context)...)
result = append(result, checkOperations(pathItem.Operations(), context)...)
result = append(result, checkParameters(pathItem.Parameters, s)...)
result = append(result, checkOperations(pathItem.Operations(), s)...)
}
}
}
return result
}

func checkParameters(parameters openapi3.Parameters, context *regexCtx) []*Error {
func checkParameters(parameters openapi3.Parameters, s *state) []*Error {
result := make([]*Error, 0)
for _, parameter := range parameters {
if parameter.Value == nil {
continue
}
if parameter.Value.Schema != nil {
result = append(result, checkSchemaRef(parameter.Value.Schema, context)...)
result = append(result, checkSchemaRef(parameter.Value.Schema, s)...)
}
for _, mediaType := range parameter.Value.Content {
if mediaType.Schema != nil {
result = append(result, checkSchemaRef(mediaType.Schema, context)...)
result = append(result, checkSchemaRef(mediaType.Schema, s)...)
}
}
}
return result
}

func checkSchema(schema *openapi3.Schema, context *regexCtx) []*Error {
func checkSchema(schema *openapi3.Schema, s *state) []*Error {
result := make([]*Error, 0)

if err := checkRegex(schema.Pattern, context); err != nil {
if err := checkRegex(schema.Pattern, s); err != nil {
result = append(result, err)
}

for _, subSchema := range schema.OneOf {
result = append(result, checkSchemaRef(subSchema, context)...)
result = append(result, checkSchemaRef(subSchema, s)...)
}
for _, subSchema := range schema.AnyOf {
result = append(result, checkSchemaRef(subSchema, context)...)
result = append(result, checkSchemaRef(subSchema, s)...)
}
for _, subSchema := range schema.AllOf {
result = append(result, checkSchemaRef(subSchema, context)...)
result = append(result, checkSchemaRef(subSchema, s)...)
}
if schema.Not != nil {
result = append(result, checkSchemaRef(schema.Not, context)...)
result = append(result, checkSchemaRef(schema.Not, s)...)
}
if schema.Items != nil {
result = append(result, checkSchemaRef(schema.Items, context)...)
result = append(result, checkSchemaRef(schema.Items, s)...)
}
for _, subSchema := range schema.Properties {
result = append(result, checkSchemaRef(subSchema, context)...)
result = append(result, checkSchemaRef(subSchema, s)...)
}
if schema.AdditionalProperties.Schema != nil {
result = append(result, checkSchemaRef(schema.AdditionalProperties.Schema, context)...)
result = append(result, checkSchemaRef(schema.AdditionalProperties.Schema, s)...)
}
return result
}

func checkSchemaRef(schema *openapi3.SchemaRef, context *regexCtx) []*Error {
return checkSchema(schema.Value, context)
func checkSchemaRef(schema *openapi3.SchemaRef, s *state) []*Error {
if s.visitedRefs.IsVisited(schema.Ref) {
return nil
}
// mark visited schema references to avoid infinite loops
if schema.Ref != "" {
s.visitedRefs.Add(schema.Ref)
defer s.visitedRefs.Remove(schema.Ref)
}

return checkSchema(schema.Value, s)
}

func checkRegex(pattern string, context *regexCtx) *Error {
func checkRegex(pattern string, s *state) *Error {
if pattern == "" {
return nil
}

if err := context.validate(pattern); err != nil {
if err := s.validate(pattern); err != nil {
return &Error{
Id: "invalid-regex-pattern",
Level: LEVEL_ERROR,
Text: err.Error(),
Source: context.source,
Source: s.source,
}
}

Expand Down
7 changes: 7 additions & 0 deletions lint/check-regexp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,10 @@ func TestRegexCheck_Embedded(t *testing.T) {
require.Equal(t, "invalid-regex-pattern", errs[i].Id)
}
}

func TestRegexCheck_Circular(t *testing.T) {

const source = "../data/circular2.yaml"
errs := lint.Run(*lint.NewConfig([]lint.Check{lint.RegexCheck}), source, loadFrom(t, source))
require.Empty(t, errs)
}
16 changes: 16 additions & 0 deletions utils/circular_refs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package utils

type VisitedRefs map[string]struct{}

func (v VisitedRefs) Add(refName string) {
v[refName] = struct{}{}
}

func (v VisitedRefs) Remove(refName string) {
delete(v, refName)
}

func (v VisitedRefs) IsVisited(refName string) bool {
_, ok := v[refName]
return ok
}

0 comments on commit d287cc8

Please sign in to comment.