diff --git a/serve/method.go b/serve/method.go index 068087c..306d407 100644 --- a/serve/method.go +++ b/serve/method.go @@ -7,7 +7,6 @@ import ( "io" "io/fs" - "github.com/google/go-jsonnet" statuspb "google.golang.org/genproto/googleapis/rpc/status" "google.golang.org/grpc" "google.golang.org/grpc/metadata" @@ -21,20 +20,17 @@ type method struct { desc protoreflect.MethodDescriptor filename string fs fs.FS - makeVM MakeVM + eval Evaluator } -func newMethod(md protoreflect.MethodDescriptor, fs fs.FS, makeVM MakeVM) method { +func newMethod(md protoreflect.MethodDescriptor, fs fs.FS, eval Evaluator) method { pkg, svc := md.ParentFile().Package(), md.Parent().Name() filename := fmt.Sprintf("%s.%s.%s.jsonnet", pkg, svc, md.Name()) - if makeVM == nil { - makeVM = jsonnet.MakeVM - } return method{ desc: md, filename: filename, fs: fs, - makeVM: makeVM, + eval: eval, } } @@ -66,7 +62,7 @@ func (m method) unaryClientCall(ss grpc.ServerStream) error { return err } - return m.evalJsonnet(input, ss) + return m.evaluate(input, ss) } func (m method) streamingClientCall(ss grpc.ServerStream) error { @@ -88,7 +84,7 @@ func (m method) streamingClientCall(ss grpc.ServerStream) error { return err } - return m.evalJsonnet(input, ss) + return m.evaluate(input, ss) } func (m method) streamingBidiCall(ss grpc.ServerStream) error { @@ -108,21 +104,15 @@ func (m method) streamingBidiCall(ss grpc.ServerStream) error { if err != nil { return err } - if err := m.evalJsonnet(input, ss); err != nil { + if err := m.evaluate(input, ss); err != nil { return err } } return nil } -func (m method) evalJsonnet(input string, ss grpc.ServerStream) error { - vm := m.makeVM() - vm.TLACode("input", input) - b, err := fs.ReadFile(m.fs, m.filename) - if err != nil { - return err - } - output, err := vm.EvaluateAnonymousSnippet(m.filename, string(b)) +func (m method) evaluate(input string, ss grpc.ServerStream) error { + output, err := m.eval(string(m.desc.FullName()), input, m.fs) if err != nil { return err } diff --git a/serve/server.go b/serve/server.go index 0662944..e78d5ab 100644 --- a/serve/server.go +++ b/serve/server.go @@ -42,12 +42,31 @@ func WithLogger(logger Logger) Option { } func WithVM(makeVM MakeVM) Option { + return WithEvaluator(JsonnetEvaluator(makeVM)) +} + +type Evaluator func(method, input string, vfs fs.FS) (output string, err error) + +func WithEvaluator(evaluator Evaluator) Option { return func(s *Server) error { - s.makeVM = makeVM + s.eval = evaluator return nil } } +func JsonnetEvaluator(makeVM MakeVM) Evaluator { + return func(method, input string, vfs fs.FS) (output string, err error) { + vm := makeVM() + vm.TLACode("input", input) + filename := method + ".jsonnet" + b, err := fs.ReadFile(vfs, filename) + if err != nil { + return "", err + } + return vm.EvaluateAnonymousSnippet(filename, string(b)) + } +} + type Server struct { methodDir string protoSet string @@ -57,7 +76,7 @@ type Server struct { gs *grpc.Server files *protoregistry.Files fs fs.FS - makeVM MakeVM + eval Evaluator } var errUnknownHandler = errors.New("Unknown handler") @@ -69,6 +88,7 @@ func NewServer(methodDir, protoSet string, options ...Option) (*Server, error) { protoSet: protoSet, log: NewLogger(os.Stderr, LogLevelError), } + options = append([]Option{WithVM(jsonnet.MakeVM)}, options...) for _, opt := range options { if err := opt(s); err != nil { return nil, err @@ -133,7 +153,7 @@ func (s *Server) loadMethods() error { for i := 0; i < sds.Len(); i++ { mds := sds.Get(i).Methods() for j := 0; j < mds.Len(); j++ { - m := newMethod(mds.Get(j), methodFS, s.makeVM) + m := newMethod(mds.Get(j), methodFS, s.eval) s.methods[m.fullMethod()] = m } }