From 2286f4d3d5405c4fba6cb26cbf47384c7dacfd9f Mon Sep 17 00:00:00 2001 From: Ian Rees Date: Mon, 18 Dec 2023 14:52:03 -0800 Subject: [PATCH 01/17] Config refactoring --- actions/fetch.go | 30 ++++++++++++++-------- actions/fetch_test.go | 3 ++- actions/fv.go | 21 +++++++++++---- actions/validate.go | 6 +++-- config/config.go | 20 --------------- internal/testfinder/testfinder.go | 3 +-- jobs/jobs.go | 12 +++------ model/finders.go | 38 ++++++++++++++++++++++++++-- server/gql/agency_resolver_test.go | 2 +- server/gql/loaders.go | 3 +-- server/gql/mutation_resolver.go | 10 ++++---- server/gql/mutation_resolver_test.go | 6 ++--- server/gql/resolver.go | 3 +-- server/gql/resolver_test.go | 2 +- server/gql/server.go | 17 ++++++------- server/rest/rest.go | 5 ++-- server/rest/rest_test.go | 5 ++-- server/server_cmd.go | 23 +++++++++-------- workers/fetch_enqueue_worker.go | 4 +-- workers/gbfs_fetch_worker.go | 6 ++--- workers/gbfs_fetch_worker_test.go | 4 +-- workers/rt_fetch_worker.go | 2 +- workers/static_fetch_worker.go | 2 +- workers/workers.go | 3 +-- 24 files changed, 128 insertions(+), 102 deletions(-) delete mode 100644 config/config.go diff --git a/actions/fetch.go b/actions/fetch.go index 25406e25..e3d539b5 100644 --- a/actions/fetch.go +++ b/actions/fetch.go @@ -18,16 +18,19 @@ import ( "github.com/interline-io/transitland-lib/tldb" "github.com/interline-io/transitland-mw/auth/authn" "github.com/interline-io/transitland-mw/auth/authz" - "github.com/interline-io/transitland-server/config" "github.com/interline-io/transitland-server/internal/dbutil" "github.com/interline-io/transitland-server/model" "github.com/jmoiron/sqlx" "google.golang.org/protobuf/proto" ) -func StaticFetch(ctx context.Context, cfg config.Config, dbf model.Finder, feedId string, feedSrc io.Reader, feedUrl string, checker model.Checker) (*model.FeedVersionFetchResult, error) { +func StaticFetch(ctx context.Context, feedId string, feedSrc io.Reader, feedUrl string) (*model.FeedVersionFetchResult, error) { + frs := model.ForContext(ctx) + cfg := frs.Config + dbf := frs.Finder + urlType := "static_current" - feed, err := fetchCheckFeed(ctx, dbf, checker, feedId, urlType, feedUrl) + feed, err := fetchCheckFeed(ctx, feedId, urlType, feedUrl) if err != nil { return nil, err } @@ -86,8 +89,10 @@ func StaticFetch(ctx context.Context, cfg config.Config, dbf model.Finder, feedI return &mr, nil } -func RTFetch(ctx context.Context, cfg config.Config, dbf model.Finder, rtf model.RTFinder, target string, feedId string, feedUrl string, urlType string, checker model.Checker) error { - feed, err := fetchCheckFeed(ctx, dbf, checker, feedId, urlType, feedUrl) +func RTFetch(ctx context.Context, target string, feedId string, feedUrl string, urlType string) error { + frs := model.ForContext(ctx) + + feed, err := fetchCheckFeed(ctx, feedId, urlType, feedUrl) if err != nil { return err } @@ -100,15 +105,15 @@ func RTFetch(ctx context.Context, cfg config.Config, dbf model.Finder, rtf model FeedID: feed.ID, URLType: urlType, FeedURL: feedUrl, - Storage: cfg.RTStorage, - Secrets: cfg.Secrets, + Storage: frs.Config.RTStorage, + Secrets: frs.Config.Secrets, FetchedAt: time.Now().In(time.UTC), } // Make request var rtMsg *pb.FeedMessage var fetchErr error - if err := tldb.NewPostgresAdapterFromDBX(dbf.DBX()).Tx(func(atx tldb.Adapter) error { + if err := tldb.NewPostgresAdapterFromDBX(frs.Finder.DBX()).Tx(func(atx tldb.Adapter) error { m, fr, err := fetch.RTFetch(atx, fetchOpts) if err != nil { return err @@ -129,7 +134,7 @@ func RTFetch(ctx context.Context, cfg config.Config, dbf model.Finder, rtf model return errors.New("invalid rt data") } key := fmt.Sprintf("rtdata:%s:%s", target, urlType) - return rtf.AddData(key, rtdata) + return frs.RTFinder.AddData(key, rtdata) } type CheckFetchWaitResult struct { @@ -250,9 +255,12 @@ func chunkBy[T any](items []T, chunkSize int) (chunks [][]T) { return append(chunks, items) } -func fetchCheckFeed(ctx context.Context, dbf model.Finder, checker model.Checker, feedId string, urlType string, url string) (*model.Feed, error) { +func fetchCheckFeed(ctx context.Context, feedId string, urlType string, url string) (*model.Feed, error) { + frs := model.ForContext(ctx) + checker := frs.Checker + // Check feed exists - feeds, err := dbf.FindFeeds(ctx, nil, nil, nil, &model.FeedFilter{OnestopID: &feedId}) + feeds, err := frs.Finder.FindFeeds(ctx, nil, nil, nil, &model.FeedFilter{OnestopID: &feedId}) if err != nil { return nil, err } diff --git a/actions/fetch_test.go b/actions/fetch_test.go index c1b0d98b..fe8a43b0 100644 --- a/actions/fetch_test.go +++ b/actions/fetch_test.go @@ -111,8 +111,9 @@ func TestStaticFetchWorker(t *testing.T) { // Setup job feedUrl := ts.URL + "/" + tc.serveFile testfinder.FindersTxRollback(t, nil, nil, func(te model.Finders) { + ctx := model.WithFinders(context.Background(), te) // Run job - if result, err := StaticFetch(context.Background(), te.Config, te.Finder, tc.feedId, nil, feedUrl, nil); err != nil && !tc.expectError { + if result, err := StaticFetch(ctx, tc.feedId, nil, feedUrl); err != nil && !tc.expectError { _ = result t.Fatal("unexpected error", err) } else if err == nil && tc.expectError { diff --git a/actions/fv.go b/actions/fv.go index a5956d5a..e2200666 100644 --- a/actions/fv.go +++ b/actions/fv.go @@ -10,11 +10,14 @@ import ( "github.com/interline-io/transitland-lib/tl/tt" "github.com/interline-io/transitland-lib/tldb" "github.com/interline-io/transitland-mw/auth/authz" - "github.com/interline-io/transitland-server/config" "github.com/interline-io/transitland-server/model" ) -func FeedVersionImport(ctx context.Context, cfg config.Config, dbf model.Finder, checker model.Checker, fvid int) (*model.FeedVersionImportResult, error) { +func FeedVersionImport(ctx context.Context, fvid int) (*model.FeedVersionImportResult, error) { + frs := model.ForContext(ctx) + checker := frs.Checker + cfg := frs.Config + dbf := frs.Finder if checker == nil { return nil, authz.ErrUnauthorized } @@ -38,7 +41,10 @@ func FeedVersionImport(ctx context.Context, cfg config.Config, dbf model.Finder, return &mr, nil } -func FeedVersionUnimport(ctx context.Context, cfg config.Config, dbf model.Finder, checker model.Checker, fvid int) (*model.FeedVersionUnimportResult, error) { +func FeedVersionUnimport(ctx context.Context, fvid int) (*model.FeedVersionUnimportResult, error) { + frs := model.ForContext(ctx) + checker := frs.Checker + dbf := frs.Finder if checker == nil { return nil, authz.ErrUnauthorized } @@ -59,7 +65,10 @@ func FeedVersionUnimport(ctx context.Context, cfg config.Config, dbf model.Finde return &mr, nil } -func FeedVersionUpdate(ctx context.Context, cfg config.Config, dbf model.Finder, checker model.Checker, fvid int, values model.FeedVersionSetInput) error { +func FeedVersionUpdate(ctx context.Context, fvid int, values model.FeedVersionSetInput) error { + frs := model.ForContext(ctx) + checker := frs.Checker + dbf := frs.Finder if checker == nil { return authz.ErrUnauthorized } @@ -93,7 +102,9 @@ func FeedVersionUpdate(ctx context.Context, cfg config.Config, dbf model.Finder, return nil } -func FeedVersionDelete(ctx context.Context, cfg config.Config, dbf model.Finder, checker model.Checker, fvid int) (*model.FeedVersionDeleteResult, error) { +func FeedVersionDelete(ctx context.Context, fvid int) (*model.FeedVersionDeleteResult, error) { + frs := model.ForContext(ctx) + checker := frs.Checker if checker == nil { return nil, authz.ErrUnauthorized } diff --git a/actions/validate.go b/actions/validate.go index bdf8826f..80c5bbad 100644 --- a/actions/validate.go +++ b/actions/validate.go @@ -12,7 +12,6 @@ import ( "github.com/interline-io/transitland-lib/tl/tt" "github.com/interline-io/transitland-lib/tlcsv" "github.com/interline-io/transitland-lib/validator" - "github.com/interline-io/transitland-server/config" "github.com/interline-io/transitland-server/model" ) @@ -21,7 +20,10 @@ type hasGeometries interface { } // ValidateUpload takes a file Reader and produces a validation package containing errors, warnings, file infos, service levels, etc. -func ValidateUpload(ctx context.Context, cfg config.Config, src io.Reader, feedURL *string, rturls []string) (*model.ValidationResult, error) { +func ValidateUpload(ctx context.Context, src io.Reader, feedURL *string, rturls []string) (*model.ValidationResult, error) { + frs := model.ForContext(ctx) + cfg := frs.Config + // Check inputs rturlsok := []string{} for _, rturl := range rturls { diff --git a/config/config.go b/config/config.go deleted file mode 100644 index b5b4bdb8..00000000 --- a/config/config.go +++ /dev/null @@ -1,20 +0,0 @@ -package config - -import ( - "github.com/interline-io/transitland-lib/tl" - "github.com/interline-io/transitland-server/internal/clock" -) - -// Config is in a separate package to avoid import cycles. - -type Config struct { - Storage string - RTStorage string - ValidateLargeFiles bool - DisableImage bool - RestPrefix string - DBURL string - RedisURL string - Clock clock.Clock - Secrets []tl.Secret -} diff --git a/internal/testfinder/testfinder.go b/internal/testfinder/testfinder.go index 5107052b..fd788f0e 100644 --- a/internal/testfinder/testfinder.go +++ b/internal/testfinder/testfinder.go @@ -10,7 +10,6 @@ import ( "github.com/interline-io/transitland-lib/rt" "github.com/interline-io/transitland-mw/auth/authz" "github.com/interline-io/transitland-mw/auth/azcheck" - "github.com/interline-io/transitland-server/config" "github.com/interline-io/transitland-server/finders/dbfinder" "github.com/interline-io/transitland-server/finders/gbfsfinder" "github.com/interline-io/transitland-server/finders/rtfinder" @@ -34,7 +33,7 @@ func newFinders(t testing.TB, db sqlx.Ext, opts TestFinderOptions) model.Finders if opts.Clock == nil { opts.Clock = &clock.Real{} } - cfg := config.Config{ + cfg := model.Config{ Clock: opts.Clock, Storage: t.TempDir(), RTStorage: t.TempDir(), diff --git a/jobs/jobs.go b/jobs/jobs.go index 1c0ba54c..de18834c 100644 --- a/jobs/jobs.go +++ b/jobs/jobs.go @@ -7,7 +7,6 @@ import ( "encoding/json" "github.com/interline-io/transitland-lib/tl" - "github.com/interline-io/transitland-server/config" "github.com/interline-io/transitland-server/model" "github.com/rs/zerolog" ) @@ -45,13 +44,10 @@ func (job *Job) HexKey() (string, error) { // JobOptions is configuration passed to worker. type JobOptions struct { - Finder model.Finder - RTFinder model.RTFinder - GbfsFinder model.GbfsFinder - JobQueue JobQueue - Logger zerolog.Logger - Config config.Config - Secrets []tl.Secret + Finders model.Finders + JobQueue JobQueue + Logger zerolog.Logger + Secrets []tl.Secret } // GetWorker returns a new worker for this job type diff --git a/model/finders.go b/model/finders.go index 280e71ab..7f62e96f 100644 --- a/model/finders.go +++ b/model/finders.go @@ -5,20 +5,54 @@ import ( "time" "github.com/interline-io/transitland-lib/rt/pb" + "github.com/interline-io/transitland-lib/tl" "github.com/interline-io/transitland-lib/tl/tt" "github.com/interline-io/transitland-mw/auth/authz" - "github.com/interline-io/transitland-server/config" + "github.com/interline-io/transitland-server/internal/clock" "github.com/interline-io/transitland-server/internal/gbfs" + "github.com/rs/zerolog" "github.com/jmoiron/sqlx" ) +var finderCtxKey = &contextKey{"finderConfig"} + +type contextKey struct { + name string +} + +func ForContext(ctx context.Context) Finders { + raw, ok := ctx.Value(finderCtxKey).(Finders) + if !ok { + return Finders{} + } + return raw +} + +func WithFinders(ctx context.Context, fs Finders) context.Context { + r := context.WithValue(ctx, finderCtxKey, fs) + return r +} + +type Config struct { + Storage string + RTStorage string + ValidateLargeFiles bool + DisableImage bool + RestPrefix string + DBURL string + RedisURL string + Clock clock.Clock + Secrets []tl.Secret +} + type Finders struct { - Config config.Config + Config Config Finder Finder RTFinder RTFinder GbfsFinder GbfsFinder Checker Checker + Logger zerolog.Logger } // Finder provides all necessary database methods diff --git a/server/gql/agency_resolver_test.go b/server/gql/agency_resolver_test.go index c911da33..36018c4a 100644 --- a/server/gql/agency_resolver_test.go +++ b/server/gql/agency_resolver_test.go @@ -315,7 +315,7 @@ func TestAgencyResolver_Authz(t *testing.T) { FGAModelTuples: fgaTestTuples, } te := testfinder.FindersWithOptions(t, teOpts) - srv, _ := NewServer(te.Config, te.Finder, te.RTFinder, te.GbfsFinder, te.Checker) + srv, _ := NewServer(te) testcases := []testcase{ { name: "basic", diff --git a/server/gql/loaders.go b/server/gql/loaders.go index 2a6013ec..2fa46362 100644 --- a/server/gql/loaders.go +++ b/server/gql/loaders.go @@ -9,7 +9,6 @@ import ( dataloader "github.com/graph-gophers/dataloader/v7" "github.com/interline-io/log" "github.com/interline-io/transitland-lib/tl/tt" - "github.com/interline-io/transitland-server/config" "github.com/interline-io/transitland-server/model" ) @@ -136,7 +135,7 @@ func NewLoaders(dbf model.Finder) *Loaders { return loaders } -func loaderMiddleware(cfg config.Config, finder model.Finder, next http.Handler) http.Handler { +func loaderMiddleware(cfg model.Config, finder model.Finder, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // This is per request scoped loaders/cache // Is this OK to use as a long term cache? diff --git a/server/gql/mutation_resolver.go b/server/gql/mutation_resolver.go index 35f50f1a..ffe0f222 100644 --- a/server/gql/mutation_resolver.go +++ b/server/gql/mutation_resolver.go @@ -20,7 +20,7 @@ func (r *mutationResolver) ValidateGtfs(ctx context.Context, file *graphql.Uploa if file != nil { src = file.File } - return actions.ValidateUpload(ctx, r.cfg, src, url, rturls) + return actions.ValidateUpload(ctx, src, url, rturls) } func (r *mutationResolver) FeedVersionFetch(ctx context.Context, file *graphql.Upload, url *string, feedId string) (*model.FeedVersionFetchResult, error) { @@ -32,19 +32,19 @@ func (r *mutationResolver) FeedVersionFetch(ctx context.Context, file *graphql.U if url != nil { feedUrl = *url } - return actions.StaticFetch(ctx, r.cfg, r.finder, feedId, feedSrc, feedUrl, r.authzChecker) + return actions.StaticFetch(ctx, feedId, feedSrc, feedUrl) } func (r *mutationResolver) FeedVersionImport(ctx context.Context, fvid int) (*model.FeedVersionImportResult, error) { - return actions.FeedVersionImport(ctx, r.cfg, r.finder, r.authzChecker, fvid) + return actions.FeedVersionImport(ctx, fvid) } func (r *mutationResolver) FeedVersionUnimport(ctx context.Context, fvid int) (*model.FeedVersionUnimportResult, error) { - return actions.FeedVersionUnimport(ctx, r.cfg, r.finder, r.authzChecker, fvid) + return actions.FeedVersionUnimport(ctx, fvid) } func (r *mutationResolver) FeedVersionUpdate(ctx context.Context, fvid int, values model.FeedVersionSetInput) (*model.FeedVersion, error) { - err := actions.FeedVersionUpdate(ctx, r.cfg, r.finder, r.authzChecker, fvid, values) + err := actions.FeedVersionUpdate(ctx, fvid, values) return nil, err } diff --git a/server/gql/mutation_resolver_test.go b/server/gql/mutation_resolver_test.go index eb0bce0d..c281b6a2 100644 --- a/server/gql/mutation_resolver_test.go +++ b/server/gql/mutation_resolver_test.go @@ -27,7 +27,7 @@ func TestFeedVersionFetchResolver(t *testing.T) { })) t.Run("found sha1", func(t *testing.T) { testfinder.FindersTxRollback(t, nil, nil, func(te model.Finders) { - srv, _ := NewServer(te.Config, te.Finder, nil, nil, nil) + srv, _ := NewServer(te) srv = ancheck.AdminDefaultMiddleware("test")(srv) // Run all requests as admin // Run all requests as admin c := client.New(srv) @@ -165,7 +165,7 @@ func TestValidateGtfsResolver(t *testing.T) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { testfinder.FindersTxRollback(t, nil, nil, func(te model.Finders) { - srv, _ := NewServer(te.Config, te.Finder, nil, nil, nil) + srv, _ := NewServer(te) srv = ancheck.UserDefaultMiddleware("test")(srv) // Run all requests as user c := client.New(srv) queryTestcase(t, c, tc) @@ -174,7 +174,7 @@ func TestValidateGtfsResolver(t *testing.T) { } t.Run("requires user access", func(t *testing.T) { testfinder.FindersTxRollback(t, nil, nil, func(te model.Finders) { - srv, _ := NewServer(te.Config, te.Finder, nil, nil, nil) // all requests run as anonymous context by default + srv, _ := NewServer(te) // all requests run as anonymous context by default c := client.New(srv) resp := make(map[string]interface{}) err := c.Post(`mutation($url:String!) {validate_gtfs(url:$url){success}}`, &resp, client.Var("url", ts200.URL)) diff --git a/server/gql/resolver.go b/server/gql/resolver.go index 3f9b0951..153bdc2f 100644 --- a/server/gql/resolver.go +++ b/server/gql/resolver.go @@ -4,7 +4,6 @@ import ( "context" "strconv" - "github.com/interline-io/transitland-server/config" "github.com/interline-io/transitland-server/internal/generated/gqlout" "github.com/interline-io/transitland-server/internal/xy" "github.com/interline-io/transitland-server/model" @@ -54,7 +53,7 @@ func atoi(v string) int { // Resolver . type Resolver struct { - cfg config.Config + cfg model.Config rtfinder model.RTFinder finder model.Finder gbfsFinder model.GbfsFinder diff --git a/server/gql/resolver_test.go b/server/gql/resolver_test.go index a5e7be86..b1ffc8bc 100644 --- a/server/gql/resolver_test.go +++ b/server/gql/resolver_test.go @@ -56,7 +56,7 @@ func newTestClient(t testing.TB) (*client.Client, model.Finders) { func newTestClientWithClock(t testing.TB, cl clock.Clock, rtfiles []testfinder.RTJsonFile) (*client.Client, model.Finders) { te := testfinder.Finders(t, cl, rtfiles) - srv, _ := NewServer(te.Config, te.Finder, te.RTFinder, te.GbfsFinder, te.Checker) + srv, _ := NewServer(te) srvMiddleware := ancheck.NewUserDefaultMiddleware(func() authn.User { return authn.NewCtxUser("testuser", "", "").WithRoles("testrole") }) return client.New(srvMiddleware(srv)), te } diff --git a/server/gql/server.go b/server/gql/server.go index 1a05dec3..9517ea0d 100644 --- a/server/gql/server.go +++ b/server/gql/server.go @@ -8,19 +8,18 @@ import ( "github.com/99designs/gqlgen/graphql" "github.com/99designs/gqlgen/graphql/handler" "github.com/interline-io/transitland-mw/auth/authn" - "github.com/interline-io/transitland-server/config" "github.com/interline-io/transitland-server/internal/generated/gqlout" "github.com/interline-io/transitland-server/model" ) -func NewServer(cfg config.Config, dbfinder model.Finder, rtfinder model.RTFinder, gbfsFinder model.GbfsFinder, checker model.Checker) (http.Handler, error) { +func NewServer(te model.Finders) (http.Handler, error) { c := gqlout.Config{Resolvers: &Resolver{ - cfg: cfg, - finder: dbfinder, - rtfinder: rtfinder, - gbfsFinder: gbfsFinder, - fvslCache: newFvslCache(dbfinder), - authzChecker: checker, + cfg: te.Config, + finder: te.Finder, + rtfinder: te.RTFinder, + gbfsFinder: te.GbfsFinder, + fvslCache: newFvslCache(te.Finder), + authzChecker: te.Checker, }} c.Directives.HasRole = func(ctx context.Context, obj interface{}, next graphql.Resolver, role model.Role) (interface{}, error) { user := authn.ForContext(ctx) @@ -31,6 +30,6 @@ func NewServer(cfg config.Config, dbfinder model.Finder, rtfinder model.RTFinder } // Setup server srv := handler.NewDefaultServer(gqlout.NewExecutableSchema(c)) - graphqlServer := loaderMiddleware(cfg, dbfinder, srv) + graphqlServer := loaderMiddleware(te.Config, te.Finder, srv) return graphqlServer, nil } diff --git a/server/rest/rest.go b/server/rest/rest.go index 8838ae2e..514d8b21 100644 --- a/server/rest/rest.go +++ b/server/rest/rest.go @@ -17,7 +17,6 @@ import ( "github.com/interline-io/log" "github.com/interline-io/transitland-mw/auth/ancheck" "github.com/interline-io/transitland-mw/meters" - "github.com/interline-io/transitland-server/config" "github.com/interline-io/transitland-server/internal/util" "github.com/interline-io/transitland-server/model" ) @@ -33,12 +32,12 @@ const MAXRADIUS = 100 * 1000.0 // restConfig holds the base config and the graphql handler type restConfig struct { - config.Config + model.Config srv http.Handler } // NewServer . -func NewServer(cfg config.Config, srv http.Handler) (http.Handler, error) { +func NewServer(cfg model.Config, srv http.Handler) (http.Handler, error) { restcfg := restConfig{Config: cfg, srv: srv} r := chi.NewRouter() diff --git a/server/rest/rest_test.go b/server/rest/rest_test.go index 5c47786d..82fdce64 100644 --- a/server/rest/rest_test.go +++ b/server/rest/rest_test.go @@ -9,7 +9,6 @@ import ( "testing" "time" - "github.com/interline-io/transitland-server/config" "github.com/interline-io/transitland-server/internal/clock" "github.com/interline-io/transitland-server/internal/testfinder" "github.com/interline-io/transitland-server/internal/testutil" @@ -36,14 +35,14 @@ func testRestConfig(t testing.TB) (http.Handler, model.Finders) { t.Fatal(err) } te := testfinder.Finders(t, &clock.Mock{T: when}, testfinder.DefaultRTJson()) - srv, err := gql.NewServer(te.Config, te.Finder, te.RTFinder, te.GbfsFinder, te.Checker) + srv, err := gql.NewServer(te) if err != nil { panic(err) } return srv, te } -func testRestServer(t testing.TB, cfg config.Config, srv http.Handler) (http.Handler, error) { +func testRestServer(t testing.TB, cfg model.Config, srv http.Handler) (http.Handler, error) { return NewServer(cfg, srv) } diff --git a/server/server_cmd.go b/server/server_cmd.go index f8911b78..bada2a45 100644 --- a/server/server_cmd.go +++ b/server/server_cmd.go @@ -28,7 +28,6 @@ import ( "github.com/interline-io/transitland-mw/lmw" "github.com/interline-io/transitland-mw/meters" "github.com/interline-io/transitland-mw/metrics" - "github.com/interline-io/transitland-server/config" "github.com/interline-io/transitland-server/finders/dbfinder" "github.com/interline-io/transitland-server/finders/gbfsfinder" "github.com/interline-io/transitland-server/finders/rtfinder" @@ -63,7 +62,7 @@ type Command struct { metricsConfig metrics.Config AuthConfig ancheck.AuthConfig CheckerConfig azcheck.CheckerConfig - config.Config + model.Config } func (cmd *Command) Parse(args []string) error { @@ -271,7 +270,14 @@ func (cmd *Command) Run() error { } // GraphQL API - graphqlServer, err := gql.NewServer(cfg, dbFinder, rtFinder, gbfsFinder, checker) + te := model.Finders{ + Config: cfg, + Finder: dbFinder, + RTFinder: rtFinder, + GbfsFinder: gbfsFinder, + Checker: checker, + } + graphqlServer, err := gql.NewServer(te) if err != nil { return err } @@ -321,12 +327,9 @@ func (cmd *Command) Run() error { // Start workers/api jobWorkers := 8 jobOptions := jobs.JobOptions{ - Logger: log.Logger, - JobQueue: jobQueue, - Finder: dbFinder, - RTFinder: rtFinder, - GbfsFinder: gbfsFinder, - Config: cfg, + Finders: te, + Logger: log.Logger, + JobQueue: jobQueue, } // Add metrics // jobQueue.Use(metrics.NewJobMiddleware("", metricProvider.NewJobMetric("default"))) @@ -340,7 +343,7 @@ func (cmd *Command) Run() error { } if cmd.EnableJobsApi { log.Infof("Enabling job api") - jobServer, err := workers.NewServer(cfg, "", jobWorkers, jobOptions) + jobServer, err := workers.NewServer("", jobWorkers, jobOptions) if err != nil { return err } diff --git a/workers/fetch_enqueue_worker.go b/workers/fetch_enqueue_worker.go index 75e96686..3f28788f 100644 --- a/workers/fetch_enqueue_worker.go +++ b/workers/fetch_enqueue_worker.go @@ -17,9 +17,9 @@ type FetchEnqueueWorker struct { func (w *FetchEnqueueWorker) Run(ctx context.Context, job jobs.Job) error { opts := job.Opts - db := opts.Finder.DBX() + db := opts.Finders.Finder.DBX() now := time.Now().In(time.UTC) - feeds, err := job.Opts.Finder.FindFeeds(ctx, nil, nil, nil, &model.FeedFilter{}) + feeds, err := job.Opts.Finders.Finder.FindFeeds(ctx, nil, nil, nil, &model.FeedFilter{}) if err != nil { return err } diff --git a/workers/gbfs_fetch_worker.go b/workers/gbfs_fetch_worker.go index afd82dd8..e5e2d08a 100644 --- a/workers/gbfs_fetch_worker.go +++ b/workers/gbfs_fetch_worker.go @@ -20,7 +20,7 @@ type GbfsFetchWorker struct { func (w *GbfsFetchWorker) Run(ctx context.Context, job jobs.Job) error { log := job.Opts.Logger.With().Str("feed_id", w.FeedID).Str("url", w.Url).Logger() - gfeeds, err := job.Opts.Finder.FindFeeds(ctx, nil, nil, nil, &model.FeedFilter{OnestopID: &w.FeedID}) + gfeeds, err := job.Opts.Finders.Finder.FindFeeds(ctx, nil, nil, nil, &model.FeedFilter{OnestopID: &w.FeedID}) if err != nil { log.Error().Err(err).Msg("gbfsfetch worker: error loading source feed") return err @@ -40,7 +40,7 @@ func (w *GbfsFetchWorker) Run(ctx context.Context, job jobs.Job) error { opts.FeedURL = w.Url } feeds, result, err := gbfs.Fetch( - tldb.NewPostgresAdapterFromDBX(job.Opts.Finder.DBX()), + tldb.NewPostgresAdapterFromDBX(job.Opts.Finders.Finder.DBX()), opts, ) if err != nil { @@ -54,7 +54,7 @@ func (w *GbfsFetchWorker) Run(ctx context.Context, job jobs.Job) error { for _, feed := range feeds { if feed.SystemInformation != nil { key := fmt.Sprintf("%s:%s", w.FeedID, feed.SystemInformation.Language.Val) - job.Opts.GbfsFinder.AddData(ctx, key, feed) + job.Opts.Finders.GbfsFinder.AddData(ctx, key, feed) } } log.Info().Msg("gbfs fetch worker: success") diff --git a/workers/gbfs_fetch_worker_test.go b/workers/gbfs_fetch_worker_test.go index 091e967f..0bb75e45 100644 --- a/workers/gbfs_fetch_worker_test.go +++ b/workers/gbfs_fetch_worker_test.go @@ -20,9 +20,7 @@ func TestGbfsFetchWorker(t *testing.T) { testfinder.FindersTxRollback(t, nil, nil, func(te model.Finders) { job := jobs.Job{} - job.Opts.Finder = te.Finder - job.Opts.RTFinder = te.RTFinder - job.Opts.GbfsFinder = te.GbfsFinder + job.Opts.Finders = te w := GbfsFetchWorker{ Url: ts.URL + "/gbfs.json", FeedID: "test-gbfs", diff --git a/workers/rt_fetch_worker.go b/workers/rt_fetch_worker.go index f2833feb..2e515beb 100644 --- a/workers/rt_fetch_worker.go +++ b/workers/rt_fetch_worker.go @@ -17,7 +17,7 @@ type RTFetchWorker struct { func (w *RTFetchWorker) Run(ctx context.Context, job jobs.Job) error { log := job.Opts.Logger.With().Str("target", w.Target).Str("source_feed_id", w.SourceFeedID).Str("source_type", w.SourceType).Str("url", w.Url).Logger() - err := actions.RTFetch(ctx, job.Opts.Config, job.Opts.Finder, job.Opts.RTFinder, w.Target, w.SourceFeedID, w.Url, w.SourceType, nil) + err := actions.RTFetch(ctx, w.Target, w.SourceFeedID, w.Url, w.SourceType) if err != nil { log.Error().Err(err).Msg("rtfetch worker: request failed") return err diff --git a/workers/static_fetch_worker.go b/workers/static_fetch_worker.go index 08db35f5..b0ba4b9c 100644 --- a/workers/static_fetch_worker.go +++ b/workers/static_fetch_worker.go @@ -16,7 +16,7 @@ type StaticFetchWorker struct { func (w *StaticFetchWorker) Run(ctx context.Context, job jobs.Job) error { log := job.Opts.Logger.With().Str("feed_id", w.FeedID).Str("feed_url", w.FeedUrl).Logger() - if result, err := actions.StaticFetch(ctx, job.Opts.Config, job.Opts.Finder, w.FeedID, nil, w.FeedUrl, nil); err != nil { + if result, err := actions.StaticFetch(ctx, w.FeedID, nil, w.FeedUrl); err != nil { log.Error().Err(err).Msg("staticfetch worker: request failed") return err } else if result.FetchError != nil { diff --git a/workers/workers.go b/workers/workers.go index 055d5644..9f06bf18 100644 --- a/workers/workers.go +++ b/workers/workers.go @@ -7,7 +7,6 @@ import ( "net/http" "github.com/go-chi/chi/v5" - "github.com/interline-io/transitland-server/config" "github.com/interline-io/transitland-server/internal/util" "github.com/interline-io/transitland-server/jobs" ) @@ -44,7 +43,7 @@ func GetWorker(job jobs.Job) (jobs.JobWorker, error) { } // NewServer creates a simple api for submitting and running jobs. -func NewServer(cfg config.Config, queueName string, workers int, jo jobs.JobOptions) (http.Handler, error) { +func NewServer(queueName string, workers int, jo jobs.JobOptions) (http.Handler, error) { r := chi.NewRouter() r.HandleFunc("/add", wrapHandler(addJobRequest, jo)) r.HandleFunc("/run", wrapHandler(runJobRequest, jo)) From 7bc48fe1c953f6a30c1a01c1469d3d1bc3a3577c Mon Sep 17 00:00:00 2001 From: Ian Rees Date: Mon, 18 Dec 2023 15:04:23 -0800 Subject: [PATCH 02/17] Context AddFinders mw --- actions/fetch.go | 3 +++ model/finders.go | 10 ++++++++++ server/gql/mutation_resolver_test.go | 1 + 3 files changed, 14 insertions(+) diff --git a/actions/fetch.go b/actions/fetch.go index e3d539b5..4e6eb694 100644 --- a/actions/fetch.go +++ b/actions/fetch.go @@ -257,6 +257,9 @@ func chunkBy[T any](items []T, chunkSize int) (chunks [][]T) { func fetchCheckFeed(ctx context.Context, feedId string, urlType string, url string) (*model.Feed, error) { frs := model.ForContext(ctx) + if frs.Finder == nil { + panic("no finder") + } checker := frs.Checker // Check feed exists diff --git a/model/finders.go b/model/finders.go index 7f62e96f..c0db321b 100644 --- a/model/finders.go +++ b/model/finders.go @@ -2,6 +2,7 @@ package model import ( "context" + "net/http" "time" "github.com/interline-io/transitland-lib/rt/pb" @@ -34,6 +35,15 @@ func WithFinders(ctx context.Context, fs Finders) context.Context { return r } +func AddFinders(te Finders) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r = r.WithContext(WithFinders(r.Context(), te)) + next.ServeHTTP(w, r) + }) + } +} + type Config struct { Storage string RTStorage string diff --git a/server/gql/mutation_resolver_test.go b/server/gql/mutation_resolver_test.go index c281b6a2..3745bb0e 100644 --- a/server/gql/mutation_resolver_test.go +++ b/server/gql/mutation_resolver_test.go @@ -28,6 +28,7 @@ func TestFeedVersionFetchResolver(t *testing.T) { t.Run("found sha1", func(t *testing.T) { testfinder.FindersTxRollback(t, nil, nil, func(te model.Finders) { srv, _ := NewServer(te) + srv = model.AddFinders(te)(srv) srv = ancheck.AdminDefaultMiddleware("test")(srv) // Run all requests as admin // Run all requests as admin c := client.New(srv) From f1ee5f0e97b432fee9d303e027bb0443e30a8f75 Mon Sep 17 00:00:00 2001 From: Ian Rees Date: Mon, 18 Dec 2023 15:30:06 -0800 Subject: [PATCH 03/17] WIP --- model/finders.go | 2 -- server/gql/agency_resolver.go | 2 +- server/gql/gbfs_resolver.go | 4 +-- server/gql/query_resolver.go | 18 +++++----- server/gql/resolver.go | 8 ++--- server/gql/route_resolver.go | 6 ++-- server/gql/server.go | 8 ++--- server/gql/stop_resolver.go | 16 ++++----- server/gql/stop_time_resolver.go | 6 ++-- server/gql/trip_resolver.go | 6 ++-- server/rest/feed_version_download.go | 14 +++++--- server/rest/feed_version_download_test.go | 4 +-- server/rest/rest.go | 41 +++++++++++------------ server/rest/rest_test.go | 6 ++-- server/server_cmd.go | 7 ++-- 15 files changed, 71 insertions(+), 77 deletions(-) diff --git a/model/finders.go b/model/finders.go index c0db321b..b8fc48d0 100644 --- a/model/finders.go +++ b/model/finders.go @@ -48,8 +48,6 @@ type Config struct { Storage string RTStorage string ValidateLargeFiles bool - DisableImage bool - RestPrefix string DBURL string RedisURL string Clock clock.Clock diff --git a/server/gql/agency_resolver.go b/server/gql/agency_resolver.go index e8b903bf..7923592d 100644 --- a/server/gql/agency_resolver.go +++ b/server/gql/agency_resolver.go @@ -35,6 +35,6 @@ func (r *agencyResolver) Operator(ctx context.Context, obj *model.Agency) (*mode } func (r *agencyResolver) Alerts(ctx context.Context, obj *model.Agency, active *bool, limit *int) ([]*model.Alert, error) { - rtAlerts := r.rtfinder.FindAlertsForAgency(obj, checkLimit(limit), active) + rtAlerts := r.frs.RTFinder.FindAlertsForAgency(obj, checkLimit(limit), active) return rtAlerts, nil } diff --git a/server/gql/gbfs_resolver.go b/server/gql/gbfs_resolver.go index e74490e1..789380e4 100644 --- a/server/gql/gbfs_resolver.go +++ b/server/gql/gbfs_resolver.go @@ -7,9 +7,9 @@ import ( ) func (r *queryResolver) Bikes(ctx context.Context, limit *int, where *model.GbfsBikeRequest) ([]*model.GbfsFreeBikeStatus, error) { - return r.gbfsFinder.FindBikes(ctx, checkLimit(limit), where) + return r.frs.GbfsFinder.FindBikes(ctx, checkLimit(limit), where) } func (r *queryResolver) Docks(ctx context.Context, limit *int, where *model.GbfsDockRequest) ([]*model.GbfsStationInformation, error) { - return r.gbfsFinder.FindDocks(ctx, checkLimit(limit), where) + return r.frs.GbfsFinder.FindDocks(ctx, checkLimit(limit), where) } diff --git a/server/gql/query_resolver.go b/server/gql/query_resolver.go index 41175fc0..d4fb9836 100644 --- a/server/gql/query_resolver.go +++ b/server/gql/query_resolver.go @@ -16,7 +16,7 @@ const MAX_RADIUS = 100_000 type queryResolver struct{ *Resolver } func (r *queryResolver) Me(ctx context.Context) (*model.Me, error) { - me, err := r.authzChecker.Me(ctx, &authz.MeRequest{}) + me, err := r.frs.Checker.Me(ctx, &authz.MeRequest{}) if err != nil { return nil, err } @@ -43,7 +43,7 @@ func (r *queryResolver) Agencies(ctx context.Context, limit *int, after *int, id return nil, errors.New("bbox too large") } } - return r.finder.FindAgencies(ctx, checkLimit(limit), checkCursor(after), ids, where) + return r.frs.Finder.FindAgencies(ctx, checkLimit(limit), checkCursor(after), ids, where) } func (r *queryResolver) Routes(ctx context.Context, limit *int, after *int, ids []int, where *model.RouteFilter) ([]*model.Route, error) { @@ -56,7 +56,7 @@ func (r *queryResolver) Routes(ctx context.Context, limit *int, after *int, ids return nil, errors.New("bbox too large") } } - return r.finder.FindRoutes(ctx, checkLimit(limit), checkCursor(after), ids, where) + return r.frs.Finder.FindRoutes(ctx, checkLimit(limit), checkCursor(after), ids, where) } func (r *queryResolver) Stops(ctx context.Context, limit *int, after *int, ids []int, where *model.StopFilter) ([]*model.Stop, error) { @@ -69,12 +69,12 @@ func (r *queryResolver) Stops(ctx context.Context, limit *int, after *int, ids [ return nil, errors.New("bbox too large") } } - return r.finder.FindStops(ctx, checkLimit(limit), checkCursor(after), ids, where) + return r.frs.Finder.FindStops(ctx, checkLimit(limit), checkCursor(after), ids, where) } func (r *queryResolver) Trips(ctx context.Context, limit *int, after *int, ids []int, where *model.TripFilter) ([]*model.Trip, error) { addMetric(ctx, "trips") - return r.finder.FindTrips(ctx, checkLimit(limit), checkCursor(after), ids, where) + return r.frs.Finder.FindTrips(ctx, checkLimit(limit), checkCursor(after), ids, where) } func (r *queryResolver) FeedVersions(ctx context.Context, limit *int, after *int, ids []int, where *model.FeedVersionFilter) ([]*model.FeedVersion, error) { @@ -87,7 +87,7 @@ func (r *queryResolver) FeedVersions(ctx context.Context, limit *int, after *int return nil, errors.New("bbox too large") } } - return r.finder.FindFeedVersions(ctx, checkLimit(limit), checkCursor(after), ids, where) + return r.frs.Finder.FindFeedVersions(ctx, checkLimit(limit), checkCursor(after), ids, where) } func (r *queryResolver) Feeds(ctx context.Context, limit *int, after *int, ids []int, where *model.FeedFilter) ([]*model.Feed, error) { @@ -100,7 +100,7 @@ func (r *queryResolver) Feeds(ctx context.Context, limit *int, after *int, ids [ return nil, errors.New("bbox too large") } } - return r.finder.FindFeeds(ctx, checkLimit(limit), checkCursor(after), ids, where) + return r.frs.Finder.FindFeeds(ctx, checkLimit(limit), checkCursor(after), ids, where) } func (r *queryResolver) Operators(ctx context.Context, limit *int, after *int, ids []int, where *model.OperatorFilter) ([]*model.Operator, error) { @@ -113,11 +113,11 @@ func (r *queryResolver) Operators(ctx context.Context, limit *int, after *int, i return nil, errors.New("bbox too large") } } - return r.finder.FindOperators(ctx, checkLimit(limit), checkCursor(after), ids, where) + return r.frs.Finder.FindOperators(ctx, checkLimit(limit), checkCursor(after), ids, where) } func (r *queryResolver) Places(ctx context.Context, limit *int, after *int, level *model.PlaceAggregationLevel, where *model.PlaceFilter) ([]*model.Place, error) { - return r.finder.FindPlaces(ctx, checkLimit(limit), checkCursor(after), nil, level, where) + return r.frs.Finder.FindPlaces(ctx, checkLimit(limit), checkCursor(after), nil, level, where) } func addMetric(ctx context.Context, resolverName string) { diff --git a/server/gql/resolver.go b/server/gql/resolver.go index 153bdc2f..c47dee74 100644 --- a/server/gql/resolver.go +++ b/server/gql/resolver.go @@ -53,12 +53,8 @@ func atoi(v string) int { // Resolver . type Resolver struct { - cfg model.Config - rtfinder model.RTFinder - finder model.Finder - gbfsFinder model.GbfsFinder - authzChecker model.Checker - fvslCache *fvslCache + frs model.Finders + fvslCache *fvslCache } // Query . diff --git a/server/gql/route_resolver.go b/server/gql/route_resolver.go index f47bca87..5ed36ec8 100644 --- a/server/gql/route_resolver.go +++ b/server/gql/route_resolver.go @@ -63,7 +63,7 @@ func (r *routeResolver) Headways(ctx context.Context, obj *model.Route, limit *i func (r *routeResolver) RouteStopBuffer(ctx context.Context, obj *model.Route, radius *float64) (*model.RouteStopBuffer, error) { // TODO: remove n+1 (which is tricky, what if multiple radius specified in different parts of query) p := model.RouteStopBufferParam{Radius: radius, EntityID: obj.ID} - ents, err := r.finder.RouteStopBuffer(ctx, &p) + ents, err := r.frs.Finder.RouteStopBuffer(ctx, &p) if err != nil { return nil, err } @@ -74,7 +74,7 @@ func (r *routeResolver) RouteStopBuffer(ctx context.Context, obj *model.Route, r } func (r *routeResolver) Alerts(ctx context.Context, obj *model.Route, active *bool, limit *int) ([]*model.Alert, error) { - return r.rtfinder.FindAlertsForRoute(obj, checkLimit(limit), active), nil + return r.frs.RTFinder.FindAlertsForRoute(obj, checkLimit(limit), active), nil } func (r *routeResolver) Patterns(ctx context.Context, obj *model.Route) ([]*model.RouteStopPattern, error) { @@ -124,7 +124,7 @@ type routePatternResolver struct{ *Resolver } func (r *routePatternResolver) Trips(ctx context.Context, obj *model.RouteStopPattern, limit *int) ([]*model.Trip, error) { // TODO: N+1 query - trips, err := r.finder.FindTrips(ctx, checkLimit(limit), nil, nil, &model.TripFilter{StopPatternID: &obj.StopPatternID, RouteIds: []int{obj.RouteID}}) + trips, err := r.frs.Finder.FindTrips(ctx, checkLimit(limit), nil, nil, &model.TripFilter{StopPatternID: &obj.StopPatternID, RouteIds: []int{obj.RouteID}}) return trips, err } diff --git a/server/gql/server.go b/server/gql/server.go index 9517ea0d..1669eb79 100644 --- a/server/gql/server.go +++ b/server/gql/server.go @@ -14,12 +14,8 @@ import ( func NewServer(te model.Finders) (http.Handler, error) { c := gqlout.Config{Resolvers: &Resolver{ - cfg: te.Config, - finder: te.Finder, - rtfinder: te.RTFinder, - gbfsFinder: te.GbfsFinder, - fvslCache: newFvslCache(te.Finder), - authzChecker: te.Checker, + frs: te, + fvslCache: newFvslCache(te.Finder), }} c.Directives.HasRole = func(ctx context.Context, obj interface{}, next graphql.Resolver, role model.Role) (interface{}, error) { user := authn.ForContext(ctx) diff --git a/server/gql/stop_resolver.go b/server/gql/stop_resolver.go index 0a8b2c4b..b20402f7 100644 --- a/server/gql/stop_resolver.go +++ b/server/gql/stop_resolver.go @@ -99,13 +99,13 @@ func (r *stopResolver) getStopTimes(ctx context.Context, obj *model.Stop, limit if where != nil { // Convert where.Next into departure date and time window if where.Next != nil { - loc, ok := r.rtfinder.StopTimezone(obj.ID, obj.StopTimezone) + loc, ok := r.frs.RTFinder.StopTimezone(obj.ID, obj.StopTimezone) if !ok { return nil, errors.New("timezone not available for stop") } serviceDate := time.Now().In(loc) - if r.cfg.Clock != nil { - serviceDate = r.cfg.Clock.Now().In(loc) + if r.frs.Config.Clock != nil { + serviceDate = r.frs.Config.Clock.Now().In(loc) } st, et := 0, 0 st = serviceDate.Hour()*3600 + serviceDate.Minute()*60 + serviceDate.Second() @@ -163,13 +163,13 @@ func (r *stopResolver) getStopTimes(ctx context.Context, obj *model.Stop, limit for _, st := range sts { ft := model.Trip{} ft.FeedVersionID = obj.FeedVersionID - ft.TripID, _ = r.rtfinder.GetGtfsTripID(atoi(st.TripID)) // TODO! - if ste, ok := r.rtfinder.FindStopTimeUpdate(&ft, st); ok { + ft.TripID, _ = r.frs.RTFinder.GetGtfsTripID(atoi(st.TripID)) // TODO! + if ste, ok := r.frs.RTFinder.FindStopTimeUpdate(&ft, st); ok { st.RTStopTimeUpdate = ste } } // Handle added trips; these must specify stop_id in StopTimeUpdates - for _, rtTrip := range r.rtfinder.GetAddedTripsForStop(obj) { + for _, rtTrip := range r.frs.RTFinder.GetAddedTripsForStop(obj) { for _, stu := range rtTrip.StopTimeUpdate { if stu.GetStopId() != obj.StopID { continue @@ -196,7 +196,7 @@ func (r *stopResolver) getStopTimes(ctx context.Context, obj *model.Stop, limit } func (r *stopResolver) Alerts(ctx context.Context, obj *model.Stop, active *bool, limit *int) ([]*model.Alert, error) { - rtAlerts := r.rtfinder.FindAlertsForStop(obj, checkLimit(limit), active) + rtAlerts := r.frs.RTFinder.FindAlertsForStop(obj, checkLimit(limit), active) return rtAlerts, nil } @@ -223,7 +223,7 @@ func (r *stopResolver) Directions(ctx context.Context, obj *model.Stop, from *mo func (r *stopResolver) NearbyStops(ctx context.Context, obj *model.Stop, limit *int, radius *float64) ([]*model.Stop, error) { c := obj.Coordinates() - nearbyStops, err := r.finder.FindStops(ctx, checkLimit(limit), nil, nil, &model.StopFilter{Near: &model.PointRadius{Lon: c[0], Lat: c[1], Radius: checkFloat(radius, 0, MAX_RADIUS)}}) + nearbyStops, err := r.frs.Finder.FindStops(ctx, checkLimit(limit), nil, nil, &model.StopFilter{Near: &model.PointRadius{Lon: c[0], Lat: c[1], Radius: checkFloat(radius, 0, MAX_RADIUS)}}) return nearbyStops, err } diff --git a/server/gql/stop_time_resolver.go b/server/gql/stop_time_resolver.go index e76688f7..1dfa1de1 100644 --- a/server/gql/stop_time_resolver.go +++ b/server/gql/stop_time_resolver.go @@ -26,7 +26,7 @@ func (r *stopTimeResolver) Trip(ctx context.Context, obj *model.StopTime) (*mode t := model.Trip{} t.FeedVersionID = obj.FeedVersionID t.TripID = obj.RTTripID - a, err := r.rtfinder.MakeTrip(&t) + a, err := r.frs.RTFinder.MakeTrip(&t) return a, err } return For(ctx).TripsByID.Load(ctx, atoi(obj.TripID))() @@ -34,7 +34,7 @@ func (r *stopTimeResolver) Trip(ctx context.Context, obj *model.StopTime) (*mode func (r *stopTimeResolver) Arrival(ctx context.Context, obj *model.StopTime) (*model.StopTimeEvent, error) { // lookup timezone - loc, ok := r.rtfinder.StopTimezone(atoi(obj.StopID), "") + loc, ok := r.frs.RTFinder.StopTimezone(atoi(obj.StopID), "") if !ok { return nil, errors.New("timezone not available for stop") } @@ -54,7 +54,7 @@ func (r *stopTimeResolver) Arrival(ctx context.Context, obj *model.StopTime) (*m func (r *stopTimeResolver) Departure(ctx context.Context, obj *model.StopTime) (*model.StopTimeEvent, error) { // lookup timezone - loc, ok := r.rtfinder.StopTimezone(atoi(obj.StopID), "") + loc, ok := r.frs.RTFinder.StopTimezone(atoi(obj.StopID), "") if !ok { return nil, errors.New("timezone not available for stop") } diff --git a/server/gql/trip_resolver.go b/server/gql/trip_resolver.go index 30c285f6..1fcd08f4 100644 --- a/server/gql/trip_resolver.go +++ b/server/gql/trip_resolver.go @@ -45,7 +45,7 @@ func (r *tripResolver) Frequencies(ctx context.Context, obj *model.Trip, limit * func (r *tripResolver) ScheduleRelationship(ctx context.Context, obj *model.Trip) (*model.ScheduleRelationship, error) { msr := model.ScheduleRelationshipScheduled - if rtt := r.rtfinder.FindTrip(obj); rtt != nil { + if rtt := r.frs.RTFinder.FindTrip(obj); rtt != nil { sr := rtt.GetTrip().GetScheduleRelationship().String() switch sr { case "SCHEDULED": @@ -64,7 +64,7 @@ func (r *tripResolver) ScheduleRelationship(ctx context.Context, obj *model.Trip } func (r *tripResolver) Timestamp(ctx context.Context, obj *model.Trip) (*time.Time, error) { - if rtt := r.rtfinder.FindTrip(obj); rtt != nil { + if rtt := r.frs.RTFinder.FindTrip(obj); rtt != nil { t := time.Unix(int64(rtt.GetTimestamp()), 0).In(time.UTC) return &t, nil } @@ -72,6 +72,6 @@ func (r *tripResolver) Timestamp(ctx context.Context, obj *model.Trip) (*time.Ti } func (r *tripResolver) Alerts(ctx context.Context, obj *model.Trip, active *bool, limit *int) ([]*model.Alert, error) { - rtAlerts := r.rtfinder.FindAlertsForTrip(obj, checkLimit(limit), active) + rtAlerts := r.frs.RTFinder.FindAlertsForTrip(obj, checkLimit(limit), active) return rtAlerts, nil } diff --git a/server/rest/feed_version_download.go b/server/rest/feed_version_download.go index 4bf60751..62b87b57 100644 --- a/server/rest/feed_version_download.go +++ b/server/rest/feed_version_download.go @@ -13,6 +13,7 @@ import ( "github.com/interline-io/transitland-lib/tl/request" "github.com/interline-io/transitland-mw/meters" "github.com/interline-io/transitland-server/internal/util" + "github.com/interline-io/transitland-server/model" "github.com/tidwall/gjson" ) @@ -32,7 +33,7 @@ query($feed_onestop_id: String!, $ids: [Int!]) { // Query redirects user to download the given fv from S3 public URL // assuming that redistribution is allowed for the feed. -func feedVersionDownloadLatestHandler(cfg restConfig, w http.ResponseWriter, r *http.Request) { +func feedVersionDownloadLatestHandler(graphqlHandler http.Handler, w http.ResponseWriter, r *http.Request) { key := chi.URLParam(r, "feed_key") gvars := hw{} if key == "" { @@ -45,7 +46,7 @@ func feedVersionDownloadLatestHandler(cfg restConfig, w http.ResponseWriter, r * } // Check if we're allowed to redistribute feed and look up latest feed version - feedResponse, err := makeGraphQLRequest(r.Context(), cfg.srv, latestFeedVersionQuery, gvars) + feedResponse, err := makeGraphQLRequest(r.Context(), graphqlHandler, latestFeedVersionQuery, gvars) if err != nil { http.Error(w, util.MakeJsonError("server error"), http.StatusInternalServerError) return @@ -83,7 +84,9 @@ func feedVersionDownloadLatestHandler(cfg restConfig, w http.ResponseWriter, r * } apiMeter.Meter("feed-version-downloads", 1.0, dims) } - serveFromStorage(w, r, cfg.Storage, fvsha1) + + frs := model.ForContext(r.Context()) + serveFromStorage(w, r, frs.Config.Storage, fvsha1) } const feedVersionFileQuery = ` @@ -102,7 +105,7 @@ query($feed_version_sha1:String!, $ids: [Int!]) { // Query redirects user to download the given fv from S3 public URL // assuming that redistribution is allowed for the feed. -func feedVersionDownloadHandler(cfg restConfig, w http.ResponseWriter, r *http.Request) { +func feedVersionDownloadHandler(graphqlHandler http.Handler, w http.ResponseWriter, r *http.Request) { gvars := hw{} key := chi.URLParam(r, "feed_version_key") if key == "" { @@ -114,7 +117,7 @@ func feedVersionDownloadHandler(cfg restConfig, w http.ResponseWriter, r *http.R gvars["feed_version_sha1"] = key } // Check if we're allowed to redistribute feed - checkfv, err := makeGraphQLRequest(r.Context(), cfg.srv, feedVersionFileQuery, gvars) + checkfv, err := makeGraphQLRequest(r.Context(), graphqlHandler, feedVersionFileQuery, gvars) if err != nil { http.Error(w, util.MakeJsonError("server error"), http.StatusInternalServerError) return @@ -159,6 +162,7 @@ func feedVersionDownloadHandler(cfg restConfig, w http.ResponseWriter, r *http.R apiMeter.Meter("feed-version-downloads", 1.0, dims) } + cfg := model.ForContext(r.Context()).Config serveFromStorage(w, r, cfg.Storage, fvsha1) } diff --git a/server/rest/feed_version_download_test.go b/server/rest/feed_version_download_test.go index dacf6881..850bdc19 100644 --- a/server/rest/feed_version_download_test.go +++ b/server/rest/feed_version_download_test.go @@ -18,7 +18,7 @@ func TestFeedVersionDownloadRequest(t *testing.T) { } srv, te := testRestConfig(t) te.Config.Storage = g - restSrv, err := testRestServer(t, te.Config, srv) + restSrv, err := testRestServer(t, Config{Config: te.Config}, srv) if err != nil { t.Fatal(err) } @@ -119,7 +119,7 @@ func TestFeedDownloadLatestRequest(t *testing.T) { } srv, te := testRestConfig(t) te.Config.Storage = g - restSrv, err := testRestServer(t, te.Config, srv) + restSrv, err := testRestServer(t, Config{Config: te.Config}, srv) if err != nil { t.Fatal(err) } diff --git a/server/rest/rest.go b/server/rest/rest.go index 514d8b21..83248815 100644 --- a/server/rest/rest.go +++ b/server/rest/rest.go @@ -30,38 +30,37 @@ var MAXLIMIT = 1_000 // MAXRADIUS is the maximum point search radius const MAXRADIUS = 100 * 1000.0 -// restConfig holds the base config and the graphql handler -type restConfig struct { +type Config struct { + DisableImage bool + RestPrefix string model.Config - srv http.Handler } // NewServer . -func NewServer(cfg model.Config, srv http.Handler) (http.Handler, error) { - restcfg := restConfig{Config: cfg, srv: srv} +func NewServer(cfg Config, graphqlHandler http.Handler) (http.Handler, error) { r := chi.NewRouter() - feedHandler := makeHandler(restcfg, "feeds", func() apiHandler { return &FeedRequest{} }) - feedVersionHandler := makeHandler(restcfg, "feedVersions", func() apiHandler { return &FeedVersionRequest{} }) - agencyHandler := makeHandler(restcfg, "agencies", func() apiHandler { return &AgencyRequest{} }) - routeHandler := makeHandler(restcfg, "routes", func() apiHandler { return &RouteRequest{} }) - tripHandler := makeHandler(restcfg, "trips", func() apiHandler { return &TripRequest{} }) - stopHandler := makeHandler(restcfg, "stops", func() apiHandler { return &StopRequest{} }) - stopDepartureHandler := makeHandler(restcfg, "stopDepartures", func() apiHandler { return &StopDepartureRequest{} }) - operatorHandler := makeHandler(restcfg, "operators", func() apiHandler { return &OperatorRequest{} }) + feedHandler := makeHandler(cfg, graphqlHandler, "feeds", func() apiHandler { return &FeedRequest{} }) + feedVersionHandler := makeHandler(cfg, graphqlHandler, "feedVersions", func() apiHandler { return &FeedVersionRequest{} }) + agencyHandler := makeHandler(cfg, graphqlHandler, "agencies", func() apiHandler { return &AgencyRequest{} }) + routeHandler := makeHandler(cfg, graphqlHandler, "routes", func() apiHandler { return &RouteRequest{} }) + tripHandler := makeHandler(cfg, graphqlHandler, "trips", func() apiHandler { return &TripRequest{} }) + stopHandler := makeHandler(cfg, graphqlHandler, "stops", func() apiHandler { return &StopRequest{} }) + stopDepartureHandler := makeHandler(cfg, graphqlHandler, "stopDepartures", func() apiHandler { return &StopDepartureRequest{} }) + operatorHandler := makeHandler(cfg, graphqlHandler, "operators", func() apiHandler { return &OperatorRequest{} }) r.HandleFunc("/feeds.{format}", feedHandler) r.HandleFunc("/feeds", feedHandler) r.HandleFunc("/feeds/{feed_key}.{format}", feedHandler) r.HandleFunc("/feeds/{feed_key}", feedHandler) - r.Handle("/feeds/{feed_key}/download_latest_feed_version", ancheck.RoleRequired("tl_download_fv_current")(makeHandlerFunc(restcfg, "feedVersionDownloadLatest", feedVersionDownloadLatestHandler))) + r.Handle("/feeds/{feed_key}/download_latest_feed_version", ancheck.RoleRequired("tl_download_fv_current")(makeHandlerFunc(graphqlHandler, "feedVersionDownloadLatest", feedVersionDownloadLatestHandler))) r.HandleFunc("/feed_versions.{format}", feedVersionHandler) r.HandleFunc("/feed_versions", feedVersionHandler) r.HandleFunc("/feed_versions/{feed_version_key}.{format}", feedVersionHandler) r.HandleFunc("/feed_versions/{feed_version_key}", feedVersionHandler) r.HandleFunc("/feeds/{feed_key}/feed_versions", feedVersionHandler) - r.Handle("/feed_versions/{feed_version_key}/download", ancheck.RoleRequired("tl_download_fv_historic")(makeHandlerFunc(restcfg, "feedVersionDownload", feedVersionDownloadHandler))) + r.Handle("/feed_versions/{feed_version_key}/download", ancheck.RoleRequired("tl_download_fv_historic")(makeHandlerFunc(graphqlHandler, "feedVersionDownload", feedVersionDownloadHandler))) r.HandleFunc("/agencies.{format}", agencyHandler) r.HandleFunc("/agencies", agencyHandler) @@ -185,7 +184,7 @@ func queryToMap(vars url.Values) map[string]string { } // makeHandler wraps an apiHandler into an HandlerFunc and performs common checks. -func makeHandler(cfg restConfig, handlerName string, f func() apiHandler) http.HandlerFunc { +func makeHandler(cfg Config, graphqlHandler http.Handler, handlerName string, f func() apiHandler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ent := f() opts := queryToMap(r.URL.Query()) @@ -238,7 +237,7 @@ func makeHandler(cfg restConfig, handlerName string, f func() apiHandler) http.H } // Make the request - response, err := makeRequest(r.Context(), cfg, ent, format, r.URL) + response, err := makeRequest(r.Context(), cfg, graphqlHandler, ent, format, r.URL) if err != nil { http.Error(w, util.MakeJsonError(err.Error()), http.StatusInternalServerError) return @@ -297,9 +296,9 @@ func makeGraphQLRequest(ctx context.Context, srv http.Handler, query string, var } // makeRequest prepares an apiHandler and makes the request. -func makeRequest(ctx context.Context, cfg restConfig, ent apiHandler, format string, u *url.URL) ([]byte, error) { +func makeRequest(ctx context.Context, cfg Config, graphqlHandler http.Handler, ent apiHandler, format string, u *url.URL) ([]byte, error) { query, vars := ent.Query() - response, err := makeGraphQLRequest(ctx, cfg.srv, query, vars) + response, err := makeGraphQLRequest(ctx, graphqlHandler, query, vars) if err != nil { vjson, _ := json.Marshal(vars) log.Error().Err(err).Str("query", query).Str("vars", string(vjson)).Msgf("graphql request failed") @@ -374,12 +373,12 @@ func renderGeojsonl(response map[string]any) ([]byte, error) { return ret, nil } -func makeHandlerFunc(cfg restConfig, handlerName string, f func(restConfig, http.ResponseWriter, *http.Request)) http.HandlerFunc { +func makeHandlerFunc(graphqlHandler http.Handler, handlerName string, f func(http.Handler, http.ResponseWriter, *http.Request)) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if apiMeter := meters.ForContext(r.Context()); apiMeter != nil { apiMeter.AddDimension("rest", "handler", handlerName) } - f(cfg, w, r) + f(graphqlHandler, w, r) } } diff --git a/server/rest/rest_test.go b/server/rest/rest_test.go index 82fdce64..29bdd91d 100644 --- a/server/rest/rest_test.go +++ b/server/rest/rest_test.go @@ -42,7 +42,7 @@ func testRestConfig(t testing.TB) (http.Handler, model.Finders) { return srv, te } -func testRestServer(t testing.TB, cfg model.Config, srv http.Handler) (http.Handler, error) { +func testRestServer(t testing.TB, cfg Config, srv http.Handler) (http.Handler, error) { return NewServer(cfg, srv) } @@ -61,8 +61,8 @@ type testRest struct { f func(*testing.T, string) } -func testquery(t *testing.T, srv http.Handler, te model.Finders, tc testRest) { - data, err := makeRequest(context.TODO(), restConfig{srv: srv, Config: te.Config}, tc.h, tc.format, nil) +func testquery(t *testing.T, graphqlHandler http.Handler, te model.Finders, tc testRest) { + data, err := makeRequest(context.TODO(), Config{Config: te.Config}, graphqlHandler, tc.h, tc.format, nil) if err != nil { t.Error(err) return diff --git a/server/server_cmd.go b/server/server_cmd.go index bada2a45..4dc75f65 100644 --- a/server/server_cmd.go +++ b/server/server_cmd.go @@ -62,6 +62,7 @@ type Command struct { metricsConfig metrics.Config AuthConfig ancheck.AuthConfig CheckerConfig azcheck.CheckerConfig + RestConfig rest.Config model.Config } @@ -77,9 +78,9 @@ func (cmd *Command) Parse(args []string) error { fl.StringVar(&cmd.RedisURL, "redisurl", "", "Redis URL (default: $TL_REDIS_URL)") fl.StringVar(&cmd.Storage, "storage", "", "Static storage backend") fl.StringVar(&cmd.RTStorage, "rt-storage", "", "RT storage backend") - fl.StringVar(&cmd.RestPrefix, "rest-prefix", "", "REST prefix for generating pagination links") fl.BoolVar(&cmd.ValidateLargeFiles, "validate-large-files", false, "Allow validation of large files") - fl.BoolVar(&cmd.DisableImage, "disable-image", false, "Disable image generation") + fl.StringVar(&cmd.RestConfig.RestPrefix, "rest-prefix", "", "REST prefix for generating pagination links") + fl.BoolVar(&cmd.RestConfig.DisableImage, "disable-image", false, "Disable image generation") // Server config fl.StringVar(&cmd.Port, "port", "8080", "") @@ -293,7 +294,7 @@ func (cmd *Command) Run() error { // REST API if !cmd.DisableRest { - restServer, err := rest.NewServer(cfg, graphqlServer) + restServer, err := rest.NewServer(cmd.RestConfig, graphqlServer) if err != nil { return err } From ac86afb6986163d1832f7f0dc734c53db8b1c8fc Mon Sep 17 00:00:00 2001 From: Ian Rees Date: Mon, 18 Dec 2023 15:46:01 -0800 Subject: [PATCH 04/17] WIP --- actions/fetch.go | 9 ++-- actions/fv.go | 3 +- actions/validate.go | 3 +- internal/testfinder/testfinder.go | 9 ++-- model/finders.go | 24 ++++----- server/gql/loaders.go | 4 +- server/gql/server.go | 2 +- server/gql/stop_resolver.go | 4 +- server/rest/feed_version_download.go | 4 +- server/rest/feed_version_download_test.go | 8 +-- server/rest/rest.go | 1 - server/rest/rest_test.go | 2 +- server/server_cmd.go | 59 ++++++++++++----------- 13 files changed, 62 insertions(+), 70 deletions(-) diff --git a/actions/fetch.go b/actions/fetch.go index 4e6eb694..dfc4287e 100644 --- a/actions/fetch.go +++ b/actions/fetch.go @@ -26,7 +26,6 @@ import ( func StaticFetch(ctx context.Context, feedId string, feedSrc io.Reader, feedUrl string) (*model.FeedVersionFetchResult, error) { frs := model.ForContext(ctx) - cfg := frs.Config dbf := frs.Finder urlType := "static_current" @@ -43,8 +42,8 @@ func StaticFetch(ctx context.Context, feedId string, feedSrc io.Reader, feedUrl FeedID: feed.ID, URLType: urlType, FeedURL: feedUrl, - Storage: cfg.Storage, - Secrets: cfg.Secrets, + Storage: frs.Storage, + Secrets: frs.Secrets, FetchedAt: time.Now().In(time.UTC), AllowFTPFetch: true, } @@ -105,8 +104,8 @@ func RTFetch(ctx context.Context, target string, feedId string, feedUrl string, FeedID: feed.ID, URLType: urlType, FeedURL: feedUrl, - Storage: frs.Config.RTStorage, - Secrets: frs.Config.Secrets, + Storage: frs.RTStorage, + Secrets: frs.Secrets, FetchedAt: time.Now().In(time.UTC), } diff --git a/actions/fv.go b/actions/fv.go index e2200666..487ae034 100644 --- a/actions/fv.go +++ b/actions/fv.go @@ -16,7 +16,6 @@ import ( func FeedVersionImport(ctx context.Context, fvid int) (*model.FeedVersionImportResult, error) { frs := model.ForContext(ctx) checker := frs.Checker - cfg := frs.Config dbf := frs.Finder if checker == nil { return nil, authz.ErrUnauthorized @@ -28,7 +27,7 @@ func FeedVersionImport(ctx context.Context, fvid int) (*model.FeedVersionImportR } opts := importer.Options{ FeedVersionID: fvid, - Storage: cfg.Storage, + Storage: frs.Storage, } db := tldb.NewPostgresAdapterFromDBX(dbf.DBX()) fr, fe := importer.MainImportFeedVersion(db, opts) diff --git a/actions/validate.go b/actions/validate.go index 80c5bbad..fae0692c 100644 --- a/actions/validate.go +++ b/actions/validate.go @@ -22,7 +22,6 @@ type hasGeometries interface { // ValidateUpload takes a file Reader and produces a validation package containing errors, warnings, file infos, service levels, etc. func ValidateUpload(ctx context.Context, src io.Reader, feedURL *string, rturls []string) (*model.ValidationResult, error) { frs := model.ForContext(ctx) - cfg := frs.Config // Check inputs rturlsok := []string{} @@ -88,7 +87,7 @@ func ValidateUpload(ctx context.Context, src io.Reader, feedURL *string, rturls MaxRTMessageSize: 10_000_000, ValidateRealtimeMessages: rturls, } - if cfg.ValidateLargeFiles { + if frs.ValidateLargeFiles { opts.CheckFileLimits = false } diff --git a/internal/testfinder/testfinder.go b/internal/testfinder/testfinder.go index fd788f0e..fa56272d 100644 --- a/internal/testfinder/testfinder.go +++ b/internal/testfinder/testfinder.go @@ -33,11 +33,6 @@ func newFinders(t testing.TB, db sqlx.Ext, opts TestFinderOptions) model.Finders if opts.Clock == nil { opts.Clock = &clock.Real{} } - cfg := model.Config{ - Clock: opts.Clock, - Storage: t.TempDir(), - RTStorage: t.TempDir(), - } // Setup Checker checkerCfg := azcheck.CheckerConfig{ @@ -77,11 +72,13 @@ func newFinders(t testing.TB, db sqlx.Ext, opts TestFinderOptions) model.Finders gbf := gbfsfinder.NewFinder(nil) return model.Finders{ - Config: cfg, Finder: dbf, RTFinder: rtf, GbfsFinder: gbf, Checker: checker, + Clock: opts.Clock, + Storage: t.TempDir(), + RTStorage: t.TempDir(), } } diff --git a/model/finders.go b/model/finders.go index b8fc48d0..80dcce29 100644 --- a/model/finders.go +++ b/model/finders.go @@ -44,23 +44,17 @@ func AddFinders(te Finders) func(http.Handler) http.Handler { } } -type Config struct { - Storage string - RTStorage string - ValidateLargeFiles bool - DBURL string - RedisURL string +type Finders struct { + Finder Finder + RTFinder RTFinder + GbfsFinder GbfsFinder + Checker Checker Clock clock.Clock Secrets []tl.Secret -} - -type Finders struct { - Config Config - Finder Finder - RTFinder RTFinder - GbfsFinder GbfsFinder - Checker Checker - Logger zerolog.Logger + ValidateLargeFiles bool + Storage string + RTStorage string + Logger zerolog.Logger } // Finder provides all necessary database methods diff --git a/server/gql/loaders.go b/server/gql/loaders.go index 2fa46362..b9f6e5b9 100644 --- a/server/gql/loaders.go +++ b/server/gql/loaders.go @@ -135,11 +135,11 @@ func NewLoaders(dbf model.Finder) *Loaders { return loaders } -func loaderMiddleware(cfg model.Config, finder model.Finder, next http.Handler) http.Handler { +func loaderMiddleware(frs model.Finders, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // This is per request scoped loaders/cache // Is this OK to use as a long term cache? - loaders := NewLoaders(finder) + loaders := NewLoaders(frs.Finder) nextCtx := context.WithValue(r.Context(), loadersKey, loaders) r = r.WithContext(nextCtx) next.ServeHTTP(w, r) diff --git a/server/gql/server.go b/server/gql/server.go index 1669eb79..f43ca776 100644 --- a/server/gql/server.go +++ b/server/gql/server.go @@ -26,6 +26,6 @@ func NewServer(te model.Finders) (http.Handler, error) { } // Setup server srv := handler.NewDefaultServer(gqlout.NewExecutableSchema(c)) - graphqlServer := loaderMiddleware(te.Config, te.Finder, srv) + graphqlServer := loaderMiddleware(te, srv) return graphqlServer, nil } diff --git a/server/gql/stop_resolver.go b/server/gql/stop_resolver.go index b20402f7..345a9833 100644 --- a/server/gql/stop_resolver.go +++ b/server/gql/stop_resolver.go @@ -104,8 +104,8 @@ func (r *stopResolver) getStopTimes(ctx context.Context, obj *model.Stop, limit return nil, errors.New("timezone not available for stop") } serviceDate := time.Now().In(loc) - if r.frs.Config.Clock != nil { - serviceDate = r.frs.Config.Clock.Now().In(loc) + if r.frs.Clock != nil { + serviceDate = r.frs.Clock.Now().In(loc) } st, et := 0, 0 st = serviceDate.Hour()*3600 + serviceDate.Minute()*60 + serviceDate.Second() diff --git a/server/rest/feed_version_download.go b/server/rest/feed_version_download.go index 62b87b57..16db1bba 100644 --- a/server/rest/feed_version_download.go +++ b/server/rest/feed_version_download.go @@ -86,7 +86,7 @@ func feedVersionDownloadLatestHandler(graphqlHandler http.Handler, w http.Respon } frs := model.ForContext(r.Context()) - serveFromStorage(w, r, frs.Config.Storage, fvsha1) + serveFromStorage(w, r, frs.Storage, fvsha1) } const feedVersionFileQuery = ` @@ -162,7 +162,7 @@ func feedVersionDownloadHandler(graphqlHandler http.Handler, w http.ResponseWrit apiMeter.Meter("feed-version-downloads", 1.0, dims) } - cfg := model.ForContext(r.Context()).Config + cfg := model.ForContext(r.Context()) serveFromStorage(w, r, cfg.Storage, fvsha1) } diff --git a/server/rest/feed_version_download_test.go b/server/rest/feed_version_download_test.go index 850bdc19..b5bd66e8 100644 --- a/server/rest/feed_version_download_test.go +++ b/server/rest/feed_version_download_test.go @@ -17,8 +17,8 @@ func TestFeedVersionDownloadRequest(t *testing.T) { return } srv, te := testRestConfig(t) - te.Config.Storage = g - restSrv, err := testRestServer(t, Config{Config: te.Config}, srv) + te.Storage = g + restSrv, err := testRestServer(t, Config{}, srv) if err != nil { t.Fatal(err) } @@ -118,8 +118,8 @@ func TestFeedDownloadLatestRequest(t *testing.T) { return } srv, te := testRestConfig(t) - te.Config.Storage = g - restSrv, err := testRestServer(t, Config{Config: te.Config}, srv) + te.Storage = g + restSrv, err := testRestServer(t, Config{}, srv) if err != nil { t.Fatal(err) } diff --git a/server/rest/rest.go b/server/rest/rest.go index 83248815..2a3290d5 100644 --- a/server/rest/rest.go +++ b/server/rest/rest.go @@ -33,7 +33,6 @@ const MAXRADIUS = 100 * 1000.0 type Config struct { DisableImage bool RestPrefix string - model.Config } // NewServer . diff --git a/server/rest/rest_test.go b/server/rest/rest_test.go index 29bdd91d..b8c915df 100644 --- a/server/rest/rest_test.go +++ b/server/rest/rest_test.go @@ -62,7 +62,7 @@ type testRest struct { } func testquery(t *testing.T, graphqlHandler http.Handler, te model.Finders, tc testRest) { - data, err := makeRequest(context.TODO(), Config{Config: te.Config}, graphqlHandler, tc.h, tc.format, nil) + data, err := makeRequest(context.TODO(), Config{}, graphqlHandler, tc.h, tc.format, nil) if err != nil { t.Error(err) return diff --git a/server/server_cmd.go b/server/server_cmd.go index 4dc75f65..b55a670a 100644 --- a/server/server_cmd.go +++ b/server/server_cmd.go @@ -43,27 +43,32 @@ import ( ) type Command struct { - Timeout int - Port string - LongQueryDuration int - DisableGraphql bool - DisableRest bool - EnablePlayground bool - EnableAdminApi bool - EnableJobsApi bool - EnableWorkers bool - EnableProfiler bool - EnableRateLimits bool - LoadAdmins bool - QueuePrefix string - SecretsFile string - AuthMiddlewares arrayFlags - metersConfig meters.Config - metricsConfig metrics.Config - AuthConfig ancheck.AuthConfig - CheckerConfig azcheck.CheckerConfig - RestConfig rest.Config - model.Config + Timeout int + Port string + LongQueryDuration int + DisableGraphql bool + DisableRest bool + EnablePlayground bool + EnableAdminApi bool + EnableJobsApi bool + EnableWorkers bool + EnableProfiler bool + EnableRateLimits bool + LoadAdmins bool + QueuePrefix string + SecretsFile string + Storage string + RTStorage string + ValidateLargeFiles bool + DBURL string + RedisURL string + AuthMiddlewares arrayFlags + metersConfig meters.Config + metricsConfig metrics.Config + AuthConfig ancheck.AuthConfig + CheckerConfig azcheck.CheckerConfig + RestConfig rest.Config + secrets []tl.Secret } func (cmd *Command) Parse(args []string) error { @@ -161,16 +166,14 @@ func (cmd *Command) Parse(args []string) error { } secrets = rr.Secrets } - cmd.Config.Secrets = secrets + cmd.secrets = secrets return nil } func (cmd *Command) Run() error { - cfg := cmd.Config - // Open database var db sqlx.Ext - dbx, err := dbutil.OpenDB(cfg.DBURL) + dbx, err := dbutil.OpenDB(cmd.DBURL) if err != nil { return err } @@ -182,7 +185,7 @@ func (cmd *Command) Run() error { // Open redis var redisClient *redis.Client if cmd.RedisURL != "" { - rOpts, err := getRedisOpts(cfg.RedisURL) + rOpts, err := getRedisOpts(cmd.RedisURL) if err != nil { return err } @@ -272,11 +275,13 @@ func (cmd *Command) Run() error { // GraphQL API te := model.Finders{ - Config: cfg, Finder: dbFinder, RTFinder: rtFinder, GbfsFinder: gbfsFinder, Checker: checker, + Secrets: cmd.secrets, + Storage: cmd.Storage, + RTStorage: cmd.RTStorage, } graphqlServer, err := gql.NewServer(te) if err != nil { From 475b803cc5e6531fad5f227d20437f02d1bb66f8 Mon Sep 17 00:00:00 2001 From: Ian Rees Date: Mon, 18 Dec 2023 15:57:42 -0800 Subject: [PATCH 05/17] Rename model.Finders to model.Config --- actions/fetch_test.go | 4 +-- internal/testfinder/testfinder.go | 14 ++++---- jobs/jobs.go | 2 +- model/config.go | 51 ++++++++++++++++++++++++++++ model/finders.go | 45 ------------------------ server/gql/agency_resolver.go | 2 +- server/gql/gbfs_resolver.go | 4 +-- server/gql/loaders.go | 2 +- server/gql/mutation_resolver_test.go | 8 ++--- server/gql/query_resolver.go | 18 +++++----- server/gql/resolver.go | 1 - server/gql/resolver_test.go | 4 +-- server/gql/route_resolver.go | 6 ++-- server/gql/server.go | 4 +-- server/gql/stop_resolver.go | 16 ++++----- server/gql/stop_resolver_test.go | 8 ++--- server/gql/stop_time_resolver.go | 6 ++-- server/gql/trip_resolver.go | 6 ++-- server/rest/feed_version_download.go | 4 +-- server/rest/rest_test.go | 4 +-- server/server_cmd.go | 2 +- workers/gbfs_fetch_worker_test.go | 2 +- 22 files changed, 109 insertions(+), 104 deletions(-) create mode 100644 model/config.go diff --git a/actions/fetch_test.go b/actions/fetch_test.go index fe8a43b0..df6ded09 100644 --- a/actions/fetch_test.go +++ b/actions/fetch_test.go @@ -110,8 +110,8 @@ func TestStaticFetchWorker(t *testing.T) { // Setup job feedUrl := ts.URL + "/" + tc.serveFile - testfinder.FindersTxRollback(t, nil, nil, func(te model.Finders) { - ctx := model.WithFinders(context.Background(), te) + testfinder.FindersTxRollback(t, nil, nil, func(te model.Config) { + ctx := model.WithConfig(context.Background(), te) // Run job if result, err := StaticFetch(ctx, tc.feedId, nil, feedUrl); err != nil && !tc.expectError { _ = result diff --git a/internal/testfinder/testfinder.go b/internal/testfinder/testfinder.go index fa56272d..058c46f5 100644 --- a/internal/testfinder/testfinder.go +++ b/internal/testfinder/testfinder.go @@ -29,7 +29,7 @@ type TestFinderOptions struct { FGAModelTuples []authz.TupleKey } -func newFinders(t testing.TB, db sqlx.Ext, opts TestFinderOptions) model.Finders { +func newFinders(t testing.TB, db sqlx.Ext, opts TestFinderOptions) model.Config { if opts.Clock == nil { opts.Clock = &clock.Real{} } @@ -71,7 +71,7 @@ func newFinders(t testing.TB, db sqlx.Ext, opts TestFinderOptions) model.Finders // Setup GBFS gbf := gbfsfinder.NewFinder(nil) - return model.Finders{ + return model.Config{ Finder: dbf, RTFinder: rtf, GbfsFinder: gbf, @@ -82,17 +82,17 @@ func newFinders(t testing.TB, db sqlx.Ext, opts TestFinderOptions) model.Finders } } -func Finders(t testing.TB, cl clock.Clock, rtJsons []RTJsonFile) model.Finders { +func Finders(t testing.TB, cl clock.Clock, rtJsons []RTJsonFile) model.Config { db := testutil.MustOpenTestDB() return newFinders(t, db, TestFinderOptions{Clock: cl, RTJsons: rtJsons}) } -func FindersWithOptions(t testing.TB, opts TestFinderOptions) model.Finders { +func FindersWithOptions(t testing.TB, opts TestFinderOptions) model.Config { db := testutil.MustOpenTestDB() return newFinders(t, db, opts) } -func FindersTx(t testing.TB, cl clock.Clock, rtJsons []RTJsonFile, cb func(model.Finders) error) { +func FindersTx(t testing.TB, cl clock.Clock, rtJsons []RTJsonFile, cb func(model.Config) error) { // Check open DB db := testutil.MustOpenTestDB() // Start Txn @@ -110,8 +110,8 @@ func FindersTx(t testing.TB, cl clock.Clock, rtJsons []RTJsonFile, cb func(model } } -func FindersTxRollback(t testing.TB, cl clock.Clock, rtJsons []RTJsonFile, cb func(model.Finders)) { - FindersTx(t, cl, rtJsons, func(c model.Finders) error { +func FindersTxRollback(t testing.TB, cl clock.Clock, rtJsons []RTJsonFile, cb func(model.Config)) { + FindersTx(t, cl, rtJsons, func(c model.Config) error { cb(c) return errors.New("rollback") }) diff --git a/jobs/jobs.go b/jobs/jobs.go index de18834c..d56a04a5 100644 --- a/jobs/jobs.go +++ b/jobs/jobs.go @@ -44,7 +44,7 @@ func (job *Job) HexKey() (string, error) { // JobOptions is configuration passed to worker. type JobOptions struct { - Finders model.Finders + Finders model.Config JobQueue JobQueue Logger zerolog.Logger Secrets []tl.Secret diff --git a/model/config.go b/model/config.go new file mode 100644 index 00000000..6a11779b --- /dev/null +++ b/model/config.go @@ -0,0 +1,51 @@ +package model + +import ( + "context" + "net/http" + + "github.com/interline-io/transitland-lib/tl" + "github.com/interline-io/transitland-server/internal/clock" + "github.com/rs/zerolog" +) + +var finderCtxKey = &contextKey{"finderConfig"} + +type contextKey struct { + name string +} + +func ForContext(ctx context.Context) Config { + raw, ok := ctx.Value(finderCtxKey).(Config) + if !ok { + return Config{} + } + return raw +} + +func WithConfig(ctx context.Context, fs Config) context.Context { + r := context.WithValue(ctx, finderCtxKey, fs) + return r +} + +func AddConfig(te Config) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r = r.WithContext(WithConfig(r.Context(), te)) + next.ServeHTTP(w, r) + }) + } +} + +type Config struct { + Finder Finder + RTFinder RTFinder + GbfsFinder GbfsFinder + Checker Checker + Clock clock.Clock + Secrets []tl.Secret + ValidateLargeFiles bool + Storage string + RTStorage string + Logger zerolog.Logger +} diff --git a/model/finders.go b/model/finders.go index 80dcce29..dc6b409e 100644 --- a/model/finders.go +++ b/model/finders.go @@ -2,61 +2,16 @@ package model import ( "context" - "net/http" "time" "github.com/interline-io/transitland-lib/rt/pb" - "github.com/interline-io/transitland-lib/tl" "github.com/interline-io/transitland-lib/tl/tt" "github.com/interline-io/transitland-mw/auth/authz" - "github.com/interline-io/transitland-server/internal/clock" "github.com/interline-io/transitland-server/internal/gbfs" - "github.com/rs/zerolog" "github.com/jmoiron/sqlx" ) -var finderCtxKey = &contextKey{"finderConfig"} - -type contextKey struct { - name string -} - -func ForContext(ctx context.Context) Finders { - raw, ok := ctx.Value(finderCtxKey).(Finders) - if !ok { - return Finders{} - } - return raw -} - -func WithFinders(ctx context.Context, fs Finders) context.Context { - r := context.WithValue(ctx, finderCtxKey, fs) - return r -} - -func AddFinders(te Finders) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - r = r.WithContext(WithFinders(r.Context(), te)) - next.ServeHTTP(w, r) - }) - } -} - -type Finders struct { - Finder Finder - RTFinder RTFinder - GbfsFinder GbfsFinder - Checker Checker - Clock clock.Clock - Secrets []tl.Secret - ValidateLargeFiles bool - Storage string - RTStorage string - Logger zerolog.Logger -} - // Finder provides all necessary database methods type Finder interface { PermFinder diff --git a/server/gql/agency_resolver.go b/server/gql/agency_resolver.go index 7923592d..766ab272 100644 --- a/server/gql/agency_resolver.go +++ b/server/gql/agency_resolver.go @@ -35,6 +35,6 @@ func (r *agencyResolver) Operator(ctx context.Context, obj *model.Agency) (*mode } func (r *agencyResolver) Alerts(ctx context.Context, obj *model.Agency, active *bool, limit *int) ([]*model.Alert, error) { - rtAlerts := r.frs.RTFinder.FindAlertsForAgency(obj, checkLimit(limit), active) + rtAlerts := model.ForContext(ctx).RTFinder.FindAlertsForAgency(obj, checkLimit(limit), active) return rtAlerts, nil } diff --git a/server/gql/gbfs_resolver.go b/server/gql/gbfs_resolver.go index 789380e4..cd5252a6 100644 --- a/server/gql/gbfs_resolver.go +++ b/server/gql/gbfs_resolver.go @@ -7,9 +7,9 @@ import ( ) func (r *queryResolver) Bikes(ctx context.Context, limit *int, where *model.GbfsBikeRequest) ([]*model.GbfsFreeBikeStatus, error) { - return r.frs.GbfsFinder.FindBikes(ctx, checkLimit(limit), where) + return model.ForContext(ctx).GbfsFinder.FindBikes(ctx, checkLimit(limit), where) } func (r *queryResolver) Docks(ctx context.Context, limit *int, where *model.GbfsDockRequest) ([]*model.GbfsStationInformation, error) { - return r.frs.GbfsFinder.FindDocks(ctx, checkLimit(limit), where) + return model.ForContext(ctx).GbfsFinder.FindDocks(ctx, checkLimit(limit), where) } diff --git a/server/gql/loaders.go b/server/gql/loaders.go index b9f6e5b9..56da6318 100644 --- a/server/gql/loaders.go +++ b/server/gql/loaders.go @@ -135,7 +135,7 @@ func NewLoaders(dbf model.Finder) *Loaders { return loaders } -func loaderMiddleware(frs model.Finders, next http.Handler) http.Handler { +func loaderMiddleware(frs model.Config, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // This is per request scoped loaders/cache // Is this OK to use as a long term cache? diff --git a/server/gql/mutation_resolver_test.go b/server/gql/mutation_resolver_test.go index 3745bb0e..90dd2bd8 100644 --- a/server/gql/mutation_resolver_test.go +++ b/server/gql/mutation_resolver_test.go @@ -26,9 +26,9 @@ func TestFeedVersionFetchResolver(t *testing.T) { w.Write(buf) })) t.Run("found sha1", func(t *testing.T) { - testfinder.FindersTxRollback(t, nil, nil, func(te model.Finders) { + testfinder.FindersTxRollback(t, nil, nil, func(te model.Config) { srv, _ := NewServer(te) - srv = model.AddFinders(te)(srv) + srv = model.AddConfig(te)(srv) srv = ancheck.AdminDefaultMiddleware("test")(srv) // Run all requests as admin // Run all requests as admin c := client.New(srv) @@ -165,7 +165,7 @@ func TestValidateGtfsResolver(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testfinder.FindersTxRollback(t, nil, nil, func(te model.Finders) { + testfinder.FindersTxRollback(t, nil, nil, func(te model.Config) { srv, _ := NewServer(te) srv = ancheck.UserDefaultMiddleware("test")(srv) // Run all requests as user c := client.New(srv) @@ -174,7 +174,7 @@ func TestValidateGtfsResolver(t *testing.T) { }) } t.Run("requires user access", func(t *testing.T) { - testfinder.FindersTxRollback(t, nil, nil, func(te model.Finders) { + testfinder.FindersTxRollback(t, nil, nil, func(te model.Config) { srv, _ := NewServer(te) // all requests run as anonymous context by default c := client.New(srv) resp := make(map[string]interface{}) diff --git a/server/gql/query_resolver.go b/server/gql/query_resolver.go index d4fb9836..a84186b6 100644 --- a/server/gql/query_resolver.go +++ b/server/gql/query_resolver.go @@ -16,7 +16,7 @@ const MAX_RADIUS = 100_000 type queryResolver struct{ *Resolver } func (r *queryResolver) Me(ctx context.Context) (*model.Me, error) { - me, err := r.frs.Checker.Me(ctx, &authz.MeRequest{}) + me, err := model.ForContext(ctx).Checker.Me(ctx, &authz.MeRequest{}) if err != nil { return nil, err } @@ -43,7 +43,7 @@ func (r *queryResolver) Agencies(ctx context.Context, limit *int, after *int, id return nil, errors.New("bbox too large") } } - return r.frs.Finder.FindAgencies(ctx, checkLimit(limit), checkCursor(after), ids, where) + return model.ForContext(ctx).Finder.FindAgencies(ctx, checkLimit(limit), checkCursor(after), ids, where) } func (r *queryResolver) Routes(ctx context.Context, limit *int, after *int, ids []int, where *model.RouteFilter) ([]*model.Route, error) { @@ -56,7 +56,7 @@ func (r *queryResolver) Routes(ctx context.Context, limit *int, after *int, ids return nil, errors.New("bbox too large") } } - return r.frs.Finder.FindRoutes(ctx, checkLimit(limit), checkCursor(after), ids, where) + return model.ForContext(ctx).Finder.FindRoutes(ctx, checkLimit(limit), checkCursor(after), ids, where) } func (r *queryResolver) Stops(ctx context.Context, limit *int, after *int, ids []int, where *model.StopFilter) ([]*model.Stop, error) { @@ -69,12 +69,12 @@ func (r *queryResolver) Stops(ctx context.Context, limit *int, after *int, ids [ return nil, errors.New("bbox too large") } } - return r.frs.Finder.FindStops(ctx, checkLimit(limit), checkCursor(after), ids, where) + return model.ForContext(ctx).Finder.FindStops(ctx, checkLimit(limit), checkCursor(after), ids, where) } func (r *queryResolver) Trips(ctx context.Context, limit *int, after *int, ids []int, where *model.TripFilter) ([]*model.Trip, error) { addMetric(ctx, "trips") - return r.frs.Finder.FindTrips(ctx, checkLimit(limit), checkCursor(after), ids, where) + return model.ForContext(ctx).Finder.FindTrips(ctx, checkLimit(limit), checkCursor(after), ids, where) } func (r *queryResolver) FeedVersions(ctx context.Context, limit *int, after *int, ids []int, where *model.FeedVersionFilter) ([]*model.FeedVersion, error) { @@ -87,7 +87,7 @@ func (r *queryResolver) FeedVersions(ctx context.Context, limit *int, after *int return nil, errors.New("bbox too large") } } - return r.frs.Finder.FindFeedVersions(ctx, checkLimit(limit), checkCursor(after), ids, where) + return model.ForContext(ctx).Finder.FindFeedVersions(ctx, checkLimit(limit), checkCursor(after), ids, where) } func (r *queryResolver) Feeds(ctx context.Context, limit *int, after *int, ids []int, where *model.FeedFilter) ([]*model.Feed, error) { @@ -100,7 +100,7 @@ func (r *queryResolver) Feeds(ctx context.Context, limit *int, after *int, ids [ return nil, errors.New("bbox too large") } } - return r.frs.Finder.FindFeeds(ctx, checkLimit(limit), checkCursor(after), ids, where) + return model.ForContext(ctx).Finder.FindFeeds(ctx, checkLimit(limit), checkCursor(after), ids, where) } func (r *queryResolver) Operators(ctx context.Context, limit *int, after *int, ids []int, where *model.OperatorFilter) ([]*model.Operator, error) { @@ -113,11 +113,11 @@ func (r *queryResolver) Operators(ctx context.Context, limit *int, after *int, i return nil, errors.New("bbox too large") } } - return r.frs.Finder.FindOperators(ctx, checkLimit(limit), checkCursor(after), ids, where) + return model.ForContext(ctx).Finder.FindOperators(ctx, checkLimit(limit), checkCursor(after), ids, where) } func (r *queryResolver) Places(ctx context.Context, limit *int, after *int, level *model.PlaceAggregationLevel, where *model.PlaceFilter) ([]*model.Place, error) { - return r.frs.Finder.FindPlaces(ctx, checkLimit(limit), checkCursor(after), nil, level, where) + return model.ForContext(ctx).Finder.FindPlaces(ctx, checkLimit(limit), checkCursor(after), nil, level, where) } func addMetric(ctx context.Context, resolverName string) { diff --git a/server/gql/resolver.go b/server/gql/resolver.go index c47dee74..4015b40f 100644 --- a/server/gql/resolver.go +++ b/server/gql/resolver.go @@ -53,7 +53,6 @@ func atoi(v string) int { // Resolver . type Resolver struct { - frs model.Finders fvslCache *fvslCache } diff --git a/server/gql/resolver_test.go b/server/gql/resolver_test.go index b1ffc8bc..25d32a98 100644 --- a/server/gql/resolver_test.go +++ b/server/gql/resolver_test.go @@ -46,7 +46,7 @@ func TestMain(m *testing.M) { // Test helpers -func newTestClient(t testing.TB) (*client.Client, model.Finders) { +func newTestClient(t testing.TB) (*client.Client, model.Config) { when, err := time.Parse("2006-01-02T15:04:05", "2022-09-01T00:00:00") if err != nil { t.Fatal(err) @@ -54,7 +54,7 @@ func newTestClient(t testing.TB) (*client.Client, model.Finders) { return newTestClientWithClock(t, &clock.Mock{T: when}, testfinder.DefaultRTJson()) } -func newTestClientWithClock(t testing.TB, cl clock.Clock, rtfiles []testfinder.RTJsonFile) (*client.Client, model.Finders) { +func newTestClientWithClock(t testing.TB, cl clock.Clock, rtfiles []testfinder.RTJsonFile) (*client.Client, model.Config) { te := testfinder.Finders(t, cl, rtfiles) srv, _ := NewServer(te) srvMiddleware := ancheck.NewUserDefaultMiddleware(func() authn.User { return authn.NewCtxUser("testuser", "", "").WithRoles("testrole") }) diff --git a/server/gql/route_resolver.go b/server/gql/route_resolver.go index 5ed36ec8..c1077052 100644 --- a/server/gql/route_resolver.go +++ b/server/gql/route_resolver.go @@ -63,7 +63,7 @@ func (r *routeResolver) Headways(ctx context.Context, obj *model.Route, limit *i func (r *routeResolver) RouteStopBuffer(ctx context.Context, obj *model.Route, radius *float64) (*model.RouteStopBuffer, error) { // TODO: remove n+1 (which is tricky, what if multiple radius specified in different parts of query) p := model.RouteStopBufferParam{Radius: radius, EntityID: obj.ID} - ents, err := r.frs.Finder.RouteStopBuffer(ctx, &p) + ents, err := model.ForContext(ctx).Finder.RouteStopBuffer(ctx, &p) if err != nil { return nil, err } @@ -74,7 +74,7 @@ func (r *routeResolver) RouteStopBuffer(ctx context.Context, obj *model.Route, r } func (r *routeResolver) Alerts(ctx context.Context, obj *model.Route, active *bool, limit *int) ([]*model.Alert, error) { - return r.frs.RTFinder.FindAlertsForRoute(obj, checkLimit(limit), active), nil + return model.ForContext(ctx).RTFinder.FindAlertsForRoute(obj, checkLimit(limit), active), nil } func (r *routeResolver) Patterns(ctx context.Context, obj *model.Route) ([]*model.RouteStopPattern, error) { @@ -124,7 +124,7 @@ type routePatternResolver struct{ *Resolver } func (r *routePatternResolver) Trips(ctx context.Context, obj *model.RouteStopPattern, limit *int) ([]*model.Trip, error) { // TODO: N+1 query - trips, err := r.frs.Finder.FindTrips(ctx, checkLimit(limit), nil, nil, &model.TripFilter{StopPatternID: &obj.StopPatternID, RouteIds: []int{obj.RouteID}}) + trips, err := model.ForContext(ctx).Finder.FindTrips(ctx, checkLimit(limit), nil, nil, &model.TripFilter{StopPatternID: &obj.StopPatternID, RouteIds: []int{obj.RouteID}}) return trips, err } diff --git a/server/gql/server.go b/server/gql/server.go index f43ca776..8e71f14b 100644 --- a/server/gql/server.go +++ b/server/gql/server.go @@ -12,9 +12,8 @@ import ( "github.com/interline-io/transitland-server/model" ) -func NewServer(te model.Finders) (http.Handler, error) { +func NewServer(te model.Config) (http.Handler, error) { c := gqlout.Config{Resolvers: &Resolver{ - frs: te, fvslCache: newFvslCache(te.Finder), }} c.Directives.HasRole = func(ctx context.Context, obj interface{}, next graphql.Resolver, role model.Role) (interface{}, error) { @@ -27,5 +26,6 @@ func NewServer(te model.Finders) (http.Handler, error) { // Setup server srv := handler.NewDefaultServer(gqlout.NewExecutableSchema(c)) graphqlServer := loaderMiddleware(te, srv) + graphqlServer = model.AddConfig(te)(graphqlServer) return graphqlServer, nil } diff --git a/server/gql/stop_resolver.go b/server/gql/stop_resolver.go index 345a9833..e7c5bead 100644 --- a/server/gql/stop_resolver.go +++ b/server/gql/stop_resolver.go @@ -99,13 +99,13 @@ func (r *stopResolver) getStopTimes(ctx context.Context, obj *model.Stop, limit if where != nil { // Convert where.Next into departure date and time window if where.Next != nil { - loc, ok := r.frs.RTFinder.StopTimezone(obj.ID, obj.StopTimezone) + loc, ok := model.ForContext(ctx).RTFinder.StopTimezone(obj.ID, obj.StopTimezone) if !ok { return nil, errors.New("timezone not available for stop") } serviceDate := time.Now().In(loc) - if r.frs.Clock != nil { - serviceDate = r.frs.Clock.Now().In(loc) + if model.ForContext(ctx).Clock != nil { + serviceDate = model.ForContext(ctx).Clock.Now().In(loc) } st, et := 0, 0 st = serviceDate.Hour()*3600 + serviceDate.Minute()*60 + serviceDate.Second() @@ -163,13 +163,13 @@ func (r *stopResolver) getStopTimes(ctx context.Context, obj *model.Stop, limit for _, st := range sts { ft := model.Trip{} ft.FeedVersionID = obj.FeedVersionID - ft.TripID, _ = r.frs.RTFinder.GetGtfsTripID(atoi(st.TripID)) // TODO! - if ste, ok := r.frs.RTFinder.FindStopTimeUpdate(&ft, st); ok { + ft.TripID, _ = model.ForContext(ctx).RTFinder.GetGtfsTripID(atoi(st.TripID)) // TODO! + if ste, ok := model.ForContext(ctx).RTFinder.FindStopTimeUpdate(&ft, st); ok { st.RTStopTimeUpdate = ste } } // Handle added trips; these must specify stop_id in StopTimeUpdates - for _, rtTrip := range r.frs.RTFinder.GetAddedTripsForStop(obj) { + for _, rtTrip := range model.ForContext(ctx).RTFinder.GetAddedTripsForStop(obj) { for _, stu := range rtTrip.StopTimeUpdate { if stu.GetStopId() != obj.StopID { continue @@ -196,7 +196,7 @@ func (r *stopResolver) getStopTimes(ctx context.Context, obj *model.Stop, limit } func (r *stopResolver) Alerts(ctx context.Context, obj *model.Stop, active *bool, limit *int) ([]*model.Alert, error) { - rtAlerts := r.frs.RTFinder.FindAlertsForStop(obj, checkLimit(limit), active) + rtAlerts := model.ForContext(ctx).RTFinder.FindAlertsForStop(obj, checkLimit(limit), active) return rtAlerts, nil } @@ -223,7 +223,7 @@ func (r *stopResolver) Directions(ctx context.Context, obj *model.Stop, from *mo func (r *stopResolver) NearbyStops(ctx context.Context, obj *model.Stop, limit *int, radius *float64) ([]*model.Stop, error) { c := obj.Coordinates() - nearbyStops, err := r.frs.Finder.FindStops(ctx, checkLimit(limit), nil, nil, &model.StopFilter{Near: &model.PointRadius{Lon: c[0], Lat: c[1], Radius: checkFloat(radius, 0, MAX_RADIUS)}}) + nearbyStops, err := model.ForContext(ctx).Finder.FindStops(ctx, checkLimit(limit), nil, nil, &model.StopFilter{Near: &model.PointRadius{Lon: c[0], Lat: c[1], Radius: checkFloat(radius, 0, MAX_RADIUS)}}) return nearbyStops, err } diff --git a/server/gql/stop_resolver_test.go b/server/gql/stop_resolver_test.go index 5be3f844..47c1202a 100644 --- a/server/gql/stop_resolver_test.go +++ b/server/gql/stop_resolver_test.go @@ -87,7 +87,7 @@ func BenchmarkStopResolver(b *testing.B) { benchmarkTestcases(b, c, stopResolverTestcases(b, te)) } -func stopResolverTestcases(t testing.TB, te model.Finders) []testcase { +func stopResolverTestcases(t testing.TB, te model.Config) []testcase { bartStops := []string{"12TH", "16TH", "19TH", "19TH_N", "24TH", "ANTC", "ASHB", "BALB", "BAYF", "CAST", "CIVC", "COLS", "COLM", "CONC", "DALY", "DBRK", "DUBL", "DELN", "PLZA", "EMBR", "FRMT", "FTVL", "GLEN", "HAYW", "LAFY", "LAKE", "MCAR", "MCAR_S", "MLBR", "MONT", "NBRK", "NCON", "OAKL", "ORIN", "PITT", "PCTR", "PHIL", "POWL", "RICH", "ROCK", "SBRN", "SFIA", "SANL", "SHAY", "SSAN", "UCTY", "WCRK", "WARM", "WDUB", "WOAK"} caltrainRailStops := []string{"70011", "70012", "70021", "70022", "70031", "70032", "70041", "70042", "70051", "70052", "70061", "70062", "70071", "70072", "70081", "70082", "70091", "70092", "70101", "70102", "70111", "70112", "70121", "70122", "70131", "70132", "70141", "70142", "70151", "70152", "70161", "70162", "70171", "70172", "70191", "70192", "70201", "70202", "70211", "70212", "70221", "70222", "70231", "70232", "70241", "70242", "70251", "70252", "70261", "70262", "70271", "70272", "70281", "70282", "70291", "70292", "70301", "70302", "70311", "70312", "70321", "70322"} caltrainBusStops := []string{"777402", "777403"} @@ -510,7 +510,7 @@ func stopResolverTestcases(t testing.TB, te model.Finders) []testcase { return testcases } -func stopResolverCursorTestcases(t *testing.T, te model.Finders) []testcase { +func stopResolverCursorTestcases(t *testing.T, te model.Config) []testcase { // First 1000 stops... dbf := te.Finder allEnts, err := dbf.FindStops(context.Background(), nil, nil, nil, nil) @@ -562,7 +562,7 @@ func stopResolverCursorTestcases(t *testing.T, te model.Finders) []testcase { return testcases } -func stopResolverPreviousOnestopIDTestcases(t testing.TB, te model.Finders) []testcase { +func stopResolverPreviousOnestopIDTestcases(t testing.TB, te model.Config) []testcase { testcases := []testcase{ { name: "default", @@ -596,7 +596,7 @@ func stopResolverPreviousOnestopIDTestcases(t testing.TB, te model.Finders) []te return testcases } -func stopResolverLicenseTestcases(t testing.TB, te model.Finders) []testcase { +func stopResolverLicenseTestcases(t testing.TB, te model.Config) []testcase { q := ` query ($lic: LicenseFilter) { stops(limit: 10000, where: {license: $lic}) { diff --git a/server/gql/stop_time_resolver.go b/server/gql/stop_time_resolver.go index 1dfa1de1..0f6f70c9 100644 --- a/server/gql/stop_time_resolver.go +++ b/server/gql/stop_time_resolver.go @@ -26,7 +26,7 @@ func (r *stopTimeResolver) Trip(ctx context.Context, obj *model.StopTime) (*mode t := model.Trip{} t.FeedVersionID = obj.FeedVersionID t.TripID = obj.RTTripID - a, err := r.frs.RTFinder.MakeTrip(&t) + a, err := model.ForContext(ctx).RTFinder.MakeTrip(&t) return a, err } return For(ctx).TripsByID.Load(ctx, atoi(obj.TripID))() @@ -34,7 +34,7 @@ func (r *stopTimeResolver) Trip(ctx context.Context, obj *model.StopTime) (*mode func (r *stopTimeResolver) Arrival(ctx context.Context, obj *model.StopTime) (*model.StopTimeEvent, error) { // lookup timezone - loc, ok := r.frs.RTFinder.StopTimezone(atoi(obj.StopID), "") + loc, ok := model.ForContext(ctx).RTFinder.StopTimezone(atoi(obj.StopID), "") if !ok { return nil, errors.New("timezone not available for stop") } @@ -54,7 +54,7 @@ func (r *stopTimeResolver) Arrival(ctx context.Context, obj *model.StopTime) (*m func (r *stopTimeResolver) Departure(ctx context.Context, obj *model.StopTime) (*model.StopTimeEvent, error) { // lookup timezone - loc, ok := r.frs.RTFinder.StopTimezone(atoi(obj.StopID), "") + loc, ok := model.ForContext(ctx).RTFinder.StopTimezone(atoi(obj.StopID), "") if !ok { return nil, errors.New("timezone not available for stop") } diff --git a/server/gql/trip_resolver.go b/server/gql/trip_resolver.go index 1fcd08f4..f6b27990 100644 --- a/server/gql/trip_resolver.go +++ b/server/gql/trip_resolver.go @@ -45,7 +45,7 @@ func (r *tripResolver) Frequencies(ctx context.Context, obj *model.Trip, limit * func (r *tripResolver) ScheduleRelationship(ctx context.Context, obj *model.Trip) (*model.ScheduleRelationship, error) { msr := model.ScheduleRelationshipScheduled - if rtt := r.frs.RTFinder.FindTrip(obj); rtt != nil { + if rtt := model.ForContext(ctx).RTFinder.FindTrip(obj); rtt != nil { sr := rtt.GetTrip().GetScheduleRelationship().String() switch sr { case "SCHEDULED": @@ -64,7 +64,7 @@ func (r *tripResolver) ScheduleRelationship(ctx context.Context, obj *model.Trip } func (r *tripResolver) Timestamp(ctx context.Context, obj *model.Trip) (*time.Time, error) { - if rtt := r.frs.RTFinder.FindTrip(obj); rtt != nil { + if rtt := model.ForContext(ctx).RTFinder.FindTrip(obj); rtt != nil { t := time.Unix(int64(rtt.GetTimestamp()), 0).In(time.UTC) return &t, nil } @@ -72,6 +72,6 @@ func (r *tripResolver) Timestamp(ctx context.Context, obj *model.Trip) (*time.Ti } func (r *tripResolver) Alerts(ctx context.Context, obj *model.Trip, active *bool, limit *int) ([]*model.Alert, error) { - rtAlerts := r.frs.RTFinder.FindAlertsForTrip(obj, checkLimit(limit), active) + rtAlerts := model.ForContext(ctx).RTFinder.FindAlertsForTrip(obj, checkLimit(limit), active) return rtAlerts, nil } diff --git a/server/rest/feed_version_download.go b/server/rest/feed_version_download.go index 16db1bba..fec9ec09 100644 --- a/server/rest/feed_version_download.go +++ b/server/rest/feed_version_download.go @@ -162,8 +162,8 @@ func feedVersionDownloadHandler(graphqlHandler http.Handler, w http.ResponseWrit apiMeter.Meter("feed-version-downloads", 1.0, dims) } - cfg := model.ForContext(r.Context()) - serveFromStorage(w, r, cfg.Storage, fvsha1) + frs := model.ForContext(r.Context()) + serveFromStorage(w, r, frs.Storage, fvsha1) } func serveFromStorage(w http.ResponseWriter, r *http.Request, storage string, fvsha1 string) { diff --git a/server/rest/rest_test.go b/server/rest/rest_test.go index b8c915df..01e8c1ee 100644 --- a/server/rest/rest_test.go +++ b/server/rest/rest_test.go @@ -29,7 +29,7 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } -func testRestConfig(t testing.TB) (http.Handler, model.Finders) { +func testRestConfig(t testing.TB) (http.Handler, model.Config) { when, err := time.Parse("2006-01-02T15:04:05", "2018-06-01T00:00:00") if err != nil { t.Fatal(err) @@ -61,7 +61,7 @@ type testRest struct { f func(*testing.T, string) } -func testquery(t *testing.T, graphqlHandler http.Handler, te model.Finders, tc testRest) { +func testquery(t *testing.T, graphqlHandler http.Handler, te model.Config, tc testRest) { data, err := makeRequest(context.TODO(), Config{}, graphqlHandler, tc.h, tc.format, nil) if err != nil { t.Error(err) diff --git a/server/server_cmd.go b/server/server_cmd.go index b55a670a..29e88191 100644 --- a/server/server_cmd.go +++ b/server/server_cmd.go @@ -274,7 +274,7 @@ func (cmd *Command) Run() error { } // GraphQL API - te := model.Finders{ + te := model.Config{ Finder: dbFinder, RTFinder: rtFinder, GbfsFinder: gbfsFinder, diff --git a/workers/gbfs_fetch_worker_test.go b/workers/gbfs_fetch_worker_test.go index 0bb75e45..5e48ab83 100644 --- a/workers/gbfs_fetch_worker_test.go +++ b/workers/gbfs_fetch_worker_test.go @@ -18,7 +18,7 @@ func TestGbfsFetchWorker(t *testing.T) { ts := httptest.NewServer(&gbfs.TestGbfsServer{Language: "en", Path: testutil.RelPath("test/data/gbfs")}) defer ts.Close() - testfinder.FindersTxRollback(t, nil, nil, func(te model.Finders) { + testfinder.FindersTxRollback(t, nil, nil, func(te model.Config) { job := jobs.Job{} job.Opts.Finders = te w := GbfsFetchWorker{ From 8c94948b39b60bfac7c8191bf32eb500ecf3669c Mon Sep 17 00:00:00 2001 From: Ian Rees Date: Mon, 18 Dec 2023 17:03:10 -0800 Subject: [PATCH 06/17] More test config refactoring --- actions/fetch.go | 28 ++-- actions/fetch_test.go | 4 +- actions/fv.go | 24 ++-- actions/validate.go | 4 +- .../testconfig.go} | 120 +++++++++--------- model/config.go | 34 ++--- server/gql/agency_resolver_test.go | 7 +- server/gql/fvsl_cache_test.go | 4 +- server/gql/loaders.go | 4 +- server/gql/mutation_resolver_test.go | 10 +- server/gql/resolver_test.go | 24 ++-- server/gql/rt_test.go | 47 ++++--- server/gql/server.go | 8 +- server/gql/stop_time_resolver_test.go | 13 +- server/rest/feed_version_download.go | 8 +- server/rest/rest_test.go | 17 ++- server/server_cmd.go | 27 ++-- workers/gbfs_fetch_worker_test.go | 4 +- 18 files changed, 194 insertions(+), 193 deletions(-) rename internal/{testfinder/testfinder.go => testconfig/testconfig.go} (72%) diff --git a/actions/fetch.go b/actions/fetch.go index dfc4287e..3d229ab9 100644 --- a/actions/fetch.go +++ b/actions/fetch.go @@ -25,8 +25,8 @@ import ( ) func StaticFetch(ctx context.Context, feedId string, feedSrc io.Reader, feedUrl string) (*model.FeedVersionFetchResult, error) { - frs := model.ForContext(ctx) - dbf := frs.Finder + cfg := model.ForContext(ctx) + dbf := cfg.Finder urlType := "static_current" feed, err := fetchCheckFeed(ctx, feedId, urlType, feedUrl) @@ -42,8 +42,8 @@ func StaticFetch(ctx context.Context, feedId string, feedSrc io.Reader, feedUrl FeedID: feed.ID, URLType: urlType, FeedURL: feedUrl, - Storage: frs.Storage, - Secrets: frs.Secrets, + Storage: cfg.Storage, + Secrets: cfg.Secrets, FetchedAt: time.Now().In(time.UTC), AllowFTPFetch: true, } @@ -89,7 +89,7 @@ func StaticFetch(ctx context.Context, feedId string, feedSrc io.Reader, feedUrl } func RTFetch(ctx context.Context, target string, feedId string, feedUrl string, urlType string) error { - frs := model.ForContext(ctx) + cfg := model.ForContext(ctx) feed, err := fetchCheckFeed(ctx, feedId, urlType, feedUrl) if err != nil { @@ -104,15 +104,15 @@ func RTFetch(ctx context.Context, target string, feedId string, feedUrl string, FeedID: feed.ID, URLType: urlType, FeedURL: feedUrl, - Storage: frs.RTStorage, - Secrets: frs.Secrets, + Storage: cfg.RTStorage, + Secrets: cfg.Secrets, FetchedAt: time.Now().In(time.UTC), } // Make request var rtMsg *pb.FeedMessage var fetchErr error - if err := tldb.NewPostgresAdapterFromDBX(frs.Finder.DBX()).Tx(func(atx tldb.Adapter) error { + if err := tldb.NewPostgresAdapterFromDBX(cfg.Finder.DBX()).Tx(func(atx tldb.Adapter) error { m, fr, err := fetch.RTFetch(atx, fetchOpts) if err != nil { return err @@ -133,7 +133,7 @@ func RTFetch(ctx context.Context, target string, feedId string, feedUrl string, return errors.New("invalid rt data") } key := fmt.Sprintf("rtdata:%s:%s", target, urlType) - return frs.RTFinder.AddData(key, rtdata) + return cfg.RTFinder.AddData(key, rtdata) } type CheckFetchWaitResult struct { @@ -220,7 +220,7 @@ func CheckFetchWaitBatch(ctx context.Context, db sqlx.Ext, feedIds []int, urlTyp return nil, err } for _, fetch := range lastFetches { - a, _ := checks[fetch.ID] + a := checks[fetch.ID] a.CheckedAt = now a.ID = fetch.ID a.OnestopID = fetch.OnestopID @@ -255,14 +255,14 @@ func chunkBy[T any](items []T, chunkSize int) (chunks [][]T) { } func fetchCheckFeed(ctx context.Context, feedId string, urlType string, url string) (*model.Feed, error) { - frs := model.ForContext(ctx) - if frs.Finder == nil { + cfg := model.ForContext(ctx) + if cfg.Finder == nil { panic("no finder") } - checker := frs.Checker + checker := cfg.Checker // Check feed exists - feeds, err := frs.Finder.FindFeeds(ctx, nil, nil, nil, &model.FeedFilter{OnestopID: &feedId}) + feeds, err := cfg.Finder.FindFeeds(ctx, nil, nil, nil, &model.FeedFilter{OnestopID: &feedId}) if err != nil { return nil, err } diff --git a/actions/fetch_test.go b/actions/fetch_test.go index df6ded09..23360242 100644 --- a/actions/fetch_test.go +++ b/actions/fetch_test.go @@ -11,7 +11,7 @@ import ( "github.com/interline-io/transitland-lib/dmfr" "github.com/interline-io/transitland-server/internal/dbutil" - "github.com/interline-io/transitland-server/internal/testfinder" + "github.com/interline-io/transitland-server/internal/testconfig" "github.com/interline-io/transitland-server/internal/testutil" "github.com/interline-io/transitland-server/model" "github.com/stretchr/testify/assert" @@ -110,7 +110,7 @@ func TestStaticFetchWorker(t *testing.T) { // Setup job feedUrl := ts.URL + "/" + tc.serveFile - testfinder.FindersTxRollback(t, nil, nil, func(te model.Config) { + testconfig.ConfigTxRollback(t, testconfig.Options{}, func(te model.Config) { ctx := model.WithConfig(context.Background(), te) // Run job if result, err := StaticFetch(ctx, tc.feedId, nil, feedUrl); err != nil && !tc.expectError { diff --git a/actions/fv.go b/actions/fv.go index 487ae034..a0db32bd 100644 --- a/actions/fv.go +++ b/actions/fv.go @@ -14,9 +14,9 @@ import ( ) func FeedVersionImport(ctx context.Context, fvid int) (*model.FeedVersionImportResult, error) { - frs := model.ForContext(ctx) - checker := frs.Checker - dbf := frs.Finder + cfg := model.ForContext(ctx) + checker := cfg.Checker + dbf := cfg.Finder if checker == nil { return nil, authz.ErrUnauthorized } @@ -27,7 +27,7 @@ func FeedVersionImport(ctx context.Context, fvid int) (*model.FeedVersionImportR } opts := importer.Options{ FeedVersionID: fvid, - Storage: frs.Storage, + Storage: cfg.Storage, } db := tldb.NewPostgresAdapterFromDBX(dbf.DBX()) fr, fe := importer.MainImportFeedVersion(db, opts) @@ -41,9 +41,9 @@ func FeedVersionImport(ctx context.Context, fvid int) (*model.FeedVersionImportR } func FeedVersionUnimport(ctx context.Context, fvid int) (*model.FeedVersionUnimportResult, error) { - frs := model.ForContext(ctx) - checker := frs.Checker - dbf := frs.Finder + cfg := model.ForContext(ctx) + checker := cfg.Checker + dbf := cfg.Finder if checker == nil { return nil, authz.ErrUnauthorized } @@ -65,9 +65,9 @@ func FeedVersionUnimport(ctx context.Context, fvid int) (*model.FeedVersionUnimp } func FeedVersionUpdate(ctx context.Context, fvid int, values model.FeedVersionSetInput) error { - frs := model.ForContext(ctx) - checker := frs.Checker - dbf := frs.Finder + cfg := model.ForContext(ctx) + checker := cfg.Checker + dbf := cfg.Finder if checker == nil { return authz.ErrUnauthorized } @@ -102,8 +102,8 @@ func FeedVersionUpdate(ctx context.Context, fvid int, values model.FeedVersionSe } func FeedVersionDelete(ctx context.Context, fvid int) (*model.FeedVersionDeleteResult, error) { - frs := model.ForContext(ctx) - checker := frs.Checker + cfg := model.ForContext(ctx) + checker := cfg.Checker if checker == nil { return nil, authz.ErrUnauthorized } diff --git a/actions/validate.go b/actions/validate.go index fae0692c..b8eba9e1 100644 --- a/actions/validate.go +++ b/actions/validate.go @@ -21,7 +21,7 @@ type hasGeometries interface { // ValidateUpload takes a file Reader and produces a validation package containing errors, warnings, file infos, service levels, etc. func ValidateUpload(ctx context.Context, src io.Reader, feedURL *string, rturls []string) (*model.ValidationResult, error) { - frs := model.ForContext(ctx) + cfg := model.ForContext(ctx) // Check inputs rturlsok := []string{} @@ -87,7 +87,7 @@ func ValidateUpload(ctx context.Context, src io.Reader, feedURL *string, rturls MaxRTMessageSize: 10_000_000, ValidateRealtimeMessages: rturls, } - if frs.ValidateLargeFiles { + if cfg.ValidateLargeFiles { opts.CheckFileLimits = false } diff --git a/internal/testfinder/testfinder.go b/internal/testconfig/testconfig.go similarity index 72% rename from internal/testfinder/testfinder.go rename to internal/testconfig/testconfig.go index 058c46f5..0a585252 100644 --- a/internal/testfinder/testfinder.go +++ b/internal/testconfig/testconfig.go @@ -1,4 +1,4 @@ -package testfinder +package testconfig import ( "context" @@ -6,6 +6,7 @@ import ( "fmt" "os" "testing" + "time" "github.com/interline-io/transitland-lib/rt" "github.com/interline-io/transitland-mw/auth/authz" @@ -22,17 +23,69 @@ import ( // Test helpers -type TestFinderOptions struct { - Clock clock.Clock +type Options struct { + When string RTJsons []RTJsonFile FGAModelFile string FGAModelTuples []authz.TupleKey } -func newFinders(t testing.TB, db sqlx.Ext, opts TestFinderOptions) model.Config { - if opts.Clock == nil { - opts.Clock = &clock.Real{} +func Config(t testing.TB, opts Options) model.Config { + db := testutil.MustOpenTestDB() + return newTestConfig(t, db, opts) +} + +func ConfigTx(t testing.TB, opts Options, cb func(model.Config) error) { + // Check open DB + db := testutil.MustOpenTestDB() + + // Start Txn + tx := db.MustBeginTx(context.Background(), nil) + defer tx.Rollback() + + // Get finders + testEnv := newTestConfig(t, tx, opts) + + // Commit or rollback + if err := cb(testEnv); err != nil { + //tx.Rollback() + } else { + tx.Commit() + } +} + +func ConfigTxRollback(t testing.TB, opts Options, cb func(model.Config)) { + ConfigTx(t, opts, func(c model.Config) error { + cb(c) + return errors.New("rollback") + }) +} + +type RTJsonFile struct { + Feed string + Ftype string + Fname string +} + +func DefaultRTJson() []RTJsonFile { + return []RTJsonFile{ + {"BA", "realtime_trip_updates", "BA.json"}, + {"BA", "realtime_alerts", "BA-alerts.json"}, + {"CT", "realtime_trip_updates", "CT.json"}, + } +} + +func newTestConfig(t testing.TB, db sqlx.Ext, opts Options) model.Config { + // Default time + if opts.When == "" { + opts.When = "2022-09-01T00:00:00" + } + + when, err := time.Parse("2006-01-02T15:04:05", opts.When) + if err != nil { + t.Fatal(err) } + cl := &clock.Mock{T: when} // Setup Checker checkerCfg := azcheck.CheckerConfig{ @@ -47,11 +100,11 @@ func newFinders(t testing.TB, db sqlx.Ext, opts TestFinderOptions) model.Config // Setup DB dbf := dbfinder.NewFinder(db, checker) - dbf.Clock = opts.Clock + dbf.Clock = cl // Setup RT rtf := rtfinder.NewFinder(rtfinder.NewLocalCache(), db) - rtf.Clock = opts.Clock + rtf.Clock = cl for _, rtj := range opts.RTJsons { fn := testutil.RelPath("test", "data", "rt", rtj.Fname) msg, err := rt.ReadFile(fn) @@ -76,57 +129,8 @@ func newFinders(t testing.TB, db sqlx.Ext, opts TestFinderOptions) model.Config RTFinder: rtf, GbfsFinder: gbf, Checker: checker, - Clock: opts.Clock, + Clock: cl, Storage: t.TempDir(), RTStorage: t.TempDir(), } } - -func Finders(t testing.TB, cl clock.Clock, rtJsons []RTJsonFile) model.Config { - db := testutil.MustOpenTestDB() - return newFinders(t, db, TestFinderOptions{Clock: cl, RTJsons: rtJsons}) -} - -func FindersWithOptions(t testing.TB, opts TestFinderOptions) model.Config { - db := testutil.MustOpenTestDB() - return newFinders(t, db, opts) -} - -func FindersTx(t testing.TB, cl clock.Clock, rtJsons []RTJsonFile, cb func(model.Config) error) { - // Check open DB - db := testutil.MustOpenTestDB() - // Start Txn - tx := db.MustBeginTx(context.Background(), nil) - defer tx.Rollback() - - // Get finders - testEnv := newFinders(t, tx, TestFinderOptions{Clock: cl, RTJsons: rtJsons}) - - // Commit or rollback - if err := cb(testEnv); err != nil { - //tx.Rollback() - } else { - tx.Commit() - } -} - -func FindersTxRollback(t testing.TB, cl clock.Clock, rtJsons []RTJsonFile, cb func(model.Config)) { - FindersTx(t, cl, rtJsons, func(c model.Config) error { - cb(c) - return errors.New("rollback") - }) -} - -type RTJsonFile struct { - Feed string - Ftype string - Fname string -} - -func DefaultRTJson() []RTJsonFile { - return []RTJsonFile{ - {"BA", "realtime_trip_updates", "BA.json"}, - {"BA", "realtime_alerts", "BA-alerts.json"}, - {"CT", "realtime_trip_updates", "CT.json"}, - } -} diff --git a/model/config.go b/model/config.go index 6a11779b..c1b47eee 100644 --- a/model/config.go +++ b/model/config.go @@ -9,6 +9,19 @@ import ( "github.com/rs/zerolog" ) +type Config struct { + Finder Finder + RTFinder RTFinder + GbfsFinder GbfsFinder + Checker Checker + Clock clock.Clock + Secrets []tl.Secret + ValidateLargeFiles bool + Storage string + RTStorage string + Logger zerolog.Logger +} + var finderCtxKey = &contextKey{"finderConfig"} type contextKey struct { @@ -23,29 +36,16 @@ func ForContext(ctx context.Context) Config { return raw } -func WithConfig(ctx context.Context, fs Config) context.Context { - r := context.WithValue(ctx, finderCtxKey, fs) +func WithConfig(ctx context.Context, cfg Config) context.Context { + r := context.WithValue(ctx, finderCtxKey, cfg) return r } -func AddConfig(te Config) func(http.Handler) http.Handler { +func AddConfig(cfg Config) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - r = r.WithContext(WithConfig(r.Context(), te)) + r = r.WithContext(WithConfig(r.Context(), cfg)) next.ServeHTTP(w, r) }) } } - -type Config struct { - Finder Finder - RTFinder RTFinder - GbfsFinder GbfsFinder - Checker Checker - Clock clock.Clock - Secrets []tl.Secret - ValidateLargeFiles bool - Storage string - RTStorage string - Logger zerolog.Logger -} diff --git a/server/gql/agency_resolver_test.go b/server/gql/agency_resolver_test.go index 36018c4a..01e9f9e5 100644 --- a/server/gql/agency_resolver_test.go +++ b/server/gql/agency_resolver_test.go @@ -7,7 +7,7 @@ import ( "github.com/99designs/gqlgen/client" "github.com/interline-io/transitland-mw/auth/ancheck" "github.com/interline-io/transitland-mw/auth/authz" - "github.com/interline-io/transitland-server/internal/testfinder" + "github.com/interline-io/transitland-server/internal/testconfig" "github.com/interline-io/transitland-server/internal/testutil" ) @@ -310,11 +310,10 @@ func TestAgencyResolver_Authz(t *testing.T) { t.Skip(a) return } - teOpts := testfinder.TestFinderOptions{ + te := testconfig.Config(t, testconfig.Options{ FGAModelFile: testutil.RelPath("test/authz/tls.json"), FGAModelTuples: fgaTestTuples, - } - te := testfinder.FindersWithOptions(t, teOpts) + }) srv, _ := NewServer(te) testcases := []testcase{ { diff --git a/server/gql/fvsl_cache_test.go b/server/gql/fvsl_cache_test.go index 716b470a..6220c0ec 100644 --- a/server/gql/fvsl_cache_test.go +++ b/server/gql/fvsl_cache_test.go @@ -3,11 +3,11 @@ package gql import ( "testing" - "github.com/interline-io/transitland-server/internal/testfinder" + "github.com/interline-io/transitland-server/internal/testconfig" ) func TestFvslCache(t *testing.T) { - te := testfinder.Finders(t, nil, nil) + te := testconfig.Config(t, testconfig.Options{}) c := newFvslCache(te.Finder) c.Get(1) } diff --git a/server/gql/loaders.go b/server/gql/loaders.go index 56da6318..60e4e8e8 100644 --- a/server/gql/loaders.go +++ b/server/gql/loaders.go @@ -135,11 +135,11 @@ func NewLoaders(dbf model.Finder) *Loaders { return loaders } -func loaderMiddleware(frs model.Config, next http.Handler) http.Handler { +func loaderMiddleware(cfg model.Config, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // This is per request scoped loaders/cache // Is this OK to use as a long term cache? - loaders := NewLoaders(frs.Finder) + loaders := NewLoaders(cfg.Finder) nextCtx := context.WithValue(r.Context(), loadersKey, loaders) r = r.WithContext(nextCtx) next.ServeHTTP(w, r) diff --git a/server/gql/mutation_resolver_test.go b/server/gql/mutation_resolver_test.go index 90dd2bd8..25fd4cd5 100644 --- a/server/gql/mutation_resolver_test.go +++ b/server/gql/mutation_resolver_test.go @@ -10,7 +10,7 @@ import ( "github.com/99designs/gqlgen/client" "github.com/interline-io/transitland-mw/auth/ancheck" - "github.com/interline-io/transitland-server/internal/testfinder" + "github.com/interline-io/transitland-server/internal/testconfig" "github.com/interline-io/transitland-server/internal/testutil" "github.com/interline-io/transitland-server/model" "github.com/stretchr/testify/assert" @@ -26,7 +26,7 @@ func TestFeedVersionFetchResolver(t *testing.T) { w.Write(buf) })) t.Run("found sha1", func(t *testing.T) { - testfinder.FindersTxRollback(t, nil, nil, func(te model.Config) { + testconfig.ConfigTxRollback(t, testconfig.Options{}, func(te model.Config) { srv, _ := NewServer(te) srv = model.AddConfig(te)(srv) srv = ancheck.AdminDefaultMiddleware("test")(srv) // Run all requests as admin @@ -41,7 +41,7 @@ func TestFeedVersionFetchResolver(t *testing.T) { }) }) // t.Run("requires admin access", func(t *testing.T) { - // testfinder.FindersTxRollback(t, nil, nil, func(te testfinder.TestEnv) { + // testconfig.ConfigTxRollback(t, nil, nil, func(te testconfig.TestEnv) { // srv, _ := NewServer(te.Config, te.Finder, nil, nil, nil) // srv = authn.UserDefaultMiddleware("test")(srv) // Run all requests as regular user // c := client.New(srv) @@ -165,7 +165,7 @@ func TestValidateGtfsResolver(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testfinder.FindersTxRollback(t, nil, nil, func(te model.Config) { + testconfig.ConfigTxRollback(t, testconfig.Options{}, func(te model.Config) { srv, _ := NewServer(te) srv = ancheck.UserDefaultMiddleware("test")(srv) // Run all requests as user c := client.New(srv) @@ -174,7 +174,7 @@ func TestValidateGtfsResolver(t *testing.T) { }) } t.Run("requires user access", func(t *testing.T) { - testfinder.FindersTxRollback(t, nil, nil, func(te model.Config) { + testconfig.ConfigTxRollback(t, testconfig.Options{}, func(te model.Config) { srv, _ := NewServer(te) // all requests run as anonymous context by default c := client.New(srv) resp := make(map[string]interface{}) diff --git a/server/gql/resolver_test.go b/server/gql/resolver_test.go index 25d32a98..95423931 100644 --- a/server/gql/resolver_test.go +++ b/server/gql/resolver_test.go @@ -5,13 +5,11 @@ import ( "log" "os" "testing" - "time" "github.com/99designs/gqlgen/client" "github.com/interline-io/transitland-mw/auth/ancheck" "github.com/interline-io/transitland-mw/auth/authn" - "github.com/interline-io/transitland-server/internal/clock" - "github.com/interline-io/transitland-server/internal/testfinder" + "github.com/interline-io/transitland-server/internal/testconfig" "github.com/interline-io/transitland-server/internal/testutil" "github.com/interline-io/transitland-server/model" "github.com/stretchr/testify/assert" @@ -47,18 +45,20 @@ func TestMain(m *testing.M) { // Test helpers func newTestClient(t testing.TB) (*client.Client, model.Config) { - when, err := time.Parse("2006-01-02T15:04:05", "2022-09-01T00:00:00") - if err != nil { - t.Fatal(err) - } - return newTestClientWithClock(t, &clock.Mock{T: when}, testfinder.DefaultRTJson()) + return newTestClientWithOpts(t, testconfig.Options{ + When: "2022-09-01T00:00:00", + RTJsons: testconfig.DefaultRTJson(), + }) } -func newTestClientWithClock(t testing.TB, cl clock.Clock, rtfiles []testfinder.RTJsonFile) (*client.Client, model.Config) { - te := testfinder.Finders(t, cl, rtfiles) +func newTestClientWithOpts(t testing.TB, opts testconfig.Options) (*client.Client, model.Config) { + te := testconfig.Config(t, opts) srv, _ := NewServer(te) - srvMiddleware := ancheck.NewUserDefaultMiddleware(func() authn.User { return authn.NewCtxUser("testuser", "", "").WithRoles("testrole") }) - return client.New(srvMiddleware(srv)), te + graphqlServer := model.AddConfig(te)(srv) + srvMiddleware := ancheck.NewUserDefaultMiddleware(func() authn.User { + return authn.NewCtxUser("testuser", "", "").WithRoles("testrole") + }) + return client.New(srvMiddleware(graphqlServer)), te } func toJson(m map[string]interface{}) string { diff --git a/server/gql/rt_test.go b/server/gql/rt_test.go index 2ca6181c..165ae730 100644 --- a/server/gql/rt_test.go +++ b/server/gql/rt_test.go @@ -2,12 +2,10 @@ package gql import ( "testing" - "time" "github.com/99designs/gqlgen/client" "github.com/interline-io/transitland-lib/tl/tt" - "github.com/interline-io/transitland-server/internal/clock" - "github.com/interline-io/transitland-server/internal/testfinder" + "github.com/interline-io/transitland-server/internal/testconfig" "github.com/stretchr/testify/assert" "github.com/tidwall/gjson" ) @@ -113,17 +111,16 @@ type rtTestCase struct { name string query string vars map[string]interface{} - rtfiles []testfinder.RTJsonFile + rtfiles []testconfig.RTJsonFile cb func(t *testing.T, jj string) } func testRt(t *testing.T, tc rtTestCase) { // Create a new RT Finder for each test... - when, err := time.Parse("2006-01-02T15:04:05", "2022-09-01T00:00:00") - if err != nil { - t.Fatal(err) - } - c, _ := newTestClientWithClock(t, &clock.Mock{T: when}, tc.rtfiles) + c, _ := newTestClientWithOpts(t, testconfig.Options{ + When: "2022-09-01T00:00:00", + RTJsons: tc.rtfiles, + }) var resp map[string]interface{} opts := []client.Option{} for k, v := range tc.vars { @@ -144,7 +141,7 @@ func TestStopRTBasic(t *testing.T) { "stop times basic", baseStopQuery, newBaseStopVars(), - testfinder.DefaultRTJson(), + testconfig.DefaultRTJson(), func(t *testing.T, jj string) { // A little more explicit version of the string check test a := gjson.Get(jj, "stops.0.stop_times").Array() @@ -184,7 +181,7 @@ func TestStopRTBasic_ArrivalFallback(t *testing.T) { "arrival will use departure if arrival is not present", baseStopQuery, newBaseStopVars(), - []testfinder.RTJsonFile{{Feed: "BA", Ftype: "realtime_trip_updates", Fname: "BA-arrival-fallback.json"}}, + []testconfig.RTJsonFile{{Feed: "BA", Ftype: "realtime_trip_updates", Fname: "BA-arrival-fallback.json"}}, func(t *testing.T, jj string) { a := gjson.Get(jj, "stops.0.stop_times").Array() checkTrip := "1031527WKDY" @@ -209,7 +206,7 @@ func TestStopRTBasic_DepartureFallback(t *testing.T) { "departure will use arrival if departure is not present", baseStopQuery, newBaseStopVars(), - []testfinder.RTJsonFile{{Feed: "BA", Ftype: "realtime_trip_updates", Fname: "BA-departure-fallback.json"}}, + []testconfig.RTJsonFile{{Feed: "BA", Ftype: "realtime_trip_updates", Fname: "BA-departure-fallback.json"}}, func(t *testing.T, jj string) { a := gjson.Get(jj, "stops.0.stop_times").Array() checkTrip := "1031527WKDY" @@ -234,7 +231,7 @@ func TestStopRTBasic_StopIDFallback(t *testing.T) { "use stop_id as fallback if no matching stop sequence", baseStopQuery, newBaseStopVars(), - []testfinder.RTJsonFile{{Feed: "BA", Ftype: "realtime_trip_updates", Fname: "BA-stop-id-fallback.json"}}, + []testconfig.RTJsonFile{{Feed: "BA", Ftype: "realtime_trip_updates", Fname: "BA-stop-id-fallback.json"}}, func(t *testing.T, jj string) { a := gjson.Get(jj, "stops.0.stop_times").Array() checkTrip := "1031527WKDY" @@ -260,7 +257,7 @@ func TestStopRTBasic_StopIDFallback_NoDoubleVisit(t *testing.T) { "do not use stop_id as fallback if stop is visited twice", baseStopQuery, newBaseStopVars(), - []testfinder.RTJsonFile{{Feed: "BA", Ftype: "realtime_trip_updates", Fname: "BA-stop-double-visit.json"}}, + []testconfig.RTJsonFile{{Feed: "BA", Ftype: "realtime_trip_updates", Fname: "BA-stop-double-visit.json"}}, func(t *testing.T, jj string) { a := gjson.Get(jj, "stops.0.stop_times").Array() checkTrip := "1031527WKDY" @@ -285,7 +282,7 @@ func TestStopRTBasic_NoRT(t *testing.T) { "no rt matches for trip 2211533WKDY", baseStopQuery, newBaseStopVars(), - []testfinder.RTJsonFile{{Feed: "BA", Ftype: "realtime_trip_updates", Fname: "BA-departure-fallback.json"}}, + []testconfig.RTJsonFile{{Feed: "BA", Ftype: "realtime_trip_updates", Fname: "BA-departure-fallback.json"}}, func(t *testing.T, jj string) { a := gjson.Get(jj, "stops.0.stop_times").Array() checkTrip := "2211533WKDY" @@ -313,7 +310,7 @@ func TestStopRTAddedTrip(t *testing.T) { "stop times added trip", baseStopQuery, newBaseStopVars(), - []testfinder.RTJsonFile{{Feed: "BA", Ftype: "realtime_trip_updates", Fname: "BA-added.json"}}, + []testconfig.RTJsonFile{{Feed: "BA", Ftype: "realtime_trip_updates", Fname: "BA-added.json"}}, func(t *testing.T, jj string) { checkTrip := "-123" found := false @@ -346,7 +343,7 @@ func TestStopRTCanceledTrip(t *testing.T) { "stop times canceled trip", baseStopQuery, newBaseStopVars(), - []testfinder.RTJsonFile{{Feed: "BA", Ftype: "realtime_trip_updates", Fname: "BA-added.json"}}, + []testconfig.RTJsonFile{{Feed: "BA", Ftype: "realtime_trip_updates", Fname: "BA-added.json"}}, func(t *testing.T, jj string) { checkTrip := "2211533WKDY" found := false @@ -381,7 +378,7 @@ func TestTripAlerts(t *testing.T) { "trip alerts", baseStopQuery, newBaseStopVars(), - []testfinder.RTJsonFile{ + []testconfig.RTJsonFile{ {Feed: "BA", Ftype: "realtime_alerts", Fname: "BA-alerts.json"}, }, func(t *testing.T, jj string) { @@ -407,7 +404,7 @@ func TestTripAlerts(t *testing.T) { "trip alerts active", baseStopQuery, activeVars, - []testfinder.RTJsonFile{ + []testconfig.RTJsonFile{ {Feed: "BA", Ftype: "realtime_alerts", Fname: "BA-alerts.json"}, }, func(t *testing.T, jj string) { @@ -447,7 +444,7 @@ func TestRouteAlerts(t *testing.T) { "stop alerts active", baseStopQuery, newBaseStopVars(), - []testfinder.RTJsonFile{ + []testconfig.RTJsonFile{ {Feed: "BA", Ftype: "realtime_alerts", Fname: "BA-alerts.json"}, }, func(t *testing.T, jj string) { @@ -474,7 +471,7 @@ func TestRouteAlerts(t *testing.T) { "stop alerts active", baseStopQuery, activeVars, - []testfinder.RTJsonFile{ + []testconfig.RTJsonFile{ {Feed: "BA", Ftype: "realtime_alerts", Fname: "BA-alerts.json"}, }, func(t *testing.T, jj string) { @@ -516,7 +513,7 @@ func TestStopAlerts(t *testing.T) { "stop alerts", baseStopQuery, newBaseStopVars(), - []testfinder.RTJsonFile{ + []testconfig.RTJsonFile{ {Feed: "BA", Ftype: "realtime_alerts", Fname: "BA-alerts.json"}, }, func(t *testing.T, jj string) { @@ -530,7 +527,7 @@ func TestStopAlerts(t *testing.T) { "stop alerts active", baseStopQuery, activeVars, - []testfinder.RTJsonFile{ + []testconfig.RTJsonFile{ {Feed: "BA", Ftype: "realtime_alerts", Fname: "BA-alerts.json"}, }, func(t *testing.T, jj string) { @@ -559,7 +556,7 @@ func TestAgencyAlerts(t *testing.T) { "stop alerts", baseStopQuery, newBaseStopVars(), - []testfinder.RTJsonFile{ + []testconfig.RTJsonFile{ {Feed: "BA", Ftype: "realtime_alerts", Fname: "BA-alerts.json"}, }, func(t *testing.T, jj string) { @@ -586,7 +583,7 @@ func TestAgencyAlerts(t *testing.T) { "stop alerts active", baseStopQuery, activeVars, - []testfinder.RTJsonFile{{Feed: "BA", Ftype: "realtime_alerts", Fname: "BA-alerts.json"}}, + []testconfig.RTJsonFile{{Feed: "BA", Ftype: "realtime_alerts", Fname: "BA-alerts.json"}}, func(t *testing.T, jj string) { checkTrip := "1031527WKDY" sts := gjson.Get(jj, "stops.0.stop_times").Array() diff --git a/server/gql/server.go b/server/gql/server.go index 8e71f14b..a18ef9c0 100644 --- a/server/gql/server.go +++ b/server/gql/server.go @@ -12,9 +12,9 @@ import ( "github.com/interline-io/transitland-server/model" ) -func NewServer(te model.Config) (http.Handler, error) { +func NewServer(cfg model.Config) (http.Handler, error) { c := gqlout.Config{Resolvers: &Resolver{ - fvslCache: newFvslCache(te.Finder), + fvslCache: newFvslCache(cfg.Finder), }} c.Directives.HasRole = func(ctx context.Context, obj interface{}, next graphql.Resolver, role model.Role) (interface{}, error) { user := authn.ForContext(ctx) @@ -25,7 +25,7 @@ func NewServer(te model.Config) (http.Handler, error) { } // Setup server srv := handler.NewDefaultServer(gqlout.NewExecutableSchema(c)) - graphqlServer := loaderMiddleware(te, srv) - graphqlServer = model.AddConfig(te)(graphqlServer) + graphqlServer := loaderMiddleware(cfg, srv) + // graphqlServer = model.AddConfig(cfg)(graphqlServer) return graphqlServer, nil } diff --git a/server/gql/stop_time_resolver_test.go b/server/gql/stop_time_resolver_test.go index 813f2d03..a5f621ec 100644 --- a/server/gql/stop_time_resolver_test.go +++ b/server/gql/stop_time_resolver_test.go @@ -2,10 +2,8 @@ package gql import ( "testing" - "time" - "github.com/interline-io/transitland-server/internal/clock" - "github.com/interline-io/transitland-server/internal/testfinder" + "github.com/interline-io/transitland-server/internal/testconfig" ) func TestStopResolver_StopTimes(t *testing.T) { @@ -363,11 +361,10 @@ func TestStopResolver_StopTimes_Next(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // 2018-05-28 22:00:00 +0000 UTC // 2018-05-28 15:00:00 -0700 PDT - when, err := time.Parse("2006-01-02T15:04:05", tc.when) - if err != nil { - t.Fatal(err) - } - c, _ := newTestClientWithClock(t, &clock.Mock{T: when}, testfinder.DefaultRTJson()) + c, _ := newTestClientWithOpts(t, testconfig.Options{ + When: tc.when, + RTJsons: testconfig.DefaultRTJson(), + }) queryTestcase(t, c, tc.testcase) }) } diff --git a/server/rest/feed_version_download.go b/server/rest/feed_version_download.go index fec9ec09..3c12e7c6 100644 --- a/server/rest/feed_version_download.go +++ b/server/rest/feed_version_download.go @@ -85,8 +85,8 @@ func feedVersionDownloadLatestHandler(graphqlHandler http.Handler, w http.Respon apiMeter.Meter("feed-version-downloads", 1.0, dims) } - frs := model.ForContext(r.Context()) - serveFromStorage(w, r, frs.Storage, fvsha1) + cfg := model.ForContext(r.Context()) + serveFromStorage(w, r, cfg.Storage, fvsha1) } const feedVersionFileQuery = ` @@ -162,8 +162,8 @@ func feedVersionDownloadHandler(graphqlHandler http.Handler, w http.ResponseWrit apiMeter.Meter("feed-version-downloads", 1.0, dims) } - frs := model.ForContext(r.Context()) - serveFromStorage(w, r, frs.Storage, fvsha1) + cfg := model.ForContext(r.Context()) + serveFromStorage(w, r, cfg.Storage, fvsha1) } func serveFromStorage(w http.ResponseWriter, r *http.Request, storage string, fvsha1 string) { diff --git a/server/rest/rest_test.go b/server/rest/rest_test.go index 01e8c1ee..3dfaeb0b 100644 --- a/server/rest/rest_test.go +++ b/server/rest/rest_test.go @@ -7,10 +7,8 @@ import ( "net/http" "os" "testing" - "time" - "github.com/interline-io/transitland-server/internal/clock" - "github.com/interline-io/transitland-server/internal/testfinder" + "github.com/interline-io/transitland-server/internal/testconfig" "github.com/interline-io/transitland-server/internal/testutil" "github.com/interline-io/transitland-server/model" "github.com/interline-io/transitland-server/server/gql" @@ -30,11 +28,12 @@ func TestMain(m *testing.M) { } func testRestConfig(t testing.TB) (http.Handler, model.Config) { - when, err := time.Parse("2006-01-02T15:04:05", "2018-06-01T00:00:00") - if err != nil { - t.Fatal(err) - } - te := testfinder.Finders(t, &clock.Mock{T: when}, testfinder.DefaultRTJson()) + te := testconfig.Config(t, + testconfig.Options{ + When: "2018-06-01T00:00:00", + RTJsons: testconfig.DefaultRTJson(), + }, + ) srv, err := gql.NewServer(te) if err != nil { panic(err) @@ -61,7 +60,7 @@ type testRest struct { f func(*testing.T, string) } -func testquery(t *testing.T, graphqlHandler http.Handler, te model.Config, tc testRest) { +func testquery(t *testing.T, graphqlHandler http.Handler, cfg model.Config, tc testRest) { data, err := makeRequest(context.TODO(), Config{}, graphqlHandler, tc.h, tc.format, nil) if err != nil { t.Error(err) diff --git a/server/server_cmd.go b/server/server_cmd.go index 29e88191..e4f98f38 100644 --- a/server/server_cmd.go +++ b/server/server_cmd.go @@ -223,6 +223,18 @@ func (cmd *Command) Run() error { jobQueue = jobs.NewLocalJobs() } + // Setup config + cfg := model.Config{ + Finder: dbFinder, + RTFinder: rtFinder, + GbfsFinder: gbfsFinder, + Checker: checker, + Secrets: cmd.secrets, + Storage: cmd.Storage, + RTStorage: cmd.RTStorage, + ValidateLargeFiles: cmd.ValidateLargeFiles, + } + // Setup metrics metricProvider, err := metrics.GetProvider(cmd.metricsConfig) if err != nil { @@ -247,6 +259,7 @@ func (cmd *Command) Run() error { AllowedHeaders: []string{"content-type", "apikey", "authorization"}, AllowCredentials: true, })) + root.Use(model.AddConfig(cfg)) // Setup user middleware for _, k := range cmd.AuthMiddlewares { @@ -274,16 +287,8 @@ func (cmd *Command) Run() error { } // GraphQL API - te := model.Config{ - Finder: dbFinder, - RTFinder: rtFinder, - GbfsFinder: gbfsFinder, - Checker: checker, - Secrets: cmd.secrets, - Storage: cmd.Storage, - RTStorage: cmd.RTStorage, - } - graphqlServer, err := gql.NewServer(te) + + graphqlServer, err := gql.NewServer(cfg) if err != nil { return err } @@ -333,7 +338,7 @@ func (cmd *Command) Run() error { // Start workers/api jobWorkers := 8 jobOptions := jobs.JobOptions{ - Finders: te, + Finders: cfg, Logger: log.Logger, JobQueue: jobQueue, } diff --git a/workers/gbfs_fetch_worker_test.go b/workers/gbfs_fetch_worker_test.go index 5e48ab83..d8d476e5 100644 --- a/workers/gbfs_fetch_worker_test.go +++ b/workers/gbfs_fetch_worker_test.go @@ -6,7 +6,7 @@ import ( "testing" "github.com/interline-io/transitland-server/internal/gbfs" - "github.com/interline-io/transitland-server/internal/testfinder" + "github.com/interline-io/transitland-server/internal/testconfig" "github.com/interline-io/transitland-server/internal/testutil" "github.com/interline-io/transitland-server/jobs" "github.com/stretchr/testify/assert" @@ -18,7 +18,7 @@ func TestGbfsFetchWorker(t *testing.T) { ts := httptest.NewServer(&gbfs.TestGbfsServer{Language: "en", Path: testutil.RelPath("test/data/gbfs")}) defer ts.Close() - testfinder.FindersTxRollback(t, nil, nil, func(te model.Config) { + testconfig.ConfigTxRollback(t, testconfig.Options{}, func(te model.Config) { job := jobs.Job{} job.Opts.Finders = te w := GbfsFetchWorker{ From a72c508732b2d4b0d0d87a516e7431a7eb9568c6 Mon Sep 17 00:00:00 2001 From: Ian Rees Date: Mon, 18 Dec 2023 18:34:07 -0800 Subject: [PATCH 07/17] More test refactoring --- actions/fetch.go | 3 --- actions/fetch_test.go | 13 +++++++----- server/gql/agency_resolver_test.go | 4 ++-- server/gql/fvsl_cache_test.go | 4 ++-- server/gql/mutation_resolver_test.go | 16 ++++++++------- server/gql/resolver_test.go | 8 ++++---- server/gql/server.go | 1 - server/rest/agency_request_test.go | 18 ++++++++-------- server/rest/feed_request_test.go | 12 +++++------ server/rest/feed_version_download_test.go | 8 ++++---- server/rest/feed_version_request_test.go | 4 ++-- server/rest/operator_request_test.go | 12 +++++------ server/rest/rest_test.go | 9 ++++---- server/rest/route_request_test.go | 18 ++++++++-------- server/rest/stop_departure_request_test.go | 4 ++-- server/rest/stop_request_test.go | 24 +++++++++++----------- server/rest/trip_request_test.go | 16 +++++++-------- server/server_cmd.go | 2 +- workers/gbfs_fetch_worker_test.go | 6 +++--- 19 files changed, 92 insertions(+), 90 deletions(-) diff --git a/actions/fetch.go b/actions/fetch.go index 3d229ab9..35d76caf 100644 --- a/actions/fetch.go +++ b/actions/fetch.go @@ -256,9 +256,6 @@ func chunkBy[T any](items []T, chunkSize int) (chunks [][]T) { func fetchCheckFeed(ctx context.Context, feedId string, urlType string, url string) (*model.Feed, error) { cfg := model.ForContext(ctx) - if cfg.Finder == nil { - panic("no finder") - } checker := cfg.Checker // Check feed exists diff --git a/actions/fetch_test.go b/actions/fetch_test.go index 23360242..d0659fe6 100644 --- a/actions/fetch_test.go +++ b/actions/fetch_test.go @@ -2,9 +2,10 @@ package actions import ( "context" - "io/ioutil" + "fmt" "net/http" "net/http/httptest" + "os" "testing" sq "github.com/Masterminds/squirrel" @@ -99,7 +100,7 @@ func TestStaticFetchWorker(t *testing.T) { return } - buf, err := ioutil.ReadFile(testutil.RelPath(tc.serveFile)) + buf, err := os.ReadFile(testutil.RelPath(tc.serveFile)) if err != nil { http.Error(w, "404", 404) return @@ -110,8 +111,10 @@ func TestStaticFetchWorker(t *testing.T) { // Setup job feedUrl := ts.URL + "/" + tc.serveFile - testconfig.ConfigTxRollback(t, testconfig.Options{}, func(te model.Config) { - ctx := model.WithConfig(context.Background(), te) + testconfig.ConfigTxRollback(t, testconfig.Options{}, func(cfg model.Config) { + fmt.Printf("checker %#v\n", cfg.Checker) + cfg.Checker = nil + ctx := model.WithConfig(context.Background(), cfg) // Run job if result, err := StaticFetch(ctx, tc.feedId, nil, feedUrl); err != nil && !tc.expectError { _ = result @@ -125,7 +128,7 @@ func TestStaticFetchWorker(t *testing.T) { ff := dmfr.FeedFetch{} if err := dbutil.Get( context.Background(), - te.Finder.DBX(), + cfg.Finder.DBX(), sq.StatementBuilder. Select("ff.*"). From("feed_fetches ff"). diff --git a/server/gql/agency_resolver_test.go b/server/gql/agency_resolver_test.go index 01e9f9e5..8992f09b 100644 --- a/server/gql/agency_resolver_test.go +++ b/server/gql/agency_resolver_test.go @@ -310,11 +310,11 @@ func TestAgencyResolver_Authz(t *testing.T) { t.Skip(a) return } - te := testconfig.Config(t, testconfig.Options{ + cfg := testconfig.Config(t, testconfig.Options{ FGAModelFile: testutil.RelPath("test/authz/tls.json"), FGAModelTuples: fgaTestTuples, }) - srv, _ := NewServer(te) + srv, _ := NewServer(cfg) testcases := []testcase{ { name: "basic", diff --git a/server/gql/fvsl_cache_test.go b/server/gql/fvsl_cache_test.go index 6220c0ec..56a42ac1 100644 --- a/server/gql/fvsl_cache_test.go +++ b/server/gql/fvsl_cache_test.go @@ -7,7 +7,7 @@ import ( ) func TestFvslCache(t *testing.T) { - te := testconfig.Config(t, testconfig.Options{}) - c := newFvslCache(te.Finder) + cfg := testconfig.Config(t, testconfig.Options{}) + c := newFvslCache(cfg.Finder) c.Get(1) } diff --git a/server/gql/mutation_resolver_test.go b/server/gql/mutation_resolver_test.go index 25fd4cd5..3d3331ae 100644 --- a/server/gql/mutation_resolver_test.go +++ b/server/gql/mutation_resolver_test.go @@ -26,9 +26,9 @@ func TestFeedVersionFetchResolver(t *testing.T) { w.Write(buf) })) t.Run("found sha1", func(t *testing.T) { - testconfig.ConfigTxRollback(t, testconfig.Options{}, func(te model.Config) { - srv, _ := NewServer(te) - srv = model.AddConfig(te)(srv) + testconfig.ConfigTxRollback(t, testconfig.Options{}, func(cfg model.Config) { + srv, _ := NewServer(cfg) + srv = model.AddConfig(cfg)(srv) srv = ancheck.AdminDefaultMiddleware("test")(srv) // Run all requests as admin // Run all requests as admin c := client.New(srv) @@ -165,8 +165,9 @@ func TestValidateGtfsResolver(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testconfig.ConfigTxRollback(t, testconfig.Options{}, func(te model.Config) { - srv, _ := NewServer(te) + testconfig.ConfigTxRollback(t, testconfig.Options{}, func(cfg model.Config) { + srv, _ := NewServer(cfg) + srv = model.AddConfig(cfg)(srv) srv = ancheck.UserDefaultMiddleware("test")(srv) // Run all requests as user c := client.New(srv) queryTestcase(t, c, tc) @@ -174,8 +175,9 @@ func TestValidateGtfsResolver(t *testing.T) { }) } t.Run("requires user access", func(t *testing.T) { - testconfig.ConfigTxRollback(t, testconfig.Options{}, func(te model.Config) { - srv, _ := NewServer(te) // all requests run as anonymous context by default + testconfig.ConfigTxRollback(t, testconfig.Options{}, func(cfg model.Config) { + srv, _ := NewServer(cfg) // all requests run as anonymous context by default + srv = model.AddConfig(cfg)(srv) c := client.New(srv) resp := make(map[string]interface{}) err := c.Post(`mutation($url:String!) {validate_gtfs(url:$url){success}}`, &resp, client.Var("url", ts200.URL)) diff --git a/server/gql/resolver_test.go b/server/gql/resolver_test.go index 95423931..8bead1cf 100644 --- a/server/gql/resolver_test.go +++ b/server/gql/resolver_test.go @@ -52,13 +52,13 @@ func newTestClient(t testing.TB) (*client.Client, model.Config) { } func newTestClientWithOpts(t testing.TB, opts testconfig.Options) (*client.Client, model.Config) { - te := testconfig.Config(t, opts) - srv, _ := NewServer(te) - graphqlServer := model.AddConfig(te)(srv) + cfg := testconfig.Config(t, opts) + srv, _ := NewServer(cfg) + graphqlServer := model.AddConfig(cfg)(srv) srvMiddleware := ancheck.NewUserDefaultMiddleware(func() authn.User { return authn.NewCtxUser("testuser", "", "").WithRoles("testrole") }) - return client.New(srvMiddleware(graphqlServer)), te + return client.New(srvMiddleware(graphqlServer)), cfg } func toJson(m map[string]interface{}) string { diff --git a/server/gql/server.go b/server/gql/server.go index a18ef9c0..02f4e224 100644 --- a/server/gql/server.go +++ b/server/gql/server.go @@ -26,6 +26,5 @@ func NewServer(cfg model.Config) (http.Handler, error) { // Setup server srv := handler.NewDefaultServer(gqlout.NewExecutableSchema(c)) graphqlServer := loaderMiddleware(cfg, srv) - // graphqlServer = model.AddConfig(cfg)(graphqlServer) return graphqlServer, nil } diff --git a/server/rest/agency_request_test.go b/server/rest/agency_request_test.go index f1569a73..a7440bd0 100644 --- a/server/rest/agency_request_test.go +++ b/server/rest/agency_request_test.go @@ -10,7 +10,7 @@ import ( ) func TestAgencyRequest(t *testing.T) { - srv, te := testRestConfig(t) + srv, _ := testRestConfig(t) fv := "e535eb2b3b9ac3ef15d82c56575e914575e732e0" testcases := []testRest{ { @@ -159,7 +159,7 @@ func TestAgencyRequest(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + testquery(t, srv, tc) }) } } @@ -192,17 +192,17 @@ func TestAgencyRequest_Format(t *testing.T) { }, }, } - srv, te := testRestConfig(t) + srv, _ := testRestConfig(t) for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + testquery(t, srv, tc) }) } } func TestAgencyRequest_Pagination(t *testing.T) { - srv, te := testRestConfig(t) - allEnts, err := te.Finder.FindAgencies(context.Background(), nil, nil, nil, nil) + srv, cfg := testRestConfig(t) + allEnts, err := cfg.Finder.FindAgencies(context.Background(), nil, nil, nil, nil) if err != nil { t.Fatal(err) } @@ -240,7 +240,7 @@ func TestAgencyRequest_Pagination(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + testquery(t, srv, tc) }) } } @@ -293,10 +293,10 @@ func TestAgencyRequest_License(t *testing.T) { expectSelect: []string{"caltrain-ca-us", ""}, }, } - srv, te := testRestConfig(t) + srv, _ := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + testquery(t, srv, tc) }) } } diff --git a/server/rest/feed_request_test.go b/server/rest/feed_request_test.go index 0ea1d5b3..46442999 100644 --- a/server/rest/feed_request_test.go +++ b/server/rest/feed_request_test.go @@ -156,10 +156,10 @@ func TestFeedRequest(t *testing.T) { expectLength: 0, }, } - srv, te := testRestConfig(t) + srv, _ := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + testquery(t, srv, tc) }) } } @@ -192,10 +192,10 @@ func TestFeedRequest_Format(t *testing.T) { }, }, } - srv, te := testRestConfig(t) + srv, _ := testRestConfig(t) for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + testquery(t, srv, tc) }) } } @@ -248,10 +248,10 @@ func TestFeedRequest_License(t *testing.T) { expectSelect: []string{"CT", "test-gbfs", "HA", "BA~rt", "CT~rt", "test", "EX"}, }, } - srv, te := testRestConfig(t) + srv, _ := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + testquery(t, srv, tc) }) } diff --git a/server/rest/feed_version_download_test.go b/server/rest/feed_version_download_test.go index b5bd66e8..02684235 100644 --- a/server/rest/feed_version_download_test.go +++ b/server/rest/feed_version_download_test.go @@ -16,8 +16,8 @@ func TestFeedVersionDownloadRequest(t *testing.T) { t.Skip(a) return } - srv, te := testRestConfig(t) - te.Storage = g + srv, cfg := testRestConfig(t) + cfg.Storage = g restSrv, err := testRestServer(t, Config{}, srv) if err != nil { t.Fatal(err) @@ -117,8 +117,8 @@ func TestFeedDownloadLatestRequest(t *testing.T) { t.Skip(a) return } - srv, te := testRestConfig(t) - te.Storage = g + srv, cfg := testRestConfig(t) + cfg.Storage = g restSrv, err := testRestServer(t, Config{}, srv) if err != nil { t.Fatal(err) diff --git a/server/rest/feed_version_request_test.go b/server/rest/feed_version_request_test.go index 6429073b..1ec80c35 100644 --- a/server/rest/feed_version_request_test.go +++ b/server/rest/feed_version_request_test.go @@ -127,10 +127,10 @@ func TestFeedVersionRequest(t *testing.T) { expectLength: 0, }, } - srv, te := testRestConfig(t) + srv, _ := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + testquery(t, srv, tc) }) } } diff --git a/server/rest/operator_request_test.go b/server/rest/operator_request_test.go index 0c839f9d..68440016 100644 --- a/server/rest/operator_request_test.go +++ b/server/rest/operator_request_test.go @@ -107,10 +107,10 @@ func TestOperatorRequest(t *testing.T) { expectLength: 0, }, } - srv, te := testRestConfig(t) + srv, _ := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + testquery(t, srv, tc) }) } } @@ -130,10 +130,10 @@ func TestOperatorRequest_Pagination(t *testing.T) { expectLength: 4, }, } - srv, te := testRestConfig(t) + srv, _ := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + testquery(t, srv, tc) }) } } @@ -195,10 +195,10 @@ func TestOperatorRequest_License(t *testing.T) { expectSelect: []string{"o-9q9-caltrain", "o-dhv-hillsborougharearegionaltransit", "o-9qs-demotransitauthority"}, }, } - srv, te := testRestConfig(t) + srv, _ := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + testquery(t, srv, tc) }) } } diff --git a/server/rest/rest_test.go b/server/rest/rest_test.go index 3dfaeb0b..c4e63717 100644 --- a/server/rest/rest_test.go +++ b/server/rest/rest_test.go @@ -28,17 +28,18 @@ func TestMain(m *testing.M) { } func testRestConfig(t testing.TB) (http.Handler, model.Config) { - te := testconfig.Config(t, + cfg := testconfig.Config(t, testconfig.Options{ When: "2018-06-01T00:00:00", RTJsons: testconfig.DefaultRTJson(), }, ) - srv, err := gql.NewServer(te) + srv, err := gql.NewServer(cfg) if err != nil { panic(err) } - return srv, te + srv = model.AddConfig(cfg)(srv) + return srv, cfg } func testRestServer(t testing.TB, cfg Config, srv http.Handler) (http.Handler, error) { @@ -60,7 +61,7 @@ type testRest struct { f func(*testing.T, string) } -func testquery(t *testing.T, graphqlHandler http.Handler, cfg model.Config, tc testRest) { +func testquery(t *testing.T, graphqlHandler http.Handler, tc testRest) { data, err := makeRequest(context.TODO(), Config{}, graphqlHandler, tc.h, tc.format, nil) if err != nil { t.Error(err) diff --git a/server/rest/route_request_test.go b/server/rest/route_request_test.go index 8b952032..c3c5f438 100644 --- a/server/rest/route_request_test.go +++ b/server/rest/route_request_test.go @@ -116,10 +116,10 @@ func TestRouteRequest(t *testing.T) { }, }, } - srv, te := testRestConfig(t) + srv, _ := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + testquery(t, srv, tc) }) } } @@ -152,17 +152,17 @@ func TestRouteRequest_Format(t *testing.T) { }, }, } - srv, te := testRestConfig(t) + srv, _ := testRestConfig(t) for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + testquery(t, srv, tc) }) } } func TestRouteRequest_Pagination(t *testing.T) { - srv, te := testRestConfig(t) - allEnts, err := te.Finder.FindRoutes(context.Background(), nil, nil, nil, nil) + srv, cfg := testRestConfig(t) + allEnts, err := cfg.Finder.FindRoutes(context.Background(), nil, nil, nil, nil) if err != nil { t.Fatal(err) } @@ -209,7 +209,7 @@ func TestRouteRequest_Pagination(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + testquery(t, srv, tc) }) } } @@ -262,10 +262,10 @@ func TestRouteRequest_License(t *testing.T) { expectLength: 51, }, } - srv, te := testRestConfig(t) + srv, _ := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + testquery(t, srv, tc) }) } } diff --git a/server/rest/stop_departure_request_test.go b/server/rest/stop_departure_request_test.go index ed3ddc77..d44b422f 100644 --- a/server/rest/stop_departure_request_test.go +++ b/server/rest/stop_departure_request_test.go @@ -144,10 +144,10 @@ func TestStopDepartureRequest(t *testing.T) { // 0, // }, } - srv, te := testRestConfig(t) + srv, _ := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + testquery(t, srv, tc) }) } } diff --git a/server/rest/stop_request_test.go b/server/rest/stop_request_test.go index ad5a8cd3..eba65df4 100644 --- a/server/rest/stop_request_test.go +++ b/server/rest/stop_request_test.go @@ -174,10 +174,10 @@ func TestStopRequest(t *testing.T) { expectLength: 0, }, } - srv, te := testRestConfig(t) + srv, _ := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + testquery(t, srv, tc) }) } } @@ -193,15 +193,15 @@ func TestStopRequest_AdminCache(t *testing.T) { type canLoadAdmins interface { LoadAdmins() error } - srv, te := testRestConfig(t) - if v, ok := te.Finder.(canLoadAdmins); !ok { + srv, cfg := testRestConfig(t) + if v, ok := cfg.Finder.(canLoadAdmins); !ok { t.Fatal("finder cant load admins") } else { if err := v.LoadAdmins(); err != nil { t.Fatal(err) } } - testquery(t, srv, te, tc) + testquery(t, srv, tc) } func TestStopRequest_Format(t *testing.T) { @@ -232,17 +232,17 @@ func TestStopRequest_Format(t *testing.T) { }, }, } - srv, te := testRestConfig(t) + srv, _ := testRestConfig(t) for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + testquery(t, srv, tc) }) } } func TestStopRequest_Pagination(t *testing.T) { - srv, te := testRestConfig(t) - allEnts, err := te.Finder.FindStops(context.Background(), nil, nil, nil, nil) + srv, cfg := testRestConfig(t) + allEnts, err := cfg.Finder.FindStops(context.Background(), nil, nil, nil, nil) if err != nil { t.Fatal(err) } @@ -276,7 +276,7 @@ func TestStopRequest_Pagination(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + testquery(t, srv, tc) }) } } @@ -363,10 +363,10 @@ func TestStopRequest_License(t *testing.T) { }, }, } - srv, te := testRestConfig(t) + srv, _ := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + testquery(t, srv, tc) }) } } diff --git a/server/rest/trip_request_test.go b/server/rest/trip_request_test.go index 75f8961e..22a74c96 100644 --- a/server/rest/trip_request_test.go +++ b/server/rest/trip_request_test.go @@ -11,7 +11,7 @@ import ( ) func TestTripRequest(t *testing.T) { - srv, te := testRestConfig(t) + srv, _ := testRestConfig(t) d, err := makeGraphQLRequest(context.Background(), srv, `query{routes(where:{feed_onestop_id:"BA",route_id:"11"}) {id onestop_id}}`, nil) if err != nil { t.Error("failed to get route id for tests") @@ -179,7 +179,7 @@ func TestTripRequest(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + testquery(t, srv, tc) }) } } @@ -212,10 +212,10 @@ func TestTripRequest_Format(t *testing.T) { }, }, } - srv, te := testRestConfig(t) + srv, _ := testRestConfig(t) for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + testquery(t, srv, tc) }) } } @@ -251,10 +251,10 @@ func TestTripRequest_Pagination(t *testing.T) { expectLength: 10_000, }, } - srv, te := testRestConfig(t) + srv, _ := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + testquery(t, srv, tc) }) } } @@ -307,10 +307,10 @@ func TestTripRequest_License(t *testing.T) { expectLength: 14903, }, } - srv, te := testRestConfig(t) + srv, _ := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + testquery(t, srv, tc) }) } } diff --git a/server/server_cmd.go b/server/server_cmd.go index e4f98f38..af383610 100644 --- a/server/server_cmd.go +++ b/server/server_cmd.go @@ -249,6 +249,7 @@ func (cmd *Command) Run() error { // Setup router root := chi.NewRouter() + root.Use(model.AddConfig(cfg)) root.Use(middleware.RequestID) root.Use(middleware.RealIP) root.Use(middleware.Recoverer) @@ -259,7 +260,6 @@ func (cmd *Command) Run() error { AllowedHeaders: []string{"content-type", "apikey", "authorization"}, AllowCredentials: true, })) - root.Use(model.AddConfig(cfg)) // Setup user middleware for _, k := range cmd.AuthMiddlewares { diff --git a/workers/gbfs_fetch_worker_test.go b/workers/gbfs_fetch_worker_test.go index d8d476e5..148b4c40 100644 --- a/workers/gbfs_fetch_worker_test.go +++ b/workers/gbfs_fetch_worker_test.go @@ -18,9 +18,9 @@ func TestGbfsFetchWorker(t *testing.T) { ts := httptest.NewServer(&gbfs.TestGbfsServer{Language: "en", Path: testutil.RelPath("test/data/gbfs")}) defer ts.Close() - testconfig.ConfigTxRollback(t, testconfig.Options{}, func(te model.Config) { + testconfig.ConfigTxRollback(t, testconfig.Options{}, func(cfg model.Config) { job := jobs.Job{} - job.Opts.Finders = te + job.Opts.Finders = cfg w := GbfsFetchWorker{ Url: ts.URL + "/gbfs.json", FeedID: "test-gbfs", @@ -30,7 +30,7 @@ func TestGbfsFetchWorker(t *testing.T) { t.Fatal(err) } // Test - bikes, err := te.GbfsFinder.FindBikes( + bikes, err := cfg.GbfsFinder.FindBikes( context.Background(), nil, &model.GbfsBikeRequest{ From 207faba848ce125bcecfc8af4de63ce6309afbec Mon Sep 17 00:00:00 2001 From: Ian Rees Date: Mon, 18 Dec 2023 18:39:50 -0800 Subject: [PATCH 08/17] Test does not use checker originally --- actions/fetch_test.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/actions/fetch_test.go b/actions/fetch_test.go index d0659fe6..214c8a42 100644 --- a/actions/fetch_test.go +++ b/actions/fetch_test.go @@ -2,7 +2,6 @@ package actions import ( "context" - "fmt" "net/http" "net/http/httptest" "os" @@ -112,8 +111,7 @@ func TestStaticFetchWorker(t *testing.T) { // Setup job feedUrl := ts.URL + "/" + tc.serveFile testconfig.ConfigTxRollback(t, testconfig.Options{}, func(cfg model.Config) { - fmt.Printf("checker %#v\n", cfg.Checker) - cfg.Checker = nil + cfg.Checker = nil // disable checker for this test ctx := model.WithConfig(context.Background(), cfg) // Run job if result, err := StaticFetch(ctx, tc.feedId, nil, feedUrl); err != nil && !tc.expectError { From e0042a68b2ee75fbb09c5779ad381dd518c9eab8 Mon Sep 17 00:00:00 2001 From: Ian Rees Date: Mon, 18 Dec 2023 19:35:36 -0800 Subject: [PATCH 09/17] More cleanup --- finders/dbfinder/finder_test.go | 20 ++++------------ finders/gbfsfinder/finder_test.go | 4 ++-- model/config.go | 2 ++ server/gql/agency_resolver_test.go | 7 +++--- server/gql/feed_resolver_test.go | 4 ++-- server/gql/fvsl_cache.go | 16 ++++++------- server/gql/fvsl_cache_test.go | 6 +++-- server/gql/gbfs_resolver_test.go | 8 +++---- server/gql/loaders.go | 3 ++- server/gql/mutation_resolver_test.go | 6 ++--- server/gql/resolver_test.go | 2 +- server/gql/route_resolver_test.go | 4 ++-- server/gql/server.go | 6 ++--- server/gql/stop_resolver.go | 2 +- server/gql/stop_resolver_test.go | 36 ++++++++++++++-------------- server/rest/rest_test.go | 2 +- server/server_cmd.go | 12 +++++----- 17 files changed, 67 insertions(+), 73 deletions(-) diff --git a/finders/dbfinder/finder_test.go b/finders/dbfinder/finder_test.go index c02b0011..1b1d5682 100644 --- a/finders/dbfinder/finder_test.go +++ b/finders/dbfinder/finder_test.go @@ -2,31 +2,19 @@ package dbfinder import ( "context" - "log" - "os" "testing" "github.com/interline-io/transitland-server/internal/testutil" - "github.com/interline-io/transitland-server/model" "github.com/stretchr/testify/assert" ) -var TestFinder model.Finder - -func TestMain(m *testing.M) { - if a, ok := testutil.CheckTestDB(); !ok { - log.Print(a) - return - } +func TestFinder_FindFeedVersionServiceWindow(t *testing.T) { db := testutil.MustOpenTestDB() dbf := NewFinder(db, nil) - TestFinder = dbf - os.Exit(m.Run()) -} + testFinder := dbf -func TestFinder_FindFeedVersionServiceWindow(t *testing.T) { fvm := map[string]int{} - fvs, err := TestFinder.FindFeedVersions(context.TODO(), nil, nil, nil, nil) + fvs, err := testFinder.FindFeedVersions(context.TODO(), nil, nil, nil, nil) if err != nil { t.Fatal(err) } @@ -64,7 +52,7 @@ func TestFinder_FindFeedVersionServiceWindow(t *testing.T) { } for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - start, end, best, err := TestFinder.FindFeedVersionServiceWindow(context.TODO(), tc.fvid) + start, end, best, err := testFinder.FindFeedVersionServiceWindow(context.TODO(), tc.fvid) if err != nil { t.Fatal(err) } diff --git a/finders/gbfsfinder/finder_test.go b/finders/gbfsfinder/finder_test.go index c0cf6db4..61d74575 100644 --- a/finders/gbfsfinder/finder_test.go +++ b/finders/gbfsfinder/finder_test.go @@ -20,7 +20,7 @@ func TestGbfsFinder(t *testing.T) { } client := testutil.MustOpenTestRedisClient() gbf := NewFinder(client) - setupGbfs(nil, gbf) + testSetupGbfs(gbf) tcs := []struct { p xy.Point @@ -58,7 +58,7 @@ func TestGbfsFinder(t *testing.T) { } -func setupGbfs(dbf model.Finder, gbf model.GbfsFinder) error { +func testSetupGbfs(gbf model.GbfsFinder) error { // Setup sourceFeedId := "gbfs-test" ts := httptest.NewServer(&gbfs.TestGbfsServer{Language: "en", Path: testutil.RelPath("test/data/gbfs")}) diff --git a/model/config.go b/model/config.go index c1b47eee..ddcd40fb 100644 --- a/model/config.go +++ b/model/config.go @@ -2,6 +2,7 @@ package model import ( "context" + "fmt" "net/http" "github.com/interline-io/transitland-lib/tl" @@ -44,6 +45,7 @@ func WithConfig(ctx context.Context, cfg Config) context.Context { func AddConfig(cfg Config) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Println("CONFIG 1") r = r.WithContext(WithConfig(r.Context(), cfg)) next.ServeHTTP(w, r) }) diff --git a/server/gql/agency_resolver_test.go b/server/gql/agency_resolver_test.go index 8992f09b..d6398612 100644 --- a/server/gql/agency_resolver_test.go +++ b/server/gql/agency_resolver_test.go @@ -242,8 +242,8 @@ func TestAgencyResolver(t *testing.T) { } func TestAgencyResolver_Cursor(t *testing.T) { - c, te := newTestClient(t) - allEnts, err := te.Finder.FindAgencies(context.Background(), nil, nil, nil, nil) + c, cfg := newTestClient(t) + allEnts, err := cfg.Finder.FindAgencies(context.Background(), nil, nil, nil, nil) if err != nil { t.Fatal(err) } @@ -314,7 +314,8 @@ func TestAgencyResolver_Authz(t *testing.T) { FGAModelFile: testutil.RelPath("test/authz/tls.json"), FGAModelTuples: fgaTestTuples, }) - srv, _ := NewServer(cfg) + _ = cfg + srv, _ := NewServer() testcases := []testcase{ { name: "basic", diff --git a/server/gql/feed_resolver_test.go b/server/gql/feed_resolver_test.go index b671adf6..e713b8cc 100644 --- a/server/gql/feed_resolver_test.go +++ b/server/gql/feed_resolver_test.go @@ -249,8 +249,8 @@ func TestFeedResolver(t *testing.T) { } func TestFeedResolver_Cursor(t *testing.T) { - c, te := newTestClient(t) - allEnts, err := te.Finder.FindFeeds(context.Background(), nil, nil, nil, nil) + c, cfg := newTestClient(t) + allEnts, err := cfg.Finder.FindFeeds(context.Background(), nil, nil, nil, nil) if err != nil { t.Fatal(err) } diff --git a/server/gql/fvsl_cache.go b/server/gql/fvsl_cache.go index 9c21b3b4..40a1d12c 100644 --- a/server/gql/fvsl_cache.go +++ b/server/gql/fvsl_cache.go @@ -23,38 +23,38 @@ type fvslCache struct { fvWindows map[int]fvslWindow } -func newFvslCache(f model.Finder) *fvslCache { +func newFvslCache() *fvslCache { return &fvslCache{ - Finder: f, fvWindows: map[int]fvslWindow{}, } } -func (f *fvslCache) Get(fvid int) (fvslWindow, bool) { +func (f *fvslCache) Get(ctx context.Context, fvid int) (fvslWindow, bool) { f.lock.Lock() a, ok := f.fvWindows[fvid] f.lock.Unlock() if ok { return a, ok } - a, err := f.query(fvid) + a, err := f.query(ctx, fvid) if err != nil { a.Valid = false } - f.Set(fvid, a) + f.Set(ctx, fvid, a) return a, a.Valid } -func (f *fvslCache) Set(fvid int, w fvslWindow) { +func (f *fvslCache) Set(ctx context.Context, fvid int, w fvslWindow) { f.lock.Lock() defer f.lock.Unlock() f.fvWindows[fvid] = w } -func (f *fvslCache) query(fvid int) (fvslWindow, error) { +func (f *fvslCache) query(ctx context.Context, fvid int) (fvslWindow, error) { + cfg := model.ForContext(ctx) var err error w := fvslWindow{} - w.StartDate, w.EndDate, w.BestWeek, err = f.Finder.FindFeedVersionServiceWindow(context.TODO(), fvid) + w.StartDate, w.EndDate, w.BestWeek, err = cfg.Finder.FindFeedVersionServiceWindow(context.TODO(), fvid) log.Trace(). Str("start_date", w.StartDate.Format("2006-01-02")). Str("end_date", w.EndDate.Format("2006-01-02")). diff --git a/server/gql/fvsl_cache_test.go b/server/gql/fvsl_cache_test.go index 56a42ac1..0f39d508 100644 --- a/server/gql/fvsl_cache_test.go +++ b/server/gql/fvsl_cache_test.go @@ -1,13 +1,15 @@ package gql import ( + "context" "testing" "github.com/interline-io/transitland-server/internal/testconfig" + "github.com/interline-io/transitland-server/model" ) func TestFvslCache(t *testing.T) { cfg := testconfig.Config(t, testconfig.Options{}) - c := newFvslCache(cfg.Finder) - c.Get(1) + c := newFvslCache() + c.Get(model.WithConfig(context.Background(), cfg), 1) } diff --git a/server/gql/gbfs_resolver_test.go b/server/gql/gbfs_resolver_test.go index 24f5f3ed..4be34ca5 100644 --- a/server/gql/gbfs_resolver_test.go +++ b/server/gql/gbfs_resolver_test.go @@ -77,8 +77,8 @@ func TestGbfsBikeResolver(t *testing.T) { selectExpect: []string{"0cbf9b08f8b71a6362e20c8173c071a6"}, }, } - c, te := newTestClient(t) - setupGbfs(te.GbfsFinder) + c, cfg := newTestClient(t) + setupGbfs(cfg.GbfsFinder) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { queryTestcase(t, c, tc) @@ -210,8 +210,8 @@ func TestGbfsStationResolver(t *testing.T) { selectExpect: []string{"27045384-791c-4519-8087-fce2f7c48a69"}, }, } - c, te := newTestClient(t) - setupGbfs(te.GbfsFinder) + c, cfg := newTestClient(t) + setupGbfs(cfg.GbfsFinder) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { queryTestcase(t, c, tc) diff --git a/server/gql/loaders.go b/server/gql/loaders.go index 60e4e8e8..e1a7de40 100644 --- a/server/gql/loaders.go +++ b/server/gql/loaders.go @@ -135,10 +135,11 @@ func NewLoaders(dbf model.Finder) *Loaders { return loaders } -func loaderMiddleware(cfg model.Config, next http.Handler) http.Handler { +func loaderMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // This is per request scoped loaders/cache // Is this OK to use as a long term cache? + cfg := model.ForContext(r.Context()) loaders := NewLoaders(cfg.Finder) nextCtx := context.WithValue(r.Context(), loadersKey, loaders) r = r.WithContext(nextCtx) diff --git a/server/gql/mutation_resolver_test.go b/server/gql/mutation_resolver_test.go index 3d3331ae..26841ed7 100644 --- a/server/gql/mutation_resolver_test.go +++ b/server/gql/mutation_resolver_test.go @@ -27,7 +27,7 @@ func TestFeedVersionFetchResolver(t *testing.T) { })) t.Run("found sha1", func(t *testing.T) { testconfig.ConfigTxRollback(t, testconfig.Options{}, func(cfg model.Config) { - srv, _ := NewServer(cfg) + srv, _ := NewServer() srv = model.AddConfig(cfg)(srv) srv = ancheck.AdminDefaultMiddleware("test")(srv) // Run all requests as admin // Run all requests as admin @@ -166,7 +166,7 @@ func TestValidateGtfsResolver(t *testing.T) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { testconfig.ConfigTxRollback(t, testconfig.Options{}, func(cfg model.Config) { - srv, _ := NewServer(cfg) + srv, _ := NewServer() srv = model.AddConfig(cfg)(srv) srv = ancheck.UserDefaultMiddleware("test")(srv) // Run all requests as user c := client.New(srv) @@ -176,7 +176,7 @@ func TestValidateGtfsResolver(t *testing.T) { } t.Run("requires user access", func(t *testing.T) { testconfig.ConfigTxRollback(t, testconfig.Options{}, func(cfg model.Config) { - srv, _ := NewServer(cfg) // all requests run as anonymous context by default + srv, _ := NewServer() // all requests run as anonymous context by default srv = model.AddConfig(cfg)(srv) c := client.New(srv) resp := make(map[string]interface{}) diff --git a/server/gql/resolver_test.go b/server/gql/resolver_test.go index 8bead1cf..cf5ee906 100644 --- a/server/gql/resolver_test.go +++ b/server/gql/resolver_test.go @@ -53,7 +53,7 @@ func newTestClient(t testing.TB) (*client.Client, model.Config) { func newTestClientWithOpts(t testing.TB, opts testconfig.Options) (*client.Client, model.Config) { cfg := testconfig.Config(t, opts) - srv, _ := NewServer(cfg) + srv, _ := NewServer() graphqlServer := model.AddConfig(cfg)(srv) srvMiddleware := ancheck.NewUserDefaultMiddleware(func() authn.User { return authn.NewCtxUser("testuser", "", "").WithRoles("testrole") diff --git a/server/gql/route_resolver_test.go b/server/gql/route_resolver_test.go index 32c0595d..568faf0e 100644 --- a/server/gql/route_resolver_test.go +++ b/server/gql/route_resolver_test.go @@ -306,8 +306,8 @@ func TestRouteResolver_PreviousOnestopID(t *testing.T) { } func TestRouteResolver_Cursor(t *testing.T) { - c, te := newTestClient(t) - allEnts, err := te.Finder.FindRoutes(context.Background(), nil, nil, nil, nil) + c, cfg := newTestClient(t) + allEnts, err := cfg.Finder.FindRoutes(context.Background(), nil, nil, nil, nil) if err != nil { t.Fatal(err) } diff --git a/server/gql/server.go b/server/gql/server.go index 02f4e224..0b5cc71c 100644 --- a/server/gql/server.go +++ b/server/gql/server.go @@ -12,9 +12,9 @@ import ( "github.com/interline-io/transitland-server/model" ) -func NewServer(cfg model.Config) (http.Handler, error) { +func NewServer() (http.Handler, error) { c := gqlout.Config{Resolvers: &Resolver{ - fvslCache: newFvslCache(cfg.Finder), + fvslCache: newFvslCache(), }} c.Directives.HasRole = func(ctx context.Context, obj interface{}, next graphql.Resolver, role model.Role) (interface{}, error) { user := authn.ForContext(ctx) @@ -25,6 +25,6 @@ func NewServer(cfg model.Config) (http.Handler, error) { } // Setup server srv := handler.NewDefaultServer(gqlout.NewExecutableSchema(c)) - graphqlServer := loaderMiddleware(cfg, srv) + graphqlServer := loaderMiddleware(srv) return graphqlServer, nil } diff --git a/server/gql/stop_resolver.go b/server/gql/stop_resolver.go index e7c5bead..e4f2d446 100644 --- a/server/gql/stop_resolver.go +++ b/server/gql/stop_resolver.go @@ -118,7 +118,7 @@ func (r *stopResolver) getStopTimes(ctx context.Context, obj *model.Stop, limit } // Check if service date is outside the window for this feed version if where.ServiceDate != nil && (where.UseServiceWindow != nil && *where.UseServiceWindow) { - sl, ok := r.fvslCache.Get(obj.FeedVersionID) + sl, ok := r.fvslCache.Get(ctx, obj.FeedVersionID) if !ok { return nil, errors.New("service level information not available for feed version") } diff --git a/server/gql/stop_resolver_test.go b/server/gql/stop_resolver_test.go index 47c1202a..28ce9b56 100644 --- a/server/gql/stop_resolver_test.go +++ b/server/gql/stop_resolver_test.go @@ -10,31 +10,31 @@ import ( ) func TestStopResolver(t *testing.T) { - c, te := newTestClient(t) - queryTestcases(t, c, stopResolverTestcases(t, te)) + c, cfg := newTestClient(t) + queryTestcases(t, c, stopResolverTestcases(t, cfg)) } func TestStopResolver_Cursor(t *testing.T) { - c, te := newTestClient(t) - queryTestcases(t, c, stopResolverCursorTestcases(t, te)) + c, cfg := newTestClient(t) + queryTestcases(t, c, stopResolverCursorTestcases(t, cfg)) } func TestStopResolver_PreviousOnestopID(t *testing.T) { - c, te := newTestClient(t) - queryTestcases(t, c, stopResolverPreviousOnestopIDTestcases(t, te)) + c, cfg := newTestClient(t) + queryTestcases(t, c, stopResolverPreviousOnestopIDTestcases(t, cfg)) } func TestStopResolver_License(t *testing.T) { - c, te := newTestClient(t) - queryTestcases(t, c, stopResolverLicenseTestcases(t, te)) + c, cfg := newTestClient(t) + queryTestcases(t, c, stopResolverLicenseTestcases(t, cfg)) } func TestStopResolver_AdminCache(t *testing.T) { type canLoadAdmins interface { LoadAdmins() error } - c, te := newTestClient(t) - if v, ok := te.Finder.(canLoadAdmins); !ok { + c, cfg := newTestClient(t) + if v, ok := cfg.Finder.(canLoadAdmins); !ok { t.Fatal("finder cant load admins") } else { if err := v.LoadAdmins(); err != nil { @@ -83,11 +83,11 @@ func TestStopResolver_AdminCache(t *testing.T) { } func BenchmarkStopResolver(b *testing.B) { - c, te := newTestClient(b) - benchmarkTestcases(b, c, stopResolverTestcases(b, te)) + c, cfg := newTestClient(b) + benchmarkTestcases(b, c, stopResolverTestcases(b, cfg)) } -func stopResolverTestcases(t testing.TB, te model.Config) []testcase { +func stopResolverTestcases(t testing.TB, cfg model.Config) []testcase { bartStops := []string{"12TH", "16TH", "19TH", "19TH_N", "24TH", "ANTC", "ASHB", "BALB", "BAYF", "CAST", "CIVC", "COLS", "COLM", "CONC", "DALY", "DBRK", "DUBL", "DELN", "PLZA", "EMBR", "FRMT", "FTVL", "GLEN", "HAYW", "LAFY", "LAKE", "MCAR", "MCAR_S", "MLBR", "MONT", "NBRK", "NCON", "OAKL", "ORIN", "PITT", "PCTR", "PHIL", "POWL", "RICH", "ROCK", "SBRN", "SFIA", "SANL", "SHAY", "SSAN", "UCTY", "WCRK", "WARM", "WDUB", "WOAK"} caltrainRailStops := []string{"70011", "70012", "70021", "70022", "70031", "70032", "70041", "70042", "70051", "70052", "70061", "70062", "70071", "70072", "70081", "70082", "70091", "70092", "70101", "70102", "70111", "70112", "70121", "70122", "70131", "70132", "70141", "70142", "70151", "70152", "70161", "70162", "70171", "70172", "70191", "70192", "70201", "70202", "70211", "70212", "70221", "70222", "70231", "70232", "70241", "70242", "70251", "70252", "70261", "70262", "70271", "70272", "70281", "70282", "70291", "70292", "70301", "70302", "70311", "70312", "70321", "70322"} caltrainBusStops := []string{"777402", "777403"} @@ -99,7 +99,7 @@ func stopResolverTestcases(t testing.TB, te model.Config) []testcase { allStops = append(allStops, caltrainStops...) vars := hw{"stop_id": "MCAR"} stopObsFvid := 0 - if err := te.Finder.DBX().QueryRowx("select feed_version_id from ext_performance_stop_observations limit 1").Scan(&stopObsFvid); err != nil { + if err := cfg.Finder.DBX().QueryRowx("select feed_version_id from ext_performance_stop_observations limit 1").Scan(&stopObsFvid); err != nil { t.Errorf("could not get fvid for stop observation test: %s", err.Error()) } testcases := []testcase{ @@ -510,9 +510,9 @@ func stopResolverTestcases(t testing.TB, te model.Config) []testcase { return testcases } -func stopResolverCursorTestcases(t *testing.T, te model.Config) []testcase { +func stopResolverCursorTestcases(t *testing.T, cfg model.Config) []testcase { // First 1000 stops... - dbf := te.Finder + dbf := cfg.Finder allEnts, err := dbf.FindStops(context.Background(), nil, nil, nil, nil) if err != nil { t.Fatal(err) @@ -562,7 +562,7 @@ func stopResolverCursorTestcases(t *testing.T, te model.Config) []testcase { return testcases } -func stopResolverPreviousOnestopIDTestcases(t testing.TB, te model.Config) []testcase { +func stopResolverPreviousOnestopIDTestcases(t testing.TB, cfg model.Config) []testcase { testcases := []testcase{ { name: "default", @@ -596,7 +596,7 @@ func stopResolverPreviousOnestopIDTestcases(t testing.TB, te model.Config) []tes return testcases } -func stopResolverLicenseTestcases(t testing.TB, te model.Config) []testcase { +func stopResolverLicenseTestcases(t testing.TB, cfg model.Config) []testcase { q := ` query ($lic: LicenseFilter) { stops(limit: 10000, where: {license: $lic}) { diff --git a/server/rest/rest_test.go b/server/rest/rest_test.go index c4e63717..fea08a11 100644 --- a/server/rest/rest_test.go +++ b/server/rest/rest_test.go @@ -34,7 +34,7 @@ func testRestConfig(t testing.TB) (http.Handler, model.Config) { RTJsons: testconfig.DefaultRTJson(), }, ) - srv, err := gql.NewServer(cfg) + srv, err := gql.NewServer() if err != nil { panic(err) } diff --git a/server/server_cmd.go b/server/server_cmd.go index af383610..f238076f 100644 --- a/server/server_cmd.go +++ b/server/server_cmd.go @@ -200,12 +200,10 @@ func (cmd *Command) Run() error { } // Create Finder - var dbFinder model.Finder - f := dbfinder.NewFinder(db, checker) + dbFinder := dbfinder.NewFinder(db, checker) if cmd.LoadAdmins { - f.LoadAdmins() + dbFinder.LoadAdmins() } - dbFinder = f // Create RTFinder, GBFSFinder var rtFinder model.RTFinder @@ -249,7 +247,6 @@ func (cmd *Command) Run() error { // Setup router root := chi.NewRouter() - root.Use(model.AddConfig(cfg)) root.Use(middleware.RequestID) root.Use(middleware.RealIP) root.Use(middleware.Recoverer) @@ -261,6 +258,9 @@ func (cmd *Command) Run() error { AllowCredentials: true, })) + // Finders config + root.Use(model.AddConfig(cfg)) + // Setup user middleware for _, k := range cmd.AuthMiddlewares { if userMiddleware, err := ancheck.GetUserMiddleware(k, cmd.AuthConfig, redisClient); err != nil { @@ -288,7 +288,7 @@ func (cmd *Command) Run() error { // GraphQL API - graphqlServer, err := gql.NewServer(cfg) + graphqlServer, err := gql.NewServer() if err != nil { return err } From e1338b22a3f54147ab27e34d97c9911cbfbd04f5 Mon Sep 17 00:00:00 2001 From: Ian Rees Date: Mon, 18 Dec 2023 19:35:49 -0800 Subject: [PATCH 10/17] More cleanup --- model/config.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/model/config.go b/model/config.go index ddcd40fb..c1b47eee 100644 --- a/model/config.go +++ b/model/config.go @@ -2,7 +2,6 @@ package model import ( "context" - "fmt" "net/http" "github.com/interline-io/transitland-lib/tl" @@ -45,7 +44,6 @@ func WithConfig(ctx context.Context, cfg Config) context.Context { func AddConfig(cfg Config) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Println("CONFIG 1") r = r.WithContext(WithConfig(r.Context(), cfg)) next.ServeHTTP(w, r) }) From af1066844aeda1936c2fc1c961afc6a641f0748c Mon Sep 17 00:00:00 2001 From: Ian Rees Date: Mon, 18 Dec 2023 20:31:43 -0800 Subject: [PATCH 11/17] More fixes and refactoring --- internal/testconfig/testconfig.go | 12 ++++- server/rest/agency_request_test.go | 22 ++++---- server/rest/feed_request_test.go | 15 +++--- server/rest/feed_version_download.go | 1 + server/rest/feed_version_download_test.go | 19 +++---- server/rest/feed_version_request_test.go | 5 +- server/rest/operator_request_test.go | 15 +++--- server/rest/rest_test.go | 63 +++++++++++++--------- server/rest/route_request_test.go | 22 ++++---- server/rest/stop_departure_request_test.go | 5 +- server/rest/stop_request_test.go | 28 +++++----- server/rest/trip_request_test.go | 26 +++++---- 12 files changed, 116 insertions(+), 117 deletions(-) diff --git a/internal/testconfig/testconfig.go b/internal/testconfig/testconfig.go index 0a585252..b4a03ffd 100644 --- a/internal/testconfig/testconfig.go +++ b/internal/testconfig/testconfig.go @@ -25,6 +25,8 @@ import ( type Options struct { When string + Storage string + RTStorage string RTJsons []RTJsonFile FGAModelFile string FGAModelTuples []authz.TupleKey @@ -124,13 +126,19 @@ func newTestConfig(t testing.TB, db sqlx.Ext, opts Options) model.Config { // Setup GBFS gbf := gbfsfinder.NewFinder(nil) + if opts.Storage == "" { + opts.Storage = t.TempDir() + } + if opts.RTStorage == "" { + opts.RTStorage = t.TempDir() + } return model.Config{ Finder: dbf, RTFinder: rtf, GbfsFinder: gbf, Checker: checker, Clock: cl, - Storage: t.TempDir(), - RTStorage: t.TempDir(), + Storage: opts.Storage, + RTStorage: opts.RTStorage, } } diff --git a/server/rest/agency_request_test.go b/server/rest/agency_request_test.go index a7440bd0..3675441e 100644 --- a/server/rest/agency_request_test.go +++ b/server/rest/agency_request_test.go @@ -5,14 +5,14 @@ import ( "strings" "testing" + "github.com/interline-io/transitland-server/internal/testconfig" "github.com/stretchr/testify/assert" "github.com/tidwall/gjson" ) func TestAgencyRequest(t *testing.T) { - srv, _ := testRestConfig(t) fv := "e535eb2b3b9ac3ef15d82c56575e914575e732e0" - testcases := []testRest{ + testcases := []testCase{ { name: "basic", h: AgencyRequest{}, @@ -159,13 +159,13 @@ func TestAgencyRequest(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, tc) + checkTestCase(t, tc) }) } } func TestAgencyRequest_Format(t *testing.T) { - tcs := []testRest{ + tcs := []testCase{ { name: "agency geojson", format: "geojson", @@ -192,16 +192,15 @@ func TestAgencyRequest_Format(t *testing.T) { }, }, } - srv, _ := testRestConfig(t) for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, tc) + checkTestCase(t, tc) }) } } func TestAgencyRequest_Pagination(t *testing.T) { - srv, cfg := testRestConfig(t) + graphqlHandler, restHandler, cfg := testHandlersWithOptions(t, testconfig.Options{}) allEnts, err := cfg.Finder.FindAgencies(context.Background(), nil, nil, nil, nil) if err != nil { t.Fatal(err) @@ -210,7 +209,7 @@ func TestAgencyRequest_Pagination(t *testing.T) { for _, ent := range allEnts { allIds = append(allIds, ent.AgencyID) } - testcases := []testRest{ + testcases := []testCase{ { name: "limit:1", h: AgencyRequest{WithCursor: WithCursor{Limit: 1}}, @@ -240,13 +239,13 @@ func TestAgencyRequest_Pagination(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, tc) + checkTestCaseWithHandlers(t, tc, graphqlHandler, restHandler) }) } } func TestAgencyRequest_License(t *testing.T) { - testcases := []testRest{ + testcases := []testCase{ { name: "license:share_alike_optional yes", h: AgencyRequest{WithCursor: WithCursor{Limit: 10_000}, LicenseFilter: LicenseFilter{LicenseShareAlikeOptional: "yes"}}, selector: "agencies.#.agency_id", @@ -293,10 +292,9 @@ func TestAgencyRequest_License(t *testing.T) { expectSelect: []string{"caltrain-ca-us", ""}, }, } - srv, _ := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, tc) + checkTestCase(t, tc) }) } } diff --git a/server/rest/feed_request_test.go b/server/rest/feed_request_test.go index 46442999..687da1b1 100644 --- a/server/rest/feed_request_test.go +++ b/server/rest/feed_request_test.go @@ -11,7 +11,7 @@ import ( func TestFeedRequest(t *testing.T) { // fv := "e535eb2b3b9ac3ef15d82c56575e914575e732e0" - testcases := []testRest{ + testcases := []testCase{ { name: "basic", h: &FeedRequest{}, @@ -156,16 +156,15 @@ func TestFeedRequest(t *testing.T) { expectLength: 0, }, } - srv, _ := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, tc) + checkTestCase(t, tc) }) } } func TestFeedRequest_Format(t *testing.T) { - tcs := []testRest{ + tcs := []testCase{ { name: "feed geojson", format: "geojson", @@ -192,16 +191,15 @@ func TestFeedRequest_Format(t *testing.T) { }, }, } - srv, _ := testRestConfig(t) for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, tc) + checkTestCase(t, tc) }) } } func TestFeedRequest_License(t *testing.T) { - testcases := []testRest{ + testcases := []testCase{ { name: "license:share_alike_optional yes", h: FeedRequest{LicenseFilter: LicenseFilter{LicenseShareAlikeOptional: "yes"}}, selector: "feeds.#.onestop_id", @@ -248,10 +246,9 @@ func TestFeedRequest_License(t *testing.T) { expectSelect: []string{"CT", "test-gbfs", "HA", "BA~rt", "CT~rt", "test", "EX"}, }, } - srv, _ := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, tc) + checkTestCase(t, tc) }) } diff --git a/server/rest/feed_version_download.go b/server/rest/feed_version_download.go index 3c12e7c6..66fc636f 100644 --- a/server/rest/feed_version_download.go +++ b/server/rest/feed_version_download.go @@ -86,6 +86,7 @@ func feedVersionDownloadLatestHandler(graphqlHandler http.Handler, w http.Respon } cfg := model.ForContext(r.Context()) + fmt.Printf("storage: %#v\n", cfg) serveFromStorage(w, r, cfg.Storage, fvsha1) } diff --git a/server/rest/feed_version_download_test.go b/server/rest/feed_version_download_test.go index 02684235..7a607075 100644 --- a/server/rest/feed_version_download_test.go +++ b/server/rest/feed_version_download_test.go @@ -7,6 +7,7 @@ import ( "github.com/interline-io/transitland-mw/auth/ancheck" "github.com/interline-io/transitland-mw/auth/authn" + "github.com/interline-io/transitland-server/internal/testconfig" "github.com/interline-io/transitland-server/internal/testutil" ) @@ -16,12 +17,9 @@ func TestFeedVersionDownloadRequest(t *testing.T) { t.Skip(a) return } - srv, cfg := testRestConfig(t) - cfg.Storage = g - restSrv, err := testRestServer(t, Config{}, srv) - if err != nil { - t.Fatal(err) - } + _, restSrv, _ := testHandlersWithOptions(t, testconfig.Options{ + Storage: g, + }) t.Run("ok", func(t *testing.T) { req, _ := http.NewRequest("GET", "/feed_versions/d2813c293bcfd7a97dde599527ae6c62c98e66c6/download", nil) @@ -117,12 +115,9 @@ func TestFeedDownloadLatestRequest(t *testing.T) { t.Skip(a) return } - srv, cfg := testRestConfig(t) - cfg.Storage = g - restSrv, err := testRestServer(t, Config{}, srv) - if err != nil { - t.Fatal(err) - } + _, restSrv, _ := testHandlersWithOptions(t, testconfig.Options{ + Storage: g, + }) t.Run("ok", func(t *testing.T) { req, _ := http.NewRequest("GET", "/feeds/CT/download_latest_feed_version", nil) diff --git a/server/rest/feed_version_request_test.go b/server/rest/feed_version_request_test.go index 1ec80c35..56192429 100644 --- a/server/rest/feed_version_request_test.go +++ b/server/rest/feed_version_request_test.go @@ -8,7 +8,7 @@ import ( func TestFeedVersionRequest(t *testing.T) { fv := "d2813c293bcfd7a97dde599527ae6c62c98e66c6" - testcases := []testRest{ + testcases := []testCase{ { name: "basic", h: FeedVersionRequest{}, @@ -127,10 +127,9 @@ func TestFeedVersionRequest(t *testing.T) { expectLength: 0, }, } - srv, _ := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, tc) + checkTestCase(t, tc) }) } } diff --git a/server/rest/operator_request_test.go b/server/rest/operator_request_test.go index 68440016..82de0439 100644 --- a/server/rest/operator_request_test.go +++ b/server/rest/operator_request_test.go @@ -5,7 +5,7 @@ import ( ) func TestOperatorRequest(t *testing.T) { - testcases := []testRest{ + testcases := []testCase{ { name: "basic", h: OperatorRequest{}, @@ -107,16 +107,15 @@ func TestOperatorRequest(t *testing.T) { expectLength: 0, }, } - srv, _ := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, tc) + checkTestCase(t, tc) }) } } func TestOperatorRequest_Pagination(t *testing.T) { - testcases := []testRest{ + testcases := []testCase{ { name: "limit:1", h: OperatorRequest{WithCursor: WithCursor{Limit: 1}}, @@ -130,16 +129,15 @@ func TestOperatorRequest_Pagination(t *testing.T) { expectLength: 4, }, } - srv, _ := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, tc) + checkTestCase(t, tc) }) } } func TestOperatorRequest_License(t *testing.T) { - testcases := []testRest{ + testcases := []testCase{ { name: "license:share_alike_optional yes", h: OperatorRequest{WithCursor: WithCursor{Limit: 10_000}, LicenseFilter: LicenseFilter{LicenseShareAlikeOptional: "yes"}}, @@ -195,10 +193,9 @@ func TestOperatorRequest_License(t *testing.T) { expectSelect: []string{"o-9q9-caltrain", "o-dhv-hillsborougharearegionaltransit", "o-9qs-demotransitauthority"}, }, } - srv, _ := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, tc) + checkTestCase(t, tc) }) } } diff --git a/server/rest/rest_test.go b/server/rest/rest_test.go index fea08a11..90fb5e82 100644 --- a/server/rest/rest_test.go +++ b/server/rest/rest_test.go @@ -27,31 +27,7 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } -func testRestConfig(t testing.TB) (http.Handler, model.Config) { - cfg := testconfig.Config(t, - testconfig.Options{ - When: "2018-06-01T00:00:00", - RTJsons: testconfig.DefaultRTJson(), - }, - ) - srv, err := gql.NewServer() - if err != nil { - panic(err) - } - srv = model.AddConfig(cfg)(srv) - return srv, cfg -} - -func testRestServer(t testing.TB, cfg Config, srv http.Handler) (http.Handler, error) { - return NewServer(cfg, srv) -} - -func toJson(m map[string]interface{}) string { - rr, _ := json.Marshal(&m) - return string(rr) -} - -type testRest struct { +type testCase struct { name string h apiHandler format string @@ -61,7 +37,37 @@ type testRest struct { f func(*testing.T, string) } -func testquery(t *testing.T, graphqlHandler http.Handler, tc testRest) { +func testHandlersWithOptions(t testing.TB, opts testconfig.Options) (http.Handler, http.Handler, model.Config) { + cfg := testconfig.Config(t, opts) + graphqlHandler, err := gql.NewServer() + if err != nil { + t.Fatal(err) + } + restHandler, err := NewServer(Config{}, graphqlHandler) + if err != nil { + t.Fatal(err) + } + return model.AddConfig(cfg)(graphqlHandler), model.AddConfig(cfg)(restHandler), cfg +} + +func checkTestCase(t *testing.T, tc testCase) { + opts := testconfig.Options{ + When: "2018-06-01T00:00:00", + RTJsons: testconfig.DefaultRTJson(), + } + cfg := testconfig.Config(t, opts) + graphqlHandler, err := gql.NewServer() + if err != nil { + t.Fatal(err) + } + restHandler, err := NewServer(Config{}, graphqlHandler) + if err != nil { + t.Fatal(err) + } + checkTestCaseWithHandlers(t, tc, model.AddConfig(cfg)(graphqlHandler), restHandler) +} + +func checkTestCaseWithHandlers(t *testing.T, tc testCase, graphqlHandler http.Handler, restHandler http.Handler) { data, err := makeRequest(context.TODO(), Config{}, graphqlHandler, tc.h, tc.format, nil) if err != nil { t.Error(err) @@ -97,3 +103,8 @@ func testquery(t *testing.T, graphqlHandler http.Handler, tc testRest) { t.Errorf("no test performed, check test case") } } + +func toJson(m map[string]interface{}) string { + rr, _ := json.Marshal(&m) + return string(rr) +} diff --git a/server/rest/route_request_test.go b/server/rest/route_request_test.go index c3c5f438..bb92ed92 100644 --- a/server/rest/route_request_test.go +++ b/server/rest/route_request_test.go @@ -5,6 +5,7 @@ import ( "strings" "testing" + "github.com/interline-io/transitland-server/internal/testconfig" "github.com/interline-io/transitland-server/model" "github.com/stretchr/testify/assert" "github.com/tidwall/gjson" @@ -13,7 +14,7 @@ import ( func TestRouteRequest(t *testing.T) { routeIds := []string{"1", "12", "14", "15", "16", "17", "19", "20", "24", "25", "275", "30", "31", "32", "33", "34", "35", "36", "360", "37", "38", "39", "400", "42", "45", "46", "48", "5", "51", "6", "60", "7", "75", "8", "9", "96", "97", "570", "571", "572", "573", "574", "800", "PWT", "SKY", "01", "03", "05", "07", "11", "19", "Bu-130", "Li-130", "Lo-130", "TaSj-130", "Gi-130", "Sp-130"} fv := "e535eb2b3b9ac3ef15d82c56575e914575e732e0" - testcases := []testRest{ + testcases := []testCase{ { name: "none", h: RouteRequest{WithCursor: WithCursor{Limit: 1000}}, @@ -116,16 +117,15 @@ func TestRouteRequest(t *testing.T) { }, }, } - srv, _ := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, tc) + checkTestCase(t, tc) }) } } func TestRouteRequest_Format(t *testing.T) { - tcs := []testRest{ + tcs := []testCase{ { name: "route geojson", format: "geojson", @@ -152,16 +152,15 @@ func TestRouteRequest_Format(t *testing.T) { }, }, } - srv, _ := testRestConfig(t) for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, tc) + checkTestCase(t, tc) }) } } func TestRouteRequest_Pagination(t *testing.T) { - srv, cfg := testRestConfig(t) + graphqlHandler, restHandler, cfg := testHandlersWithOptions(t, testconfig.Options{}) allEnts, err := cfg.Finder.FindRoutes(context.Background(), nil, nil, nil, nil) if err != nil { t.Fatal(err) @@ -170,7 +169,7 @@ func TestRouteRequest_Pagination(t *testing.T) { for _, ent := range allEnts { allIds = append(allIds, ent.RouteID) } - testcases := []testRest{ + testcases := []testCase{ { name: "limit:1", h: RouteRequest{WithCursor: WithCursor{Limit: 1}}, @@ -209,13 +208,13 @@ func TestRouteRequest_Pagination(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, tc) + checkTestCaseWithHandlers(t, tc, graphqlHandler, restHandler) }) } } func TestRouteRequest_License(t *testing.T) { - testcases := []testRest{ + testcases := []testCase{ { name: "license:share_alike_optional yes", h: RouteRequest{WithCursor: WithCursor{Limit: 10_000}, LicenseFilter: LicenseFilter{LicenseShareAlikeOptional: "yes"}}, selector: "routes.#.route_id", @@ -262,10 +261,9 @@ func TestRouteRequest_License(t *testing.T) { expectLength: 51, }, } - srv, _ := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, tc) + checkTestCase(t, tc) }) } } diff --git a/server/rest/stop_departure_request_test.go b/server/rest/stop_departure_request_test.go index d44b422f..dad9d19f 100644 --- a/server/rest/stop_departure_request_test.go +++ b/server/rest/stop_departure_request_test.go @@ -12,7 +12,7 @@ func TestStopDepartureRequest(t *testing.T) { return &v } sid := "s-9q9nfsxn67-fruitvale" - testcases := []testRest{ + testcases := []testCase{ { name: "basic", h: StopDepartureRequest{StopKey: sid}, @@ -144,10 +144,9 @@ func TestStopDepartureRequest(t *testing.T) { // 0, // }, } - srv, _ := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, tc) + checkTestCase(t, tc) }) } } diff --git a/server/rest/stop_request_test.go b/server/rest/stop_request_test.go index eba65df4..f5849d28 100644 --- a/server/rest/stop_request_test.go +++ b/server/rest/stop_request_test.go @@ -5,6 +5,7 @@ import ( "strings" "testing" + "github.com/interline-io/transitland-server/internal/testconfig" "github.com/interline-io/transitland-server/model" "github.com/stretchr/testify/assert" "github.com/tidwall/gjson" @@ -18,7 +19,7 @@ func TestStopRequest(t *testing.T) { caltrainBusStops := []string{"777402", "777403"} _ = caltrainRailStops _ = caltrainBusStops - testcases := []testRest{ + testcases := []testCase{ { name: "basic", h: StopRequest{}, @@ -174,16 +175,15 @@ func TestStopRequest(t *testing.T) { expectLength: 0, }, } - srv, _ := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, tc) + checkTestCase(t, tc) }) } } func TestStopRequest_AdminCache(t *testing.T) { - tc := testRest{ + tc := testCase{ name: "place", h: StopRequest{StopKey: "BA:FTVL"}, selector: "stops.#.place.adm1_name", @@ -193,7 +193,7 @@ func TestStopRequest_AdminCache(t *testing.T) { type canLoadAdmins interface { LoadAdmins() error } - srv, cfg := testRestConfig(t) + graphqlHandler, restHandler, cfg := testHandlersWithOptions(t, testconfig.Options{}) if v, ok := cfg.Finder.(canLoadAdmins); !ok { t.Fatal("finder cant load admins") } else { @@ -201,11 +201,11 @@ func TestStopRequest_AdminCache(t *testing.T) { t.Fatal(err) } } - testquery(t, srv, tc) + checkTestCaseWithHandlers(t, tc, graphqlHandler, restHandler) } func TestStopRequest_Format(t *testing.T) { - tcs := []testRest{ + tcs := []testCase{ { name: "stop geojson", format: "geojson", @@ -232,16 +232,15 @@ func TestStopRequest_Format(t *testing.T) { }, }, } - srv, _ := testRestConfig(t) for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, tc) + checkTestCase(t, tc) }) } } func TestStopRequest_Pagination(t *testing.T) { - srv, cfg := testRestConfig(t) + graphqlHandler, restHandler, cfg := testHandlersWithOptions(t, testconfig.Options{}) allEnts, err := cfg.Finder.FindStops(context.Background(), nil, nil, nil, nil) if err != nil { t.Fatal(err) @@ -250,7 +249,7 @@ func TestStopRequest_Pagination(t *testing.T) { for _, ent := range allEnts { allIds = append(allIds, ent.StopID) } - testcases := []testRest{ + testcases := []testCase{ { name: "pagination exists", h: StopRequest{}, @@ -276,13 +275,13 @@ func TestStopRequest_Pagination(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, tc) + checkTestCaseWithHandlers(t, tc, graphqlHandler, restHandler) }) } } func TestStopRequest_License(t *testing.T) { - testcases := []testRest{ + testcases := []testCase{ { name: "license:share_alike_optional yes", h: StopRequest{WithCursor: WithCursor{Limit: 10_000}, LicenseFilter: LicenseFilter{LicenseShareAlikeOptional: "yes"}}, @@ -363,10 +362,9 @@ func TestStopRequest_License(t *testing.T) { }, }, } - srv, _ := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, tc) + checkTestCase(t, tc) }) } } diff --git a/server/rest/trip_request_test.go b/server/rest/trip_request_test.go index 22a74c96..f791b35c 100644 --- a/server/rest/trip_request_test.go +++ b/server/rest/trip_request_test.go @@ -6,19 +6,20 @@ import ( "strings" "testing" + "github.com/interline-io/transitland-server/internal/testconfig" "github.com/stretchr/testify/assert" "github.com/tidwall/gjson" ) func TestTripRequest(t *testing.T) { - srv, _ := testRestConfig(t) - d, err := makeGraphQLRequest(context.Background(), srv, `query{routes(where:{feed_onestop_id:"BA",route_id:"11"}) {id onestop_id}}`, nil) + graphqlHandler, restHandler, _ := testHandlersWithOptions(t, testconfig.Options{}) + d, err := makeGraphQLRequest(context.Background(), graphqlHandler, `query{routes(where:{feed_onestop_id:"BA",route_id:"11"}) {id onestop_id}}`, nil) if err != nil { t.Error("failed to get route id for tests") } routeId := int(gjson.Get(toJson(d), "routes.0.id").Int()) routeOnestopId := gjson.Get(toJson(d), "routes.0.onestop_id").String() - d2, err := makeGraphQLRequest(context.Background(), srv, `query{trips(where:{trip_id:"5132248WKDY"}){id}}`, nil) + d2, err := makeGraphQLRequest(context.Background(), graphqlHandler, `query{trips(where:{trip_id:"5132248WKDY"}){id}}`, nil) if err != nil { t.Error("failed to get route id for tests") } @@ -26,7 +27,7 @@ func TestTripRequest(t *testing.T) { fv := "e535eb2b3b9ac3ef15d82c56575e914575e732e0" ctfv := "d2813c293bcfd7a97dde599527ae6c62c98e66c6" - testcases := []testRest{ + testcases := []testCase{ { name: "none", h: TripRequest{}, @@ -179,13 +180,13 @@ func TestTripRequest(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, tc) + checkTestCaseWithHandlers(t, tc, graphqlHandler, restHandler) }) } } func TestTripRequest_Format(t *testing.T) { - tcs := []testRest{ + tcs := []testCase{ { name: "trip geojson", format: "geojson", @@ -212,16 +213,15 @@ func TestTripRequest_Format(t *testing.T) { }, }, } - srv, _ := testRestConfig(t) for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, tc) + checkTestCase(t, tc) }) } } func TestTripRequest_Pagination(t *testing.T) { - testcases := []testRest{ + testcases := []testCase{ { name: "limit:1", h: TripRequest{WithCursor: WithCursor{Limit: 1}}, @@ -251,16 +251,15 @@ func TestTripRequest_Pagination(t *testing.T) { expectLength: 10_000, }, } - srv, _ := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, tc) + checkTestCase(t, tc) }) } } func TestTripRequest_License(t *testing.T) { - testcases := []testRest{ + testcases := []testCase{ { name: "license:share_alike_optional yes", h: TripRequest{WithCursor: WithCursor{Limit: 100_000}, LicenseFilter: LicenseFilter{LicenseShareAlikeOptional: "yes"}}, selector: "trips.#.trip_id", @@ -307,10 +306,9 @@ func TestTripRequest_License(t *testing.T) { expectLength: 14903, }, } - srv, _ := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, tc) + checkTestCase(t, tc) }) } } From e0efef7484ad141a6a9fc53fd8c41cc6691d7288 Mon Sep 17 00:00:00 2001 From: Ian Rees Date: Mon, 18 Dec 2023 20:36:19 -0800 Subject: [PATCH 12/17] Cleanup --- server/rest/feed_version_download.go | 1 - server/rest/trip_request_test.go | 1 - 2 files changed, 2 deletions(-) diff --git a/server/rest/feed_version_download.go b/server/rest/feed_version_download.go index 66fc636f..3c12e7c6 100644 --- a/server/rest/feed_version_download.go +++ b/server/rest/feed_version_download.go @@ -86,7 +86,6 @@ func feedVersionDownloadLatestHandler(graphqlHandler http.Handler, w http.Respon } cfg := model.ForContext(r.Context()) - fmt.Printf("storage: %#v\n", cfg) serveFromStorage(w, r, cfg.Storage, fvsha1) } diff --git a/server/rest/trip_request_test.go b/server/rest/trip_request_test.go index f791b35c..125a16eb 100644 --- a/server/rest/trip_request_test.go +++ b/server/rest/trip_request_test.go @@ -24,7 +24,6 @@ func TestTripRequest(t *testing.T) { t.Error("failed to get route id for tests") } tripId := int(gjson.Get(toJson(d2), "trips.0.id").Int()) - fv := "e535eb2b3b9ac3ef15d82c56575e914575e732e0" ctfv := "d2813c293bcfd7a97dde599527ae6c62c98e66c6" testcases := []testCase{ From 3a8d2f77a58ddbe1a00f592e700b23f56496d1f7 Mon Sep 17 00:00:00 2001 From: Ian Rees Date: Mon, 18 Dec 2023 20:38:22 -0800 Subject: [PATCH 13/17] WIP --- server/rest/trip_request_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server/rest/trip_request_test.go b/server/rest/trip_request_test.go index 125a16eb..5453dbc7 100644 --- a/server/rest/trip_request_test.go +++ b/server/rest/trip_request_test.go @@ -12,7 +12,9 @@ import ( ) func TestTripRequest(t *testing.T) { - graphqlHandler, restHandler, _ := testHandlersWithOptions(t, testconfig.Options{}) + graphqlHandler, restHandler, _ := testHandlersWithOptions(t, testconfig.Options{ + When: "2018-06-01T00:00:00", + }) d, err := makeGraphQLRequest(context.Background(), graphqlHandler, `query{routes(where:{feed_onestop_id:"BA",route_id:"11"}) {id onestop_id}}`, nil) if err != nil { t.Error("failed to get route id for tests") From 26f7c86817c857733a45f12f48091533d39d9534 Mon Sep 17 00:00:00 2001 From: Ian Rees Date: Mon, 18 Dec 2023 20:39:05 -0800 Subject: [PATCH 14/17] Fix test --- server/rest/trip_request_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/server/rest/trip_request_test.go b/server/rest/trip_request_test.go index 5453dbc7..42bfa63c 100644 --- a/server/rest/trip_request_test.go +++ b/server/rest/trip_request_test.go @@ -13,7 +13,8 @@ import ( func TestTripRequest(t *testing.T) { graphqlHandler, restHandler, _ := testHandlersWithOptions(t, testconfig.Options{ - When: "2018-06-01T00:00:00", + When: "2018-06-01T00:00:00", + RTJsons: testconfig.DefaultRTJson(), }) d, err := makeGraphQLRequest(context.Background(), graphqlHandler, `query{routes(where:{feed_onestop_id:"BA",route_id:"11"}) {id onestop_id}}`, nil) if err != nil { From 2bb19891f1323c7ef11a41b42ea375549a33139d Mon Sep 17 00:00:00 2001 From: Ian Rees Date: Mon, 18 Dec 2023 20:41:10 -0800 Subject: [PATCH 15/17] Fewer skipped tests --- server/rest/feed_version_download_test.go | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/server/rest/feed_version_download_test.go b/server/rest/feed_version_download_test.go index 7a607075..8c67312c 100644 --- a/server/rest/feed_version_download_test.go +++ b/server/rest/feed_version_download_test.go @@ -12,13 +12,8 @@ import ( ) func TestFeedVersionDownloadRequest(t *testing.T) { - g, a, ok := testutil.CheckEnv("TL_TEST_STORAGE") - if !ok { - t.Skip(a) - return - } _, restSrv, _ := testHandlersWithOptions(t, testconfig.Options{ - Storage: g, + Storage: testutil.RelPath("tmp"), }) t.Run("ok", func(t *testing.T) { @@ -110,13 +105,8 @@ func TestFeedVersionDownloadRequest(t *testing.T) { } func TestFeedDownloadLatestRequest(t *testing.T) { - g, a, ok := testutil.CheckEnv("TL_TEST_STORAGE") - if !ok { - t.Skip(a) - return - } _, restSrv, _ := testHandlersWithOptions(t, testconfig.Options{ - Storage: g, + Storage: testutil.RelPath("tmp"), }) t.Run("ok", func(t *testing.T) { From 64ce5c2012809debfa968262719c2edfc3f260fe Mon Sep 17 00:00:00 2001 From: Ian Rees Date: Mon, 18 Dec 2023 20:43:39 -0800 Subject: [PATCH 16/17] Attempt fix --- .github/workflows/test.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 59583fec..f0708f3e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -17,7 +17,6 @@ jobs: PGPASSWORD: for_testing PGHOST: localhost PGPORT: 5432 - TL_TEST_STORAGE: /tmp/tlserver TL_TEST_SERVER_DATABASE_URL: postgres://root:for_testing@localhost:5432/tlv2_test_server?sslmode=disable TL_DATABASE_URL: postgres://root:for_testing@localhost:5432/tlv2_test_server?sslmode=disable services: From 9314dc010e36f02e49c14ca6fecd65f45937ace3 Mon Sep 17 00:00:00 2001 From: Ian Rees Date: Mon, 18 Dec 2023 20:55:00 -0800 Subject: [PATCH 17/17] More ctx improvements --- actions/fetch_test.go | 2 +- jobs/jobs.go | 2 -- server/gql/agency_resolver_test.go | 3 ++- server/gql/feed_resolver_test.go | 4 +++- server/gql/route_resolver_test.go | 3 ++- server/rest/agency_request_test.go | 3 ++- server/rest/route_request_test.go | 2 +- server/rest/stop_request_test.go | 2 +- server/rest/trip_request_test.go | 8 +++++--- server/server_cmd.go | 1 - test_setup.sh | 2 +- workers/fetch_enqueue_worker.go | 5 +++-- workers/gbfs_fetch_worker.go | 7 ++++--- workers/gbfs_fetch_worker_test.go | 6 +++--- 14 files changed, 28 insertions(+), 22 deletions(-) diff --git a/actions/fetch_test.go b/actions/fetch_test.go index 214c8a42..2f3b9eb3 100644 --- a/actions/fetch_test.go +++ b/actions/fetch_test.go @@ -125,7 +125,7 @@ func TestStaticFetchWorker(t *testing.T) { // Check output ff := dmfr.FeedFetch{} if err := dbutil.Get( - context.Background(), + ctx, cfg.Finder.DBX(), sq.StatementBuilder. Select("ff.*"). diff --git a/jobs/jobs.go b/jobs/jobs.go index d56a04a5..4f302dda 100644 --- a/jobs/jobs.go +++ b/jobs/jobs.go @@ -7,7 +7,6 @@ import ( "encoding/json" "github.com/interline-io/transitland-lib/tl" - "github.com/interline-io/transitland-server/model" "github.com/rs/zerolog" ) @@ -44,7 +43,6 @@ func (job *Job) HexKey() (string, error) { // JobOptions is configuration passed to worker. type JobOptions struct { - Finders model.Config JobQueue JobQueue Logger zerolog.Logger Secrets []tl.Secret diff --git a/server/gql/agency_resolver_test.go b/server/gql/agency_resolver_test.go index d6398612..1aca4b8a 100644 --- a/server/gql/agency_resolver_test.go +++ b/server/gql/agency_resolver_test.go @@ -9,6 +9,7 @@ import ( "github.com/interline-io/transitland-mw/auth/authz" "github.com/interline-io/transitland-server/internal/testconfig" "github.com/interline-io/transitland-server/internal/testutil" + "github.com/interline-io/transitland-server/model" ) func TestAgencyResolver(t *testing.T) { @@ -243,7 +244,7 @@ func TestAgencyResolver(t *testing.T) { func TestAgencyResolver_Cursor(t *testing.T) { c, cfg := newTestClient(t) - allEnts, err := cfg.Finder.FindAgencies(context.Background(), nil, nil, nil, nil) + allEnts, err := cfg.Finder.FindAgencies(model.WithConfig(context.Background(), cfg), nil, nil, nil, nil) if err != nil { t.Fatal(err) } diff --git a/server/gql/feed_resolver_test.go b/server/gql/feed_resolver_test.go index e713b8cc..dee1ae5f 100644 --- a/server/gql/feed_resolver_test.go +++ b/server/gql/feed_resolver_test.go @@ -3,6 +3,8 @@ package gql import ( "context" "testing" + + "github.com/interline-io/transitland-server/model" ) func TestFeedResolver(t *testing.T) { @@ -250,7 +252,7 @@ func TestFeedResolver(t *testing.T) { func TestFeedResolver_Cursor(t *testing.T) { c, cfg := newTestClient(t) - allEnts, err := cfg.Finder.FindFeeds(context.Background(), nil, nil, nil, nil) + allEnts, err := cfg.Finder.FindFeeds(model.WithConfig(context.Background(), cfg), nil, nil, nil, nil) if err != nil { t.Fatal(err) } diff --git a/server/gql/route_resolver_test.go b/server/gql/route_resolver_test.go index 568faf0e..ab378771 100644 --- a/server/gql/route_resolver_test.go +++ b/server/gql/route_resolver_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "github.com/interline-io/transitland-server/model" "github.com/stretchr/testify/assert" "github.com/tidwall/gjson" ) @@ -307,7 +308,7 @@ func TestRouteResolver_PreviousOnestopID(t *testing.T) { func TestRouteResolver_Cursor(t *testing.T) { c, cfg := newTestClient(t) - allEnts, err := cfg.Finder.FindRoutes(context.Background(), nil, nil, nil, nil) + allEnts, err := cfg.Finder.FindRoutes(model.WithConfig(context.Background(), cfg), nil, nil, nil, nil) if err != nil { t.Fatal(err) } diff --git a/server/rest/agency_request_test.go b/server/rest/agency_request_test.go index 3675441e..64e203ee 100644 --- a/server/rest/agency_request_test.go +++ b/server/rest/agency_request_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/interline-io/transitland-server/internal/testconfig" + "github.com/interline-io/transitland-server/model" "github.com/stretchr/testify/assert" "github.com/tidwall/gjson" ) @@ -201,7 +202,7 @@ func TestAgencyRequest_Format(t *testing.T) { func TestAgencyRequest_Pagination(t *testing.T) { graphqlHandler, restHandler, cfg := testHandlersWithOptions(t, testconfig.Options{}) - allEnts, err := cfg.Finder.FindAgencies(context.Background(), nil, nil, nil, nil) + allEnts, err := cfg.Finder.FindAgencies(model.WithConfig(context.Background(), cfg), nil, nil, nil, nil) if err != nil { t.Fatal(err) } diff --git a/server/rest/route_request_test.go b/server/rest/route_request_test.go index bb92ed92..ba0a9430 100644 --- a/server/rest/route_request_test.go +++ b/server/rest/route_request_test.go @@ -161,7 +161,7 @@ func TestRouteRequest_Format(t *testing.T) { func TestRouteRequest_Pagination(t *testing.T) { graphqlHandler, restHandler, cfg := testHandlersWithOptions(t, testconfig.Options{}) - allEnts, err := cfg.Finder.FindRoutes(context.Background(), nil, nil, nil, nil) + allEnts, err := cfg.Finder.FindRoutes(model.WithConfig(context.Background(), cfg), nil, nil, nil, nil) if err != nil { t.Fatal(err) } diff --git a/server/rest/stop_request_test.go b/server/rest/stop_request_test.go index f5849d28..17961baf 100644 --- a/server/rest/stop_request_test.go +++ b/server/rest/stop_request_test.go @@ -241,7 +241,7 @@ func TestStopRequest_Format(t *testing.T) { func TestStopRequest_Pagination(t *testing.T) { graphqlHandler, restHandler, cfg := testHandlersWithOptions(t, testconfig.Options{}) - allEnts, err := cfg.Finder.FindStops(context.Background(), nil, nil, nil, nil) + allEnts, err := cfg.Finder.FindStops(model.WithConfig(context.Background(), cfg), nil, nil, nil, nil) if err != nil { t.Fatal(err) } diff --git a/server/rest/trip_request_test.go b/server/rest/trip_request_test.go index 42bfa63c..facd0dcf 100644 --- a/server/rest/trip_request_test.go +++ b/server/rest/trip_request_test.go @@ -7,22 +7,24 @@ import ( "testing" "github.com/interline-io/transitland-server/internal/testconfig" + "github.com/interline-io/transitland-server/model" "github.com/stretchr/testify/assert" "github.com/tidwall/gjson" ) func TestTripRequest(t *testing.T) { - graphqlHandler, restHandler, _ := testHandlersWithOptions(t, testconfig.Options{ + graphqlHandler, restHandler, cfg := testHandlersWithOptions(t, testconfig.Options{ When: "2018-06-01T00:00:00", RTJsons: testconfig.DefaultRTJson(), }) - d, err := makeGraphQLRequest(context.Background(), graphqlHandler, `query{routes(where:{feed_onestop_id:"BA",route_id:"11"}) {id onestop_id}}`, nil) + ctx := model.WithConfig(context.Background(), cfg) + d, err := makeGraphQLRequest(ctx, graphqlHandler, `query{routes(where:{feed_onestop_id:"BA",route_id:"11"}) {id onestop_id}}`, nil) if err != nil { t.Error("failed to get route id for tests") } routeId := int(gjson.Get(toJson(d), "routes.0.id").Int()) routeOnestopId := gjson.Get(toJson(d), "routes.0.onestop_id").String() - d2, err := makeGraphQLRequest(context.Background(), graphqlHandler, `query{trips(where:{trip_id:"5132248WKDY"}){id}}`, nil) + d2, err := makeGraphQLRequest(ctx, graphqlHandler, `query{trips(where:{trip_id:"5132248WKDY"}){id}}`, nil) if err != nil { t.Error("failed to get route id for tests") } diff --git a/server/server_cmd.go b/server/server_cmd.go index f238076f..76e643b0 100644 --- a/server/server_cmd.go +++ b/server/server_cmd.go @@ -338,7 +338,6 @@ func (cmd *Command) Run() error { // Start workers/api jobWorkers := 8 jobOptions := jobs.JobOptions{ - Finders: cfg, Logger: log.Logger, JobQueue: jobQueue, } diff --git a/test_setup.sh b/test_setup.sh index 43c46e26..0229a965 100755 --- a/test_setup.sh +++ b/test_setup.sh @@ -1,6 +1,6 @@ #!/bin/bash # Remove import files -TL_TEST_STORAGE="${TL_TEST_STORAGE:-tmp}" +TL_TEST_STORAGE="${PWD}/tmp" mkdir -p "${TL_TEST_STORAGE}"; rm ${TL_TEST_STORAGE}/*.zip # export TL_LOG=debug (cd cmd/tlserver && go install .) diff --git a/workers/fetch_enqueue_worker.go b/workers/fetch_enqueue_worker.go index 3f28788f..100a109f 100644 --- a/workers/fetch_enqueue_worker.go +++ b/workers/fetch_enqueue_worker.go @@ -16,10 +16,11 @@ type FetchEnqueueWorker struct { } func (w *FetchEnqueueWorker) Run(ctx context.Context, job jobs.Job) error { + cfg := model.ForContext(ctx) + db := cfg.Finder.DBX() opts := job.Opts - db := opts.Finders.Finder.DBX() now := time.Now().In(time.UTC) - feeds, err := job.Opts.Finders.Finder.FindFeeds(ctx, nil, nil, nil, &model.FeedFilter{}) + feeds, err := cfg.Finder.FindFeeds(ctx, nil, nil, nil, &model.FeedFilter{}) if err != nil { return err } diff --git a/workers/gbfs_fetch_worker.go b/workers/gbfs_fetch_worker.go index e5e2d08a..e1a473b0 100644 --- a/workers/gbfs_fetch_worker.go +++ b/workers/gbfs_fetch_worker.go @@ -19,8 +19,9 @@ type GbfsFetchWorker struct { } func (w *GbfsFetchWorker) Run(ctx context.Context, job jobs.Job) error { + cfg := model.ForContext(ctx) log := job.Opts.Logger.With().Str("feed_id", w.FeedID).Str("url", w.Url).Logger() - gfeeds, err := job.Opts.Finders.Finder.FindFeeds(ctx, nil, nil, nil, &model.FeedFilter{OnestopID: &w.FeedID}) + gfeeds, err := cfg.Finder.FindFeeds(ctx, nil, nil, nil, &model.FeedFilter{OnestopID: &w.FeedID}) if err != nil { log.Error().Err(err).Msg("gbfsfetch worker: error loading source feed") return err @@ -40,7 +41,7 @@ func (w *GbfsFetchWorker) Run(ctx context.Context, job jobs.Job) error { opts.FeedURL = w.Url } feeds, result, err := gbfs.Fetch( - tldb.NewPostgresAdapterFromDBX(job.Opts.Finders.Finder.DBX()), + tldb.NewPostgresAdapterFromDBX(cfg.Finder.DBX()), opts, ) if err != nil { @@ -54,7 +55,7 @@ func (w *GbfsFetchWorker) Run(ctx context.Context, job jobs.Job) error { for _, feed := range feeds { if feed.SystemInformation != nil { key := fmt.Sprintf("%s:%s", w.FeedID, feed.SystemInformation.Language.Val) - job.Opts.Finders.GbfsFinder.AddData(ctx, key, feed) + cfg.GbfsFinder.AddData(ctx, key, feed) } } log.Info().Msg("gbfs fetch worker: success") diff --git a/workers/gbfs_fetch_worker_test.go b/workers/gbfs_fetch_worker_test.go index 148b4c40..a6936244 100644 --- a/workers/gbfs_fetch_worker_test.go +++ b/workers/gbfs_fetch_worker_test.go @@ -20,18 +20,18 @@ func TestGbfsFetchWorker(t *testing.T) { testconfig.ConfigTxRollback(t, testconfig.Options{}, func(cfg model.Config) { job := jobs.Job{} - job.Opts.Finders = cfg w := GbfsFetchWorker{ Url: ts.URL + "/gbfs.json", FeedID: "test-gbfs", } - err := w.Run(context.Background(), job) + ctx := model.WithConfig(context.Background(), cfg) + err := w.Run(ctx, job) if err != nil { t.Fatal(err) } // Test bikes, err := cfg.GbfsFinder.FindBikes( - context.Background(), + ctx, nil, &model.GbfsBikeRequest{ Near: &model.PointRadius{