From 03250fd9de66b052e55bbc0ab0e13984cbb53f0a Mon Sep 17 00:00:00 2001 From: Tanner Stirrat Date: Tue, 5 Nov 2024 10:38:49 -0700 Subject: [PATCH] Plumb through a notion of the source folder --- pkg/composableschemadsl/compiler/compiler.go | 8 ++++++++ pkg/composableschemadsl/compiler/importer.go | 17 +++++++++++++--- .../compiler/importer_test.go | 20 ++++++++++++++----- .../compiler/translator.go | 9 +++++++-- 4 files changed, 44 insertions(+), 10 deletions(-) diff --git a/pkg/composableschemadsl/compiler/compiler.go b/pkg/composableschemadsl/compiler/compiler.go index 1c16762248..dd942a7550 100644 --- a/pkg/composableschemadsl/compiler/compiler.go +++ b/pkg/composableschemadsl/compiler/compiler.go @@ -53,6 +53,9 @@ type config struct { skipValidation bool objectTypePrefix *string existingNames []string + // In an import context, this is the folder containing + // the importing schema (as opposed to imported schemas) + sourceFolder string } func SkipValidation() Option { return func(cfg *config) { cfg.skipValidation = true } } @@ -73,6 +76,10 @@ func ExistingNames(names []string) Option { return func(cfg *config) { cfg.existingNames = names } } +func SourceFolder(sourceFolder string) Option { + return func(cfg *config) { cfg.sourceFolder = sourceFolder } +} + type Option func(*config) type ObjectPrefixOption func(*config) @@ -100,6 +107,7 @@ func Compile(schema InputSchema, prefix ObjectPrefixOption, opts ...Option) (*Co schemaString: schema.SchemaString, skipValidate: cfg.skipValidation, existingNames: cfg.existingNames, + sourceFolder: cfg.sourceFolder, }, root) if err != nil { var errorWithNode errorWithNode diff --git a/pkg/composableschemadsl/compiler/importer.go b/pkg/composableschemadsl/compiler/importer.go index 7022fce98a..f175b27951 100644 --- a/pkg/composableschemadsl/compiler/importer.go +++ b/pkg/composableschemadsl/compiler/importer.go @@ -10,10 +10,17 @@ import ( "github.com/rs/zerolog/log" ) +type ImportContext struct { + pathSegments []string + sourceFolder string + names *mapz.Set[string] +} + const SchemaFileSuffix = ".zed" -func ImportFile(pathSegments []string, existingNames *mapz.Set[string]) (*CompiledSchema, error) { - filepath := constructFilePath(pathSegments) +func ImportFile(importContext *ImportContext) (*CompiledSchema, error) { + relativeFilepath := constructFilePath(importContext.pathSegments) + filepath := path.Join(importContext.sourceFolder, relativeFilepath) var schemaBytes []byte schemaBytes, err := os.ReadFile(filepath) @@ -27,7 +34,11 @@ func ImportFile(pathSegments []string, existingNames *mapz.Set[string]) (*Compil // TODO: should this point to the schema file? What is this for? Source: input.Source("schema"), SchemaString: string(schemaBytes), - }, AllowUnprefixedObjectType(), ExistingNames(existingNames.AsSlice())) + }, + AllowUnprefixedObjectType(), + ExistingNames(importContext.names.AsSlice()), + SourceFolder(importContext.sourceFolder), + ) if err != nil { return nil, err } diff --git a/pkg/composableschemadsl/compiler/importer_test.go b/pkg/composableschemadsl/compiler/importer_test.go index 7e54a529f7..5292d061d3 100644 --- a/pkg/composableschemadsl/compiler/importer_test.go +++ b/pkg/composableschemadsl/compiler/importer_test.go @@ -3,18 +3,19 @@ package compiler_test import ( "fmt" "os" + "path" "testing" - "github.com/stretchr/testify/require" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" - "github.com/authzed/spicedb/pkg/composableschemadsl/input" "github.com/authzed/spicedb/pkg/composableschemadsl/compiler" "github.com/authzed/spicedb/pkg/composableschemadsl/generator" + "github.com/authzed/spicedb/pkg/composableschemadsl/input" ) type importerTest struct { - name string + name string folder string } @@ -27,6 +28,10 @@ func (it *importerTest) input() string { return string(b) } +func (it *importerTest) relativePath() string { + return fmt.Sprintf("importer-test/%s", it.folder) +} + func (it *importerTest) expected() string { b, err := os.ReadFile(fmt.Sprintf("importer-test/%s/expected.zed", it.folder)) if err != nil { @@ -44,6 +49,9 @@ func (it *importerTest) writeExpected(schema string) { } func TestImporter(t *testing.T) { + workingDir, err := os.Getwd() + require.NoError(t, err) + importerTests := []importerTest{ {"simple local import", "simple-local"}, {"simple local import", "simple-local-with-hop"}, @@ -55,12 +63,15 @@ func TestImporter(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() + sourceFolder := path.Join(workingDir, test.relativePath()) + inputSchema := test.input() compiled, err := compiler.Compile(compiler.InputSchema{ Source: input.Source("schema"), SchemaString: inputSchema, - }, compiler.AllowUnprefixedObjectType()) + }, compiler.AllowUnprefixedObjectType(), + compiler.SourceFolder(sourceFolder)) require.NoError(t, err) generated, _, err := generator.GenerateSchema(compiled.OrderedDefinitions) @@ -76,5 +87,4 @@ func TestImporter(t *testing.T) { } }) } - } diff --git a/pkg/composableschemadsl/compiler/translator.go b/pkg/composableschemadsl/compiler/translator.go index 6a216dea95..927c963e17 100644 --- a/pkg/composableschemadsl/compiler/translator.go +++ b/pkg/composableschemadsl/compiler/translator.go @@ -24,6 +24,7 @@ type translationContext struct { schemaString string skipValidate bool existingNames []string + sourceFolder string } func (tctx translationContext) prefixedPath(definitionName string) (string, error) { @@ -690,7 +691,7 @@ func addWithCaveats(tctx translationContext, typeRefNode *dslNode, ref *core.All return nil } -func translateImport(_ translationContext, importNode *dslNode, names *mapz.Set[string]) (*CompiledSchema, error) { +func translateImport(tctx translationContext, importNode *dslNode, names *mapz.Set[string]) (*CompiledSchema, error) { // NOTE: this function currently just grabs everything that's in the target file. // TODO: only grab the requested definitions // TODO: import cycle tracking @@ -706,5 +707,9 @@ func translateImport(_ translationContext, importNode *dslNode, names *mapz.Set[ pathSegments = append(pathSegments, segment) } - return ImportFile(pathSegments, names) + return ImportFile(&ImportContext{ + names: names, + pathSegments: pathSegments, + sourceFolder: tctx.sourceFolder, + }) }