Skip to content

Commit

Permalink
Plumb through a notion of the source folder
Browse files Browse the repository at this point in the history
  • Loading branch information
tstirrat15 committed Nov 5, 2024
1 parent eaab825 commit 03250fd
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 10 deletions.
8 changes: 8 additions & 0 deletions pkg/composableschemadsl/compiler/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ type config struct {
skipValidation bool
objectTypePrefix *string
existingNames []string
// In an import context, this is the folder containing
// the importing schema (as opposed to imported schemas)
sourceFolder string
}

func SkipValidation() Option { return func(cfg *config) { cfg.skipValidation = true } }
Expand All @@ -73,6 +76,10 @@ func ExistingNames(names []string) Option {
return func(cfg *config) { cfg.existingNames = names }
}

func SourceFolder(sourceFolder string) Option {
return func(cfg *config) { cfg.sourceFolder = sourceFolder }
}

type Option func(*config)

type ObjectPrefixOption func(*config)
Expand Down Expand Up @@ -100,6 +107,7 @@ func Compile(schema InputSchema, prefix ObjectPrefixOption, opts ...Option) (*Co
schemaString: schema.SchemaString,
skipValidate: cfg.skipValidation,
existingNames: cfg.existingNames,
sourceFolder: cfg.sourceFolder,
}, root)
if err != nil {
var errorWithNode errorWithNode
Expand Down
17 changes: 14 additions & 3 deletions pkg/composableschemadsl/compiler/importer.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,17 @@ import (
"github.com/rs/zerolog/log"
)

type ImportContext struct {
pathSegments []string
sourceFolder string
names *mapz.Set[string]
}

const SchemaFileSuffix = ".zed"

func ImportFile(pathSegments []string, existingNames *mapz.Set[string]) (*CompiledSchema, error) {
filepath := constructFilePath(pathSegments)
func ImportFile(importContext *ImportContext) (*CompiledSchema, error) {
relativeFilepath := constructFilePath(importContext.pathSegments)
filepath := path.Join(importContext.sourceFolder, relativeFilepath)

var schemaBytes []byte
schemaBytes, err := os.ReadFile(filepath)
Expand All @@ -27,7 +34,11 @@ func ImportFile(pathSegments []string, existingNames *mapz.Set[string]) (*Compil
// TODO: should this point to the schema file? What is this for?
Source: input.Source("schema"),
SchemaString: string(schemaBytes),
}, AllowUnprefixedObjectType(), ExistingNames(existingNames.AsSlice()))
},
AllowUnprefixedObjectType(),
ExistingNames(importContext.names.AsSlice()),
SourceFolder(importContext.sourceFolder),
)
if err != nil {
return nil, err
}
Expand Down
20 changes: 15 additions & 5 deletions pkg/composableschemadsl/compiler/importer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@ package compiler_test
import (
"fmt"
"os"
"path"
"testing"

"github.com/stretchr/testify/require"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/authzed/spicedb/pkg/composableschemadsl/input"
"github.com/authzed/spicedb/pkg/composableschemadsl/compiler"
"github.com/authzed/spicedb/pkg/composableschemadsl/generator"
"github.com/authzed/spicedb/pkg/composableschemadsl/input"
)

type importerTest struct {
name string
name string
folder string
}

Expand All @@ -27,6 +28,10 @@ func (it *importerTest) input() string {
return string(b)
}

func (it *importerTest) relativePath() string {
return fmt.Sprintf("importer-test/%s", it.folder)
}

func (it *importerTest) expected() string {
b, err := os.ReadFile(fmt.Sprintf("importer-test/%s/expected.zed", it.folder))
if err != nil {
Expand All @@ -44,6 +49,9 @@ func (it *importerTest) writeExpected(schema string) {
}

func TestImporter(t *testing.T) {
workingDir, err := os.Getwd()
require.NoError(t, err)

importerTests := []importerTest{
{"simple local import", "simple-local"},
{"simple local import", "simple-local-with-hop"},
Expand All @@ -55,12 +63,15 @@ func TestImporter(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
t.Parallel()

sourceFolder := path.Join(workingDir, test.relativePath())

inputSchema := test.input()

compiled, err := compiler.Compile(compiler.InputSchema{
Source: input.Source("schema"),
SchemaString: inputSchema,
}, compiler.AllowUnprefixedObjectType())
}, compiler.AllowUnprefixedObjectType(),
compiler.SourceFolder(sourceFolder))
require.NoError(t, err)

generated, _, err := generator.GenerateSchema(compiled.OrderedDefinitions)
Expand All @@ -76,5 +87,4 @@ func TestImporter(t *testing.T) {
}
})
}

}
9 changes: 7 additions & 2 deletions pkg/composableschemadsl/compiler/translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type translationContext struct {
schemaString string
skipValidate bool
existingNames []string
sourceFolder string
}

func (tctx translationContext) prefixedPath(definitionName string) (string, error) {
Expand Down Expand Up @@ -690,7 +691,7 @@ func addWithCaveats(tctx translationContext, typeRefNode *dslNode, ref *core.All
return nil
}

func translateImport(_ translationContext, importNode *dslNode, names *mapz.Set[string]) (*CompiledSchema, error) {
func translateImport(tctx translationContext, importNode *dslNode, names *mapz.Set[string]) (*CompiledSchema, error) {
// NOTE: this function currently just grabs everything that's in the target file.
// TODO: only grab the requested definitions
// TODO: import cycle tracking
Expand All @@ -706,5 +707,9 @@ func translateImport(_ translationContext, importNode *dslNode, names *mapz.Set[
pathSegments = append(pathSegments, segment)
}

return ImportFile(pathSegments, names)
return ImportFile(&ImportContext{
names: names,
pathSegments: pathSegments,
sourceFolder: tctx.sourceFolder,
})
}

0 comments on commit 03250fd

Please sign in to comment.