diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 59583fec..f0708f3e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -17,7 +17,6 @@ jobs: PGPASSWORD: for_testing PGHOST: localhost PGPORT: 5432 - TL_TEST_STORAGE: /tmp/tlserver TL_TEST_SERVER_DATABASE_URL: postgres://root:for_testing@localhost:5432/tlv2_test_server?sslmode=disable TL_DATABASE_URL: postgres://root:for_testing@localhost:5432/tlv2_test_server?sslmode=disable services: diff --git a/actions/fetch.go b/actions/fetch.go index 25406e25..35d76caf 100644 --- a/actions/fetch.go +++ b/actions/fetch.go @@ -18,16 +18,18 @@ import ( "github.com/interline-io/transitland-lib/tldb" "github.com/interline-io/transitland-mw/auth/authn" "github.com/interline-io/transitland-mw/auth/authz" - "github.com/interline-io/transitland-server/config" "github.com/interline-io/transitland-server/internal/dbutil" "github.com/interline-io/transitland-server/model" "github.com/jmoiron/sqlx" "google.golang.org/protobuf/proto" ) -func StaticFetch(ctx context.Context, cfg config.Config, dbf model.Finder, feedId string, feedSrc io.Reader, feedUrl string, checker model.Checker) (*model.FeedVersionFetchResult, error) { +func StaticFetch(ctx context.Context, feedId string, feedSrc io.Reader, feedUrl string) (*model.FeedVersionFetchResult, error) { + cfg := model.ForContext(ctx) + dbf := cfg.Finder + urlType := "static_current" - feed, err := fetchCheckFeed(ctx, dbf, checker, feedId, urlType, feedUrl) + feed, err := fetchCheckFeed(ctx, feedId, urlType, feedUrl) if err != nil { return nil, err } @@ -86,8 +88,10 @@ func StaticFetch(ctx context.Context, cfg config.Config, dbf model.Finder, feedI return &mr, nil } -func RTFetch(ctx context.Context, cfg config.Config, dbf model.Finder, rtf model.RTFinder, target string, feedId string, feedUrl string, urlType string, checker model.Checker) error { - feed, err := fetchCheckFeed(ctx, dbf, checker, feedId, urlType, feedUrl) +func RTFetch(ctx context.Context, target string, feedId string, feedUrl string, urlType string) error { + cfg := model.ForContext(ctx) + + feed, err := fetchCheckFeed(ctx, feedId, urlType, feedUrl) if err != nil { return err } @@ -108,7 +112,7 @@ func RTFetch(ctx context.Context, cfg config.Config, dbf model.Finder, rtf model // Make request var rtMsg *pb.FeedMessage var fetchErr error - if err := tldb.NewPostgresAdapterFromDBX(dbf.DBX()).Tx(func(atx tldb.Adapter) error { + if err := tldb.NewPostgresAdapterFromDBX(cfg.Finder.DBX()).Tx(func(atx tldb.Adapter) error { m, fr, err := fetch.RTFetch(atx, fetchOpts) if err != nil { return err @@ -129,7 +133,7 @@ func RTFetch(ctx context.Context, cfg config.Config, dbf model.Finder, rtf model return errors.New("invalid rt data") } key := fmt.Sprintf("rtdata:%s:%s", target, urlType) - return rtf.AddData(key, rtdata) + return cfg.RTFinder.AddData(key, rtdata) } type CheckFetchWaitResult struct { @@ -216,7 +220,7 @@ func CheckFetchWaitBatch(ctx context.Context, db sqlx.Ext, feedIds []int, urlTyp return nil, err } for _, fetch := range lastFetches { - a, _ := checks[fetch.ID] + a := checks[fetch.ID] a.CheckedAt = now a.ID = fetch.ID a.OnestopID = fetch.OnestopID @@ -250,9 +254,12 @@ func chunkBy[T any](items []T, chunkSize int) (chunks [][]T) { return append(chunks, items) } -func fetchCheckFeed(ctx context.Context, dbf model.Finder, checker model.Checker, feedId string, urlType string, url string) (*model.Feed, error) { +func fetchCheckFeed(ctx context.Context, feedId string, urlType string, url string) (*model.Feed, error) { + cfg := model.ForContext(ctx) + checker := cfg.Checker + // Check feed exists - feeds, err := dbf.FindFeeds(ctx, nil, nil, nil, &model.FeedFilter{OnestopID: &feedId}) + feeds, err := cfg.Finder.FindFeeds(ctx, nil, nil, nil, &model.FeedFilter{OnestopID: &feedId}) if err != nil { return nil, err } diff --git a/actions/fetch_test.go b/actions/fetch_test.go index c1b0d98b..2f3b9eb3 100644 --- a/actions/fetch_test.go +++ b/actions/fetch_test.go @@ -2,16 +2,16 @@ package actions import ( "context" - "io/ioutil" "net/http" "net/http/httptest" + "os" "testing" sq "github.com/Masterminds/squirrel" "github.com/interline-io/transitland-lib/dmfr" "github.com/interline-io/transitland-server/internal/dbutil" - "github.com/interline-io/transitland-server/internal/testfinder" + "github.com/interline-io/transitland-server/internal/testconfig" "github.com/interline-io/transitland-server/internal/testutil" "github.com/interline-io/transitland-server/model" "github.com/stretchr/testify/assert" @@ -99,7 +99,7 @@ func TestStaticFetchWorker(t *testing.T) { return } - buf, err := ioutil.ReadFile(testutil.RelPath(tc.serveFile)) + buf, err := os.ReadFile(testutil.RelPath(tc.serveFile)) if err != nil { http.Error(w, "404", 404) return @@ -110,9 +110,11 @@ func TestStaticFetchWorker(t *testing.T) { // Setup job feedUrl := ts.URL + "/" + tc.serveFile - testfinder.FindersTxRollback(t, nil, nil, func(te model.Finders) { + testconfig.ConfigTxRollback(t, testconfig.Options{}, func(cfg model.Config) { + cfg.Checker = nil // disable checker for this test + ctx := model.WithConfig(context.Background(), cfg) // Run job - if result, err := StaticFetch(context.Background(), te.Config, te.Finder, tc.feedId, nil, feedUrl, nil); err != nil && !tc.expectError { + if result, err := StaticFetch(ctx, tc.feedId, nil, feedUrl); err != nil && !tc.expectError { _ = result t.Fatal("unexpected error", err) } else if err == nil && tc.expectError { @@ -123,8 +125,8 @@ func TestStaticFetchWorker(t *testing.T) { // Check output ff := dmfr.FeedFetch{} if err := dbutil.Get( - context.Background(), - te.Finder.DBX(), + ctx, + cfg.Finder.DBX(), sq.StatementBuilder. Select("ff.*"). From("feed_fetches ff"). diff --git a/actions/fv.go b/actions/fv.go index a5956d5a..a0db32bd 100644 --- a/actions/fv.go +++ b/actions/fv.go @@ -10,11 +10,13 @@ import ( "github.com/interline-io/transitland-lib/tl/tt" "github.com/interline-io/transitland-lib/tldb" "github.com/interline-io/transitland-mw/auth/authz" - "github.com/interline-io/transitland-server/config" "github.com/interline-io/transitland-server/model" ) -func FeedVersionImport(ctx context.Context, cfg config.Config, dbf model.Finder, checker model.Checker, fvid int) (*model.FeedVersionImportResult, error) { +func FeedVersionImport(ctx context.Context, fvid int) (*model.FeedVersionImportResult, error) { + cfg := model.ForContext(ctx) + checker := cfg.Checker + dbf := cfg.Finder if checker == nil { return nil, authz.ErrUnauthorized } @@ -38,7 +40,10 @@ func FeedVersionImport(ctx context.Context, cfg config.Config, dbf model.Finder, return &mr, nil } -func FeedVersionUnimport(ctx context.Context, cfg config.Config, dbf model.Finder, checker model.Checker, fvid int) (*model.FeedVersionUnimportResult, error) { +func FeedVersionUnimport(ctx context.Context, fvid int) (*model.FeedVersionUnimportResult, error) { + cfg := model.ForContext(ctx) + checker := cfg.Checker + dbf := cfg.Finder if checker == nil { return nil, authz.ErrUnauthorized } @@ -59,7 +64,10 @@ func FeedVersionUnimport(ctx context.Context, cfg config.Config, dbf model.Finde return &mr, nil } -func FeedVersionUpdate(ctx context.Context, cfg config.Config, dbf model.Finder, checker model.Checker, fvid int, values model.FeedVersionSetInput) error { +func FeedVersionUpdate(ctx context.Context, fvid int, values model.FeedVersionSetInput) error { + cfg := model.ForContext(ctx) + checker := cfg.Checker + dbf := cfg.Finder if checker == nil { return authz.ErrUnauthorized } @@ -93,7 +101,9 @@ func FeedVersionUpdate(ctx context.Context, cfg config.Config, dbf model.Finder, return nil } -func FeedVersionDelete(ctx context.Context, cfg config.Config, dbf model.Finder, checker model.Checker, fvid int) (*model.FeedVersionDeleteResult, error) { +func FeedVersionDelete(ctx context.Context, fvid int) (*model.FeedVersionDeleteResult, error) { + cfg := model.ForContext(ctx) + checker := cfg.Checker if checker == nil { return nil, authz.ErrUnauthorized } diff --git a/actions/validate.go b/actions/validate.go index bdf8826f..b8eba9e1 100644 --- a/actions/validate.go +++ b/actions/validate.go @@ -12,7 +12,6 @@ import ( "github.com/interline-io/transitland-lib/tl/tt" "github.com/interline-io/transitland-lib/tlcsv" "github.com/interline-io/transitland-lib/validator" - "github.com/interline-io/transitland-server/config" "github.com/interline-io/transitland-server/model" ) @@ -21,7 +20,9 @@ type hasGeometries interface { } // ValidateUpload takes a file Reader and produces a validation package containing errors, warnings, file infos, service levels, etc. -func ValidateUpload(ctx context.Context, cfg config.Config, src io.Reader, feedURL *string, rturls []string) (*model.ValidationResult, error) { +func ValidateUpload(ctx context.Context, src io.Reader, feedURL *string, rturls []string) (*model.ValidationResult, error) { + cfg := model.ForContext(ctx) + // Check inputs rturlsok := []string{} for _, rturl := range rturls { diff --git a/config/config.go b/config/config.go deleted file mode 100644 index b5b4bdb8..00000000 --- a/config/config.go +++ /dev/null @@ -1,20 +0,0 @@ -package config - -import ( - "github.com/interline-io/transitland-lib/tl" - "github.com/interline-io/transitland-server/internal/clock" -) - -// Config is in a separate package to avoid import cycles. - -type Config struct { - Storage string - RTStorage string - ValidateLargeFiles bool - DisableImage bool - RestPrefix string - DBURL string - RedisURL string - Clock clock.Clock - Secrets []tl.Secret -} diff --git a/finders/dbfinder/finder_test.go b/finders/dbfinder/finder_test.go index c02b0011..1b1d5682 100644 --- a/finders/dbfinder/finder_test.go +++ b/finders/dbfinder/finder_test.go @@ -2,31 +2,19 @@ package dbfinder import ( "context" - "log" - "os" "testing" "github.com/interline-io/transitland-server/internal/testutil" - "github.com/interline-io/transitland-server/model" "github.com/stretchr/testify/assert" ) -var TestFinder model.Finder - -func TestMain(m *testing.M) { - if a, ok := testutil.CheckTestDB(); !ok { - log.Print(a) - return - } +func TestFinder_FindFeedVersionServiceWindow(t *testing.T) { db := testutil.MustOpenTestDB() dbf := NewFinder(db, nil) - TestFinder = dbf - os.Exit(m.Run()) -} + testFinder := dbf -func TestFinder_FindFeedVersionServiceWindow(t *testing.T) { fvm := map[string]int{} - fvs, err := TestFinder.FindFeedVersions(context.TODO(), nil, nil, nil, nil) + fvs, err := testFinder.FindFeedVersions(context.TODO(), nil, nil, nil, nil) if err != nil { t.Fatal(err) } @@ -64,7 +52,7 @@ func TestFinder_FindFeedVersionServiceWindow(t *testing.T) { } for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - start, end, best, err := TestFinder.FindFeedVersionServiceWindow(context.TODO(), tc.fvid) + start, end, best, err := testFinder.FindFeedVersionServiceWindow(context.TODO(), tc.fvid) if err != nil { t.Fatal(err) } diff --git a/finders/gbfsfinder/finder_test.go b/finders/gbfsfinder/finder_test.go index c0cf6db4..61d74575 100644 --- a/finders/gbfsfinder/finder_test.go +++ b/finders/gbfsfinder/finder_test.go @@ -20,7 +20,7 @@ func TestGbfsFinder(t *testing.T) { } client := testutil.MustOpenTestRedisClient() gbf := NewFinder(client) - setupGbfs(nil, gbf) + testSetupGbfs(gbf) tcs := []struct { p xy.Point @@ -58,7 +58,7 @@ func TestGbfsFinder(t *testing.T) { } -func setupGbfs(dbf model.Finder, gbf model.GbfsFinder) error { +func testSetupGbfs(gbf model.GbfsFinder) error { // Setup sourceFeedId := "gbfs-test" ts := httptest.NewServer(&gbfs.TestGbfsServer{Language: "en", Path: testutil.RelPath("test/data/gbfs")}) diff --git a/internal/testfinder/testfinder.go b/internal/testconfig/testconfig.go similarity index 67% rename from internal/testfinder/testfinder.go rename to internal/testconfig/testconfig.go index 5107052b..b4a03ffd 100644 --- a/internal/testfinder/testfinder.go +++ b/internal/testconfig/testconfig.go @@ -1,4 +1,4 @@ -package testfinder +package testconfig import ( "context" @@ -6,11 +6,11 @@ import ( "fmt" "os" "testing" + "time" "github.com/interline-io/transitland-lib/rt" "github.com/interline-io/transitland-mw/auth/authz" "github.com/interline-io/transitland-mw/auth/azcheck" - "github.com/interline-io/transitland-server/config" "github.com/interline-io/transitland-server/finders/dbfinder" "github.com/interline-io/transitland-server/finders/gbfsfinder" "github.com/interline-io/transitland-server/finders/rtfinder" @@ -23,23 +23,72 @@ import ( // Test helpers -type TestFinderOptions struct { - Clock clock.Clock +type Options struct { + When string + Storage string + RTStorage string RTJsons []RTJsonFile FGAModelFile string FGAModelTuples []authz.TupleKey } -func newFinders(t testing.TB, db sqlx.Ext, opts TestFinderOptions) model.Finders { - if opts.Clock == nil { - opts.Clock = &clock.Real{} +func Config(t testing.TB, opts Options) model.Config { + db := testutil.MustOpenTestDB() + return newTestConfig(t, db, opts) +} + +func ConfigTx(t testing.TB, opts Options, cb func(model.Config) error) { + // Check open DB + db := testutil.MustOpenTestDB() + + // Start Txn + tx := db.MustBeginTx(context.Background(), nil) + defer tx.Rollback() + + // Get finders + testEnv := newTestConfig(t, tx, opts) + + // Commit or rollback + if err := cb(testEnv); err != nil { + //tx.Rollback() + } else { + tx.Commit() } - cfg := config.Config{ - Clock: opts.Clock, - Storage: t.TempDir(), - RTStorage: t.TempDir(), +} + +func ConfigTxRollback(t testing.TB, opts Options, cb func(model.Config)) { + ConfigTx(t, opts, func(c model.Config) error { + cb(c) + return errors.New("rollback") + }) +} + +type RTJsonFile struct { + Feed string + Ftype string + Fname string +} + +func DefaultRTJson() []RTJsonFile { + return []RTJsonFile{ + {"BA", "realtime_trip_updates", "BA.json"}, + {"BA", "realtime_alerts", "BA-alerts.json"}, + {"CT", "realtime_trip_updates", "CT.json"}, + } +} + +func newTestConfig(t testing.TB, db sqlx.Ext, opts Options) model.Config { + // Default time + if opts.When == "" { + opts.When = "2022-09-01T00:00:00" } + when, err := time.Parse("2006-01-02T15:04:05", opts.When) + if err != nil { + t.Fatal(err) + } + cl := &clock.Mock{T: when} + // Setup Checker checkerCfg := azcheck.CheckerConfig{ FGAEndpoint: os.Getenv("TL_TEST_FGA_ENDPOINT"), @@ -53,11 +102,11 @@ func newFinders(t testing.TB, db sqlx.Ext, opts TestFinderOptions) model.Finders // Setup DB dbf := dbfinder.NewFinder(db, checker) - dbf.Clock = opts.Clock + dbf.Clock = cl // Setup RT rtf := rtfinder.NewFinder(rtfinder.NewLocalCache(), db) - rtf.Clock = opts.Clock + rtf.Clock = cl for _, rtj := range opts.RTJsons { fn := testutil.RelPath("test", "data", "rt", rtj.Fname) msg, err := rt.ReadFile(fn) @@ -77,60 +126,19 @@ func newFinders(t testing.TB, db sqlx.Ext, opts TestFinderOptions) model.Finders // Setup GBFS gbf := gbfsfinder.NewFinder(nil) - return model.Finders{ - Config: cfg, + if opts.Storage == "" { + opts.Storage = t.TempDir() + } + if opts.RTStorage == "" { + opts.RTStorage = t.TempDir() + } + return model.Config{ Finder: dbf, RTFinder: rtf, GbfsFinder: gbf, Checker: checker, - } -} - -func Finders(t testing.TB, cl clock.Clock, rtJsons []RTJsonFile) model.Finders { - db := testutil.MustOpenTestDB() - return newFinders(t, db, TestFinderOptions{Clock: cl, RTJsons: rtJsons}) -} - -func FindersWithOptions(t testing.TB, opts TestFinderOptions) model.Finders { - db := testutil.MustOpenTestDB() - return newFinders(t, db, opts) -} - -func FindersTx(t testing.TB, cl clock.Clock, rtJsons []RTJsonFile, cb func(model.Finders) error) { - // Check open DB - db := testutil.MustOpenTestDB() - // Start Txn - tx := db.MustBeginTx(context.Background(), nil) - defer tx.Rollback() - - // Get finders - testEnv := newFinders(t, tx, TestFinderOptions{Clock: cl, RTJsons: rtJsons}) - - // Commit or rollback - if err := cb(testEnv); err != nil { - //tx.Rollback() - } else { - tx.Commit() - } -} - -func FindersTxRollback(t testing.TB, cl clock.Clock, rtJsons []RTJsonFile, cb func(model.Finders)) { - FindersTx(t, cl, rtJsons, func(c model.Finders) error { - cb(c) - return errors.New("rollback") - }) -} - -type RTJsonFile struct { - Feed string - Ftype string - Fname string -} - -func DefaultRTJson() []RTJsonFile { - return []RTJsonFile{ - {"BA", "realtime_trip_updates", "BA.json"}, - {"BA", "realtime_alerts", "BA-alerts.json"}, - {"CT", "realtime_trip_updates", "CT.json"}, + Clock: cl, + Storage: opts.Storage, + RTStorage: opts.RTStorage, } } diff --git a/jobs/jobs.go b/jobs/jobs.go index 1c0ba54c..4f302dda 100644 --- a/jobs/jobs.go +++ b/jobs/jobs.go @@ -7,8 +7,6 @@ import ( "encoding/json" "github.com/interline-io/transitland-lib/tl" - "github.com/interline-io/transitland-server/config" - "github.com/interline-io/transitland-server/model" "github.com/rs/zerolog" ) @@ -45,13 +43,9 @@ func (job *Job) HexKey() (string, error) { // JobOptions is configuration passed to worker. type JobOptions struct { - Finder model.Finder - RTFinder model.RTFinder - GbfsFinder model.GbfsFinder - JobQueue JobQueue - Logger zerolog.Logger - Config config.Config - Secrets []tl.Secret + JobQueue JobQueue + Logger zerolog.Logger + Secrets []tl.Secret } // GetWorker returns a new worker for this job type diff --git a/model/config.go b/model/config.go new file mode 100644 index 00000000..c1b47eee --- /dev/null +++ b/model/config.go @@ -0,0 +1,51 @@ +package model + +import ( + "context" + "net/http" + + "github.com/interline-io/transitland-lib/tl" + "github.com/interline-io/transitland-server/internal/clock" + "github.com/rs/zerolog" +) + +type Config struct { + Finder Finder + RTFinder RTFinder + GbfsFinder GbfsFinder + Checker Checker + Clock clock.Clock + Secrets []tl.Secret + ValidateLargeFiles bool + Storage string + RTStorage string + Logger zerolog.Logger +} + +var finderCtxKey = &contextKey{"finderConfig"} + +type contextKey struct { + name string +} + +func ForContext(ctx context.Context) Config { + raw, ok := ctx.Value(finderCtxKey).(Config) + if !ok { + return Config{} + } + return raw +} + +func WithConfig(ctx context.Context, cfg Config) context.Context { + r := context.WithValue(ctx, finderCtxKey, cfg) + return r +} + +func AddConfig(cfg Config) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r = r.WithContext(WithConfig(r.Context(), cfg)) + next.ServeHTTP(w, r) + }) + } +} diff --git a/model/finders.go b/model/finders.go index 280e71ab..dc6b409e 100644 --- a/model/finders.go +++ b/model/finders.go @@ -7,20 +7,11 @@ import ( "github.com/interline-io/transitland-lib/rt/pb" "github.com/interline-io/transitland-lib/tl/tt" "github.com/interline-io/transitland-mw/auth/authz" - "github.com/interline-io/transitland-server/config" "github.com/interline-io/transitland-server/internal/gbfs" "github.com/jmoiron/sqlx" ) -type Finders struct { - Config config.Config - Finder Finder - RTFinder RTFinder - GbfsFinder GbfsFinder - Checker Checker -} - // Finder provides all necessary database methods type Finder interface { PermFinder diff --git a/server/gql/agency_resolver.go b/server/gql/agency_resolver.go index e8b903bf..766ab272 100644 --- a/server/gql/agency_resolver.go +++ b/server/gql/agency_resolver.go @@ -35,6 +35,6 @@ func (r *agencyResolver) Operator(ctx context.Context, obj *model.Agency) (*mode } func (r *agencyResolver) Alerts(ctx context.Context, obj *model.Agency, active *bool, limit *int) ([]*model.Alert, error) { - rtAlerts := r.rtfinder.FindAlertsForAgency(obj, checkLimit(limit), active) + rtAlerts := model.ForContext(ctx).RTFinder.FindAlertsForAgency(obj, checkLimit(limit), active) return rtAlerts, nil } diff --git a/server/gql/agency_resolver_test.go b/server/gql/agency_resolver_test.go index c911da33..1aca4b8a 100644 --- a/server/gql/agency_resolver_test.go +++ b/server/gql/agency_resolver_test.go @@ -7,8 +7,9 @@ import ( "github.com/99designs/gqlgen/client" "github.com/interline-io/transitland-mw/auth/ancheck" "github.com/interline-io/transitland-mw/auth/authz" - "github.com/interline-io/transitland-server/internal/testfinder" + "github.com/interline-io/transitland-server/internal/testconfig" "github.com/interline-io/transitland-server/internal/testutil" + "github.com/interline-io/transitland-server/model" ) func TestAgencyResolver(t *testing.T) { @@ -242,8 +243,8 @@ func TestAgencyResolver(t *testing.T) { } func TestAgencyResolver_Cursor(t *testing.T) { - c, te := newTestClient(t) - allEnts, err := te.Finder.FindAgencies(context.Background(), nil, nil, nil, nil) + c, cfg := newTestClient(t) + allEnts, err := cfg.Finder.FindAgencies(model.WithConfig(context.Background(), cfg), nil, nil, nil, nil) if err != nil { t.Fatal(err) } @@ -310,12 +311,12 @@ func TestAgencyResolver_Authz(t *testing.T) { t.Skip(a) return } - teOpts := testfinder.TestFinderOptions{ + cfg := testconfig.Config(t, testconfig.Options{ FGAModelFile: testutil.RelPath("test/authz/tls.json"), FGAModelTuples: fgaTestTuples, - } - te := testfinder.FindersWithOptions(t, teOpts) - srv, _ := NewServer(te.Config, te.Finder, te.RTFinder, te.GbfsFinder, te.Checker) + }) + _ = cfg + srv, _ := NewServer() testcases := []testcase{ { name: "basic", diff --git a/server/gql/feed_resolver_test.go b/server/gql/feed_resolver_test.go index b671adf6..dee1ae5f 100644 --- a/server/gql/feed_resolver_test.go +++ b/server/gql/feed_resolver_test.go @@ -3,6 +3,8 @@ package gql import ( "context" "testing" + + "github.com/interline-io/transitland-server/model" ) func TestFeedResolver(t *testing.T) { @@ -249,8 +251,8 @@ func TestFeedResolver(t *testing.T) { } func TestFeedResolver_Cursor(t *testing.T) { - c, te := newTestClient(t) - allEnts, err := te.Finder.FindFeeds(context.Background(), nil, nil, nil, nil) + c, cfg := newTestClient(t) + allEnts, err := cfg.Finder.FindFeeds(model.WithConfig(context.Background(), cfg), nil, nil, nil, nil) if err != nil { t.Fatal(err) } diff --git a/server/gql/fvsl_cache.go b/server/gql/fvsl_cache.go index 9c21b3b4..40a1d12c 100644 --- a/server/gql/fvsl_cache.go +++ b/server/gql/fvsl_cache.go @@ -23,38 +23,38 @@ type fvslCache struct { fvWindows map[int]fvslWindow } -func newFvslCache(f model.Finder) *fvslCache { +func newFvslCache() *fvslCache { return &fvslCache{ - Finder: f, fvWindows: map[int]fvslWindow{}, } } -func (f *fvslCache) Get(fvid int) (fvslWindow, bool) { +func (f *fvslCache) Get(ctx context.Context, fvid int) (fvslWindow, bool) { f.lock.Lock() a, ok := f.fvWindows[fvid] f.lock.Unlock() if ok { return a, ok } - a, err := f.query(fvid) + a, err := f.query(ctx, fvid) if err != nil { a.Valid = false } - f.Set(fvid, a) + f.Set(ctx, fvid, a) return a, a.Valid } -func (f *fvslCache) Set(fvid int, w fvslWindow) { +func (f *fvslCache) Set(ctx context.Context, fvid int, w fvslWindow) { f.lock.Lock() defer f.lock.Unlock() f.fvWindows[fvid] = w } -func (f *fvslCache) query(fvid int) (fvslWindow, error) { +func (f *fvslCache) query(ctx context.Context, fvid int) (fvslWindow, error) { + cfg := model.ForContext(ctx) var err error w := fvslWindow{} - w.StartDate, w.EndDate, w.BestWeek, err = f.Finder.FindFeedVersionServiceWindow(context.TODO(), fvid) + w.StartDate, w.EndDate, w.BestWeek, err = cfg.Finder.FindFeedVersionServiceWindow(context.TODO(), fvid) log.Trace(). Str("start_date", w.StartDate.Format("2006-01-02")). Str("end_date", w.EndDate.Format("2006-01-02")). diff --git a/server/gql/fvsl_cache_test.go b/server/gql/fvsl_cache_test.go index 716b470a..0f39d508 100644 --- a/server/gql/fvsl_cache_test.go +++ b/server/gql/fvsl_cache_test.go @@ -1,13 +1,15 @@ package gql import ( + "context" "testing" - "github.com/interline-io/transitland-server/internal/testfinder" + "github.com/interline-io/transitland-server/internal/testconfig" + "github.com/interline-io/transitland-server/model" ) func TestFvslCache(t *testing.T) { - te := testfinder.Finders(t, nil, nil) - c := newFvslCache(te.Finder) - c.Get(1) + cfg := testconfig.Config(t, testconfig.Options{}) + c := newFvslCache() + c.Get(model.WithConfig(context.Background(), cfg), 1) } diff --git a/server/gql/gbfs_resolver.go b/server/gql/gbfs_resolver.go index e74490e1..cd5252a6 100644 --- a/server/gql/gbfs_resolver.go +++ b/server/gql/gbfs_resolver.go @@ -7,9 +7,9 @@ import ( ) func (r *queryResolver) Bikes(ctx context.Context, limit *int, where *model.GbfsBikeRequest) ([]*model.GbfsFreeBikeStatus, error) { - return r.gbfsFinder.FindBikes(ctx, checkLimit(limit), where) + return model.ForContext(ctx).GbfsFinder.FindBikes(ctx, checkLimit(limit), where) } func (r *queryResolver) Docks(ctx context.Context, limit *int, where *model.GbfsDockRequest) ([]*model.GbfsStationInformation, error) { - return r.gbfsFinder.FindDocks(ctx, checkLimit(limit), where) + return model.ForContext(ctx).GbfsFinder.FindDocks(ctx, checkLimit(limit), where) } diff --git a/server/gql/gbfs_resolver_test.go b/server/gql/gbfs_resolver_test.go index 24f5f3ed..4be34ca5 100644 --- a/server/gql/gbfs_resolver_test.go +++ b/server/gql/gbfs_resolver_test.go @@ -77,8 +77,8 @@ func TestGbfsBikeResolver(t *testing.T) { selectExpect: []string{"0cbf9b08f8b71a6362e20c8173c071a6"}, }, } - c, te := newTestClient(t) - setupGbfs(te.GbfsFinder) + c, cfg := newTestClient(t) + setupGbfs(cfg.GbfsFinder) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { queryTestcase(t, c, tc) @@ -210,8 +210,8 @@ func TestGbfsStationResolver(t *testing.T) { selectExpect: []string{"27045384-791c-4519-8087-fce2f7c48a69"}, }, } - c, te := newTestClient(t) - setupGbfs(te.GbfsFinder) + c, cfg := newTestClient(t) + setupGbfs(cfg.GbfsFinder) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { queryTestcase(t, c, tc) diff --git a/server/gql/loaders.go b/server/gql/loaders.go index 2a6013ec..e1a7de40 100644 --- a/server/gql/loaders.go +++ b/server/gql/loaders.go @@ -9,7 +9,6 @@ import ( dataloader "github.com/graph-gophers/dataloader/v7" "github.com/interline-io/log" "github.com/interline-io/transitland-lib/tl/tt" - "github.com/interline-io/transitland-server/config" "github.com/interline-io/transitland-server/model" ) @@ -136,11 +135,12 @@ func NewLoaders(dbf model.Finder) *Loaders { return loaders } -func loaderMiddleware(cfg config.Config, finder model.Finder, next http.Handler) http.Handler { +func loaderMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // This is per request scoped loaders/cache // Is this OK to use as a long term cache? - loaders := NewLoaders(finder) + cfg := model.ForContext(r.Context()) + loaders := NewLoaders(cfg.Finder) nextCtx := context.WithValue(r.Context(), loadersKey, loaders) r = r.WithContext(nextCtx) next.ServeHTTP(w, r) diff --git a/server/gql/mutation_resolver.go b/server/gql/mutation_resolver.go index 35f50f1a..ffe0f222 100644 --- a/server/gql/mutation_resolver.go +++ b/server/gql/mutation_resolver.go @@ -20,7 +20,7 @@ func (r *mutationResolver) ValidateGtfs(ctx context.Context, file *graphql.Uploa if file != nil { src = file.File } - return actions.ValidateUpload(ctx, r.cfg, src, url, rturls) + return actions.ValidateUpload(ctx, src, url, rturls) } func (r *mutationResolver) FeedVersionFetch(ctx context.Context, file *graphql.Upload, url *string, feedId string) (*model.FeedVersionFetchResult, error) { @@ -32,19 +32,19 @@ func (r *mutationResolver) FeedVersionFetch(ctx context.Context, file *graphql.U if url != nil { feedUrl = *url } - return actions.StaticFetch(ctx, r.cfg, r.finder, feedId, feedSrc, feedUrl, r.authzChecker) + return actions.StaticFetch(ctx, feedId, feedSrc, feedUrl) } func (r *mutationResolver) FeedVersionImport(ctx context.Context, fvid int) (*model.FeedVersionImportResult, error) { - return actions.FeedVersionImport(ctx, r.cfg, r.finder, r.authzChecker, fvid) + return actions.FeedVersionImport(ctx, fvid) } func (r *mutationResolver) FeedVersionUnimport(ctx context.Context, fvid int) (*model.FeedVersionUnimportResult, error) { - return actions.FeedVersionUnimport(ctx, r.cfg, r.finder, r.authzChecker, fvid) + return actions.FeedVersionUnimport(ctx, fvid) } func (r *mutationResolver) FeedVersionUpdate(ctx context.Context, fvid int, values model.FeedVersionSetInput) (*model.FeedVersion, error) { - err := actions.FeedVersionUpdate(ctx, r.cfg, r.finder, r.authzChecker, fvid, values) + err := actions.FeedVersionUpdate(ctx, fvid, values) return nil, err } diff --git a/server/gql/mutation_resolver_test.go b/server/gql/mutation_resolver_test.go index eb0bce0d..26841ed7 100644 --- a/server/gql/mutation_resolver_test.go +++ b/server/gql/mutation_resolver_test.go @@ -10,7 +10,7 @@ import ( "github.com/99designs/gqlgen/client" "github.com/interline-io/transitland-mw/auth/ancheck" - "github.com/interline-io/transitland-server/internal/testfinder" + "github.com/interline-io/transitland-server/internal/testconfig" "github.com/interline-io/transitland-server/internal/testutil" "github.com/interline-io/transitland-server/model" "github.com/stretchr/testify/assert" @@ -26,8 +26,9 @@ func TestFeedVersionFetchResolver(t *testing.T) { w.Write(buf) })) t.Run("found sha1", func(t *testing.T) { - testfinder.FindersTxRollback(t, nil, nil, func(te model.Finders) { - srv, _ := NewServer(te.Config, te.Finder, nil, nil, nil) + testconfig.ConfigTxRollback(t, testconfig.Options{}, func(cfg model.Config) { + srv, _ := NewServer() + srv = model.AddConfig(cfg)(srv) srv = ancheck.AdminDefaultMiddleware("test")(srv) // Run all requests as admin // Run all requests as admin c := client.New(srv) @@ -40,7 +41,7 @@ func TestFeedVersionFetchResolver(t *testing.T) { }) }) // t.Run("requires admin access", func(t *testing.T) { - // testfinder.FindersTxRollback(t, nil, nil, func(te testfinder.TestEnv) { + // testconfig.ConfigTxRollback(t, nil, nil, func(te testconfig.TestEnv) { // srv, _ := NewServer(te.Config, te.Finder, nil, nil, nil) // srv = authn.UserDefaultMiddleware("test")(srv) // Run all requests as regular user // c := client.New(srv) @@ -164,8 +165,9 @@ func TestValidateGtfsResolver(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testfinder.FindersTxRollback(t, nil, nil, func(te model.Finders) { - srv, _ := NewServer(te.Config, te.Finder, nil, nil, nil) + testconfig.ConfigTxRollback(t, testconfig.Options{}, func(cfg model.Config) { + srv, _ := NewServer() + srv = model.AddConfig(cfg)(srv) srv = ancheck.UserDefaultMiddleware("test")(srv) // Run all requests as user c := client.New(srv) queryTestcase(t, c, tc) @@ -173,8 +175,9 @@ func TestValidateGtfsResolver(t *testing.T) { }) } t.Run("requires user access", func(t *testing.T) { - testfinder.FindersTxRollback(t, nil, nil, func(te model.Finders) { - srv, _ := NewServer(te.Config, te.Finder, nil, nil, nil) // all requests run as anonymous context by default + testconfig.ConfigTxRollback(t, testconfig.Options{}, func(cfg model.Config) { + srv, _ := NewServer() // all requests run as anonymous context by default + srv = model.AddConfig(cfg)(srv) c := client.New(srv) resp := make(map[string]interface{}) err := c.Post(`mutation($url:String!) {validate_gtfs(url:$url){success}}`, &resp, client.Var("url", ts200.URL)) diff --git a/server/gql/query_resolver.go b/server/gql/query_resolver.go index 41175fc0..a84186b6 100644 --- a/server/gql/query_resolver.go +++ b/server/gql/query_resolver.go @@ -16,7 +16,7 @@ const MAX_RADIUS = 100_000 type queryResolver struct{ *Resolver } func (r *queryResolver) Me(ctx context.Context) (*model.Me, error) { - me, err := r.authzChecker.Me(ctx, &authz.MeRequest{}) + me, err := model.ForContext(ctx).Checker.Me(ctx, &authz.MeRequest{}) if err != nil { return nil, err } @@ -43,7 +43,7 @@ func (r *queryResolver) Agencies(ctx context.Context, limit *int, after *int, id return nil, errors.New("bbox too large") } } - return r.finder.FindAgencies(ctx, checkLimit(limit), checkCursor(after), ids, where) + return model.ForContext(ctx).Finder.FindAgencies(ctx, checkLimit(limit), checkCursor(after), ids, where) } func (r *queryResolver) Routes(ctx context.Context, limit *int, after *int, ids []int, where *model.RouteFilter) ([]*model.Route, error) { @@ -56,7 +56,7 @@ func (r *queryResolver) Routes(ctx context.Context, limit *int, after *int, ids return nil, errors.New("bbox too large") } } - return r.finder.FindRoutes(ctx, checkLimit(limit), checkCursor(after), ids, where) + return model.ForContext(ctx).Finder.FindRoutes(ctx, checkLimit(limit), checkCursor(after), ids, where) } func (r *queryResolver) Stops(ctx context.Context, limit *int, after *int, ids []int, where *model.StopFilter) ([]*model.Stop, error) { @@ -69,12 +69,12 @@ func (r *queryResolver) Stops(ctx context.Context, limit *int, after *int, ids [ return nil, errors.New("bbox too large") } } - return r.finder.FindStops(ctx, checkLimit(limit), checkCursor(after), ids, where) + return model.ForContext(ctx).Finder.FindStops(ctx, checkLimit(limit), checkCursor(after), ids, where) } func (r *queryResolver) Trips(ctx context.Context, limit *int, after *int, ids []int, where *model.TripFilter) ([]*model.Trip, error) { addMetric(ctx, "trips") - return r.finder.FindTrips(ctx, checkLimit(limit), checkCursor(after), ids, where) + return model.ForContext(ctx).Finder.FindTrips(ctx, checkLimit(limit), checkCursor(after), ids, where) } func (r *queryResolver) FeedVersions(ctx context.Context, limit *int, after *int, ids []int, where *model.FeedVersionFilter) ([]*model.FeedVersion, error) { @@ -87,7 +87,7 @@ func (r *queryResolver) FeedVersions(ctx context.Context, limit *int, after *int return nil, errors.New("bbox too large") } } - return r.finder.FindFeedVersions(ctx, checkLimit(limit), checkCursor(after), ids, where) + return model.ForContext(ctx).Finder.FindFeedVersions(ctx, checkLimit(limit), checkCursor(after), ids, where) } func (r *queryResolver) Feeds(ctx context.Context, limit *int, after *int, ids []int, where *model.FeedFilter) ([]*model.Feed, error) { @@ -100,7 +100,7 @@ func (r *queryResolver) Feeds(ctx context.Context, limit *int, after *int, ids [ return nil, errors.New("bbox too large") } } - return r.finder.FindFeeds(ctx, checkLimit(limit), checkCursor(after), ids, where) + return model.ForContext(ctx).Finder.FindFeeds(ctx, checkLimit(limit), checkCursor(after), ids, where) } func (r *queryResolver) Operators(ctx context.Context, limit *int, after *int, ids []int, where *model.OperatorFilter) ([]*model.Operator, error) { @@ -113,11 +113,11 @@ func (r *queryResolver) Operators(ctx context.Context, limit *int, after *int, i return nil, errors.New("bbox too large") } } - return r.finder.FindOperators(ctx, checkLimit(limit), checkCursor(after), ids, where) + return model.ForContext(ctx).Finder.FindOperators(ctx, checkLimit(limit), checkCursor(after), ids, where) } func (r *queryResolver) Places(ctx context.Context, limit *int, after *int, level *model.PlaceAggregationLevel, where *model.PlaceFilter) ([]*model.Place, error) { - return r.finder.FindPlaces(ctx, checkLimit(limit), checkCursor(after), nil, level, where) + return model.ForContext(ctx).Finder.FindPlaces(ctx, checkLimit(limit), checkCursor(after), nil, level, where) } func addMetric(ctx context.Context, resolverName string) { diff --git a/server/gql/resolver.go b/server/gql/resolver.go index 3f9b0951..4015b40f 100644 --- a/server/gql/resolver.go +++ b/server/gql/resolver.go @@ -4,7 +4,6 @@ import ( "context" "strconv" - "github.com/interline-io/transitland-server/config" "github.com/interline-io/transitland-server/internal/generated/gqlout" "github.com/interline-io/transitland-server/internal/xy" "github.com/interline-io/transitland-server/model" @@ -54,12 +53,7 @@ func atoi(v string) int { // Resolver . type Resolver struct { - cfg config.Config - rtfinder model.RTFinder - finder model.Finder - gbfsFinder model.GbfsFinder - authzChecker model.Checker - fvslCache *fvslCache + fvslCache *fvslCache } // Query . diff --git a/server/gql/resolver_test.go b/server/gql/resolver_test.go index a5e7be86..cf5ee906 100644 --- a/server/gql/resolver_test.go +++ b/server/gql/resolver_test.go @@ -5,13 +5,11 @@ import ( "log" "os" "testing" - "time" "github.com/99designs/gqlgen/client" "github.com/interline-io/transitland-mw/auth/ancheck" "github.com/interline-io/transitland-mw/auth/authn" - "github.com/interline-io/transitland-server/internal/clock" - "github.com/interline-io/transitland-server/internal/testfinder" + "github.com/interline-io/transitland-server/internal/testconfig" "github.com/interline-io/transitland-server/internal/testutil" "github.com/interline-io/transitland-server/model" "github.com/stretchr/testify/assert" @@ -46,19 +44,21 @@ func TestMain(m *testing.M) { // Test helpers -func newTestClient(t testing.TB) (*client.Client, model.Finders) { - when, err := time.Parse("2006-01-02T15:04:05", "2022-09-01T00:00:00") - if err != nil { - t.Fatal(err) - } - return newTestClientWithClock(t, &clock.Mock{T: when}, testfinder.DefaultRTJson()) +func newTestClient(t testing.TB) (*client.Client, model.Config) { + return newTestClientWithOpts(t, testconfig.Options{ + When: "2022-09-01T00:00:00", + RTJsons: testconfig.DefaultRTJson(), + }) } -func newTestClientWithClock(t testing.TB, cl clock.Clock, rtfiles []testfinder.RTJsonFile) (*client.Client, model.Finders) { - te := testfinder.Finders(t, cl, rtfiles) - srv, _ := NewServer(te.Config, te.Finder, te.RTFinder, te.GbfsFinder, te.Checker) - srvMiddleware := ancheck.NewUserDefaultMiddleware(func() authn.User { return authn.NewCtxUser("testuser", "", "").WithRoles("testrole") }) - return client.New(srvMiddleware(srv)), te +func newTestClientWithOpts(t testing.TB, opts testconfig.Options) (*client.Client, model.Config) { + cfg := testconfig.Config(t, opts) + srv, _ := NewServer() + graphqlServer := model.AddConfig(cfg)(srv) + srvMiddleware := ancheck.NewUserDefaultMiddleware(func() authn.User { + return authn.NewCtxUser("testuser", "", "").WithRoles("testrole") + }) + return client.New(srvMiddleware(graphqlServer)), cfg } func toJson(m map[string]interface{}) string { diff --git a/server/gql/route_resolver.go b/server/gql/route_resolver.go index f47bca87..c1077052 100644 --- a/server/gql/route_resolver.go +++ b/server/gql/route_resolver.go @@ -63,7 +63,7 @@ func (r *routeResolver) Headways(ctx context.Context, obj *model.Route, limit *i func (r *routeResolver) RouteStopBuffer(ctx context.Context, obj *model.Route, radius *float64) (*model.RouteStopBuffer, error) { // TODO: remove n+1 (which is tricky, what if multiple radius specified in different parts of query) p := model.RouteStopBufferParam{Radius: radius, EntityID: obj.ID} - ents, err := r.finder.RouteStopBuffer(ctx, &p) + ents, err := model.ForContext(ctx).Finder.RouteStopBuffer(ctx, &p) if err != nil { return nil, err } @@ -74,7 +74,7 @@ func (r *routeResolver) RouteStopBuffer(ctx context.Context, obj *model.Route, r } func (r *routeResolver) Alerts(ctx context.Context, obj *model.Route, active *bool, limit *int) ([]*model.Alert, error) { - return r.rtfinder.FindAlertsForRoute(obj, checkLimit(limit), active), nil + return model.ForContext(ctx).RTFinder.FindAlertsForRoute(obj, checkLimit(limit), active), nil } func (r *routeResolver) Patterns(ctx context.Context, obj *model.Route) ([]*model.RouteStopPattern, error) { @@ -124,7 +124,7 @@ type routePatternResolver struct{ *Resolver } func (r *routePatternResolver) Trips(ctx context.Context, obj *model.RouteStopPattern, limit *int) ([]*model.Trip, error) { // TODO: N+1 query - trips, err := r.finder.FindTrips(ctx, checkLimit(limit), nil, nil, &model.TripFilter{StopPatternID: &obj.StopPatternID, RouteIds: []int{obj.RouteID}}) + trips, err := model.ForContext(ctx).Finder.FindTrips(ctx, checkLimit(limit), nil, nil, &model.TripFilter{StopPatternID: &obj.StopPatternID, RouteIds: []int{obj.RouteID}}) return trips, err } diff --git a/server/gql/route_resolver_test.go b/server/gql/route_resolver_test.go index 32c0595d..ab378771 100644 --- a/server/gql/route_resolver_test.go +++ b/server/gql/route_resolver_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "github.com/interline-io/transitland-server/model" "github.com/stretchr/testify/assert" "github.com/tidwall/gjson" ) @@ -306,8 +307,8 @@ func TestRouteResolver_PreviousOnestopID(t *testing.T) { } func TestRouteResolver_Cursor(t *testing.T) { - c, te := newTestClient(t) - allEnts, err := te.Finder.FindRoutes(context.Background(), nil, nil, nil, nil) + c, cfg := newTestClient(t) + allEnts, err := cfg.Finder.FindRoutes(model.WithConfig(context.Background(), cfg), nil, nil, nil, nil) if err != nil { t.Fatal(err) } diff --git a/server/gql/rt_test.go b/server/gql/rt_test.go index 2ca6181c..165ae730 100644 --- a/server/gql/rt_test.go +++ b/server/gql/rt_test.go @@ -2,12 +2,10 @@ package gql import ( "testing" - "time" "github.com/99designs/gqlgen/client" "github.com/interline-io/transitland-lib/tl/tt" - "github.com/interline-io/transitland-server/internal/clock" - "github.com/interline-io/transitland-server/internal/testfinder" + "github.com/interline-io/transitland-server/internal/testconfig" "github.com/stretchr/testify/assert" "github.com/tidwall/gjson" ) @@ -113,17 +111,16 @@ type rtTestCase struct { name string query string vars map[string]interface{} - rtfiles []testfinder.RTJsonFile + rtfiles []testconfig.RTJsonFile cb func(t *testing.T, jj string) } func testRt(t *testing.T, tc rtTestCase) { // Create a new RT Finder for each test... - when, err := time.Parse("2006-01-02T15:04:05", "2022-09-01T00:00:00") - if err != nil { - t.Fatal(err) - } - c, _ := newTestClientWithClock(t, &clock.Mock{T: when}, tc.rtfiles) + c, _ := newTestClientWithOpts(t, testconfig.Options{ + When: "2022-09-01T00:00:00", + RTJsons: tc.rtfiles, + }) var resp map[string]interface{} opts := []client.Option{} for k, v := range tc.vars { @@ -144,7 +141,7 @@ func TestStopRTBasic(t *testing.T) { "stop times basic", baseStopQuery, newBaseStopVars(), - testfinder.DefaultRTJson(), + testconfig.DefaultRTJson(), func(t *testing.T, jj string) { // A little more explicit version of the string check test a := gjson.Get(jj, "stops.0.stop_times").Array() @@ -184,7 +181,7 @@ func TestStopRTBasic_ArrivalFallback(t *testing.T) { "arrival will use departure if arrival is not present", baseStopQuery, newBaseStopVars(), - []testfinder.RTJsonFile{{Feed: "BA", Ftype: "realtime_trip_updates", Fname: "BA-arrival-fallback.json"}}, + []testconfig.RTJsonFile{{Feed: "BA", Ftype: "realtime_trip_updates", Fname: "BA-arrival-fallback.json"}}, func(t *testing.T, jj string) { a := gjson.Get(jj, "stops.0.stop_times").Array() checkTrip := "1031527WKDY" @@ -209,7 +206,7 @@ func TestStopRTBasic_DepartureFallback(t *testing.T) { "departure will use arrival if departure is not present", baseStopQuery, newBaseStopVars(), - []testfinder.RTJsonFile{{Feed: "BA", Ftype: "realtime_trip_updates", Fname: "BA-departure-fallback.json"}}, + []testconfig.RTJsonFile{{Feed: "BA", Ftype: "realtime_trip_updates", Fname: "BA-departure-fallback.json"}}, func(t *testing.T, jj string) { a := gjson.Get(jj, "stops.0.stop_times").Array() checkTrip := "1031527WKDY" @@ -234,7 +231,7 @@ func TestStopRTBasic_StopIDFallback(t *testing.T) { "use stop_id as fallback if no matching stop sequence", baseStopQuery, newBaseStopVars(), - []testfinder.RTJsonFile{{Feed: "BA", Ftype: "realtime_trip_updates", Fname: "BA-stop-id-fallback.json"}}, + []testconfig.RTJsonFile{{Feed: "BA", Ftype: "realtime_trip_updates", Fname: "BA-stop-id-fallback.json"}}, func(t *testing.T, jj string) { a := gjson.Get(jj, "stops.0.stop_times").Array() checkTrip := "1031527WKDY" @@ -260,7 +257,7 @@ func TestStopRTBasic_StopIDFallback_NoDoubleVisit(t *testing.T) { "do not use stop_id as fallback if stop is visited twice", baseStopQuery, newBaseStopVars(), - []testfinder.RTJsonFile{{Feed: "BA", Ftype: "realtime_trip_updates", Fname: "BA-stop-double-visit.json"}}, + []testconfig.RTJsonFile{{Feed: "BA", Ftype: "realtime_trip_updates", Fname: "BA-stop-double-visit.json"}}, func(t *testing.T, jj string) { a := gjson.Get(jj, "stops.0.stop_times").Array() checkTrip := "1031527WKDY" @@ -285,7 +282,7 @@ func TestStopRTBasic_NoRT(t *testing.T) { "no rt matches for trip 2211533WKDY", baseStopQuery, newBaseStopVars(), - []testfinder.RTJsonFile{{Feed: "BA", Ftype: "realtime_trip_updates", Fname: "BA-departure-fallback.json"}}, + []testconfig.RTJsonFile{{Feed: "BA", Ftype: "realtime_trip_updates", Fname: "BA-departure-fallback.json"}}, func(t *testing.T, jj string) { a := gjson.Get(jj, "stops.0.stop_times").Array() checkTrip := "2211533WKDY" @@ -313,7 +310,7 @@ func TestStopRTAddedTrip(t *testing.T) { "stop times added trip", baseStopQuery, newBaseStopVars(), - []testfinder.RTJsonFile{{Feed: "BA", Ftype: "realtime_trip_updates", Fname: "BA-added.json"}}, + []testconfig.RTJsonFile{{Feed: "BA", Ftype: "realtime_trip_updates", Fname: "BA-added.json"}}, func(t *testing.T, jj string) { checkTrip := "-123" found := false @@ -346,7 +343,7 @@ func TestStopRTCanceledTrip(t *testing.T) { "stop times canceled trip", baseStopQuery, newBaseStopVars(), - []testfinder.RTJsonFile{{Feed: "BA", Ftype: "realtime_trip_updates", Fname: "BA-added.json"}}, + []testconfig.RTJsonFile{{Feed: "BA", Ftype: "realtime_trip_updates", Fname: "BA-added.json"}}, func(t *testing.T, jj string) { checkTrip := "2211533WKDY" found := false @@ -381,7 +378,7 @@ func TestTripAlerts(t *testing.T) { "trip alerts", baseStopQuery, newBaseStopVars(), - []testfinder.RTJsonFile{ + []testconfig.RTJsonFile{ {Feed: "BA", Ftype: "realtime_alerts", Fname: "BA-alerts.json"}, }, func(t *testing.T, jj string) { @@ -407,7 +404,7 @@ func TestTripAlerts(t *testing.T) { "trip alerts active", baseStopQuery, activeVars, - []testfinder.RTJsonFile{ + []testconfig.RTJsonFile{ {Feed: "BA", Ftype: "realtime_alerts", Fname: "BA-alerts.json"}, }, func(t *testing.T, jj string) { @@ -447,7 +444,7 @@ func TestRouteAlerts(t *testing.T) { "stop alerts active", baseStopQuery, newBaseStopVars(), - []testfinder.RTJsonFile{ + []testconfig.RTJsonFile{ {Feed: "BA", Ftype: "realtime_alerts", Fname: "BA-alerts.json"}, }, func(t *testing.T, jj string) { @@ -474,7 +471,7 @@ func TestRouteAlerts(t *testing.T) { "stop alerts active", baseStopQuery, activeVars, - []testfinder.RTJsonFile{ + []testconfig.RTJsonFile{ {Feed: "BA", Ftype: "realtime_alerts", Fname: "BA-alerts.json"}, }, func(t *testing.T, jj string) { @@ -516,7 +513,7 @@ func TestStopAlerts(t *testing.T) { "stop alerts", baseStopQuery, newBaseStopVars(), - []testfinder.RTJsonFile{ + []testconfig.RTJsonFile{ {Feed: "BA", Ftype: "realtime_alerts", Fname: "BA-alerts.json"}, }, func(t *testing.T, jj string) { @@ -530,7 +527,7 @@ func TestStopAlerts(t *testing.T) { "stop alerts active", baseStopQuery, activeVars, - []testfinder.RTJsonFile{ + []testconfig.RTJsonFile{ {Feed: "BA", Ftype: "realtime_alerts", Fname: "BA-alerts.json"}, }, func(t *testing.T, jj string) { @@ -559,7 +556,7 @@ func TestAgencyAlerts(t *testing.T) { "stop alerts", baseStopQuery, newBaseStopVars(), - []testfinder.RTJsonFile{ + []testconfig.RTJsonFile{ {Feed: "BA", Ftype: "realtime_alerts", Fname: "BA-alerts.json"}, }, func(t *testing.T, jj string) { @@ -586,7 +583,7 @@ func TestAgencyAlerts(t *testing.T) { "stop alerts active", baseStopQuery, activeVars, - []testfinder.RTJsonFile{{Feed: "BA", Ftype: "realtime_alerts", Fname: "BA-alerts.json"}}, + []testconfig.RTJsonFile{{Feed: "BA", Ftype: "realtime_alerts", Fname: "BA-alerts.json"}}, func(t *testing.T, jj string) { checkTrip := "1031527WKDY" sts := gjson.Get(jj, "stops.0.stop_times").Array() diff --git a/server/gql/server.go b/server/gql/server.go index 1a05dec3..0b5cc71c 100644 --- a/server/gql/server.go +++ b/server/gql/server.go @@ -8,19 +8,13 @@ import ( "github.com/99designs/gqlgen/graphql" "github.com/99designs/gqlgen/graphql/handler" "github.com/interline-io/transitland-mw/auth/authn" - "github.com/interline-io/transitland-server/config" "github.com/interline-io/transitland-server/internal/generated/gqlout" "github.com/interline-io/transitland-server/model" ) -func NewServer(cfg config.Config, dbfinder model.Finder, rtfinder model.RTFinder, gbfsFinder model.GbfsFinder, checker model.Checker) (http.Handler, error) { +func NewServer() (http.Handler, error) { c := gqlout.Config{Resolvers: &Resolver{ - cfg: cfg, - finder: dbfinder, - rtfinder: rtfinder, - gbfsFinder: gbfsFinder, - fvslCache: newFvslCache(dbfinder), - authzChecker: checker, + fvslCache: newFvslCache(), }} c.Directives.HasRole = func(ctx context.Context, obj interface{}, next graphql.Resolver, role model.Role) (interface{}, error) { user := authn.ForContext(ctx) @@ -31,6 +25,6 @@ func NewServer(cfg config.Config, dbfinder model.Finder, rtfinder model.RTFinder } // Setup server srv := handler.NewDefaultServer(gqlout.NewExecutableSchema(c)) - graphqlServer := loaderMiddleware(cfg, dbfinder, srv) + graphqlServer := loaderMiddleware(srv) return graphqlServer, nil } diff --git a/server/gql/stop_resolver.go b/server/gql/stop_resolver.go index 0a8b2c4b..e4f2d446 100644 --- a/server/gql/stop_resolver.go +++ b/server/gql/stop_resolver.go @@ -99,13 +99,13 @@ func (r *stopResolver) getStopTimes(ctx context.Context, obj *model.Stop, limit if where != nil { // Convert where.Next into departure date and time window if where.Next != nil { - loc, ok := r.rtfinder.StopTimezone(obj.ID, obj.StopTimezone) + loc, ok := model.ForContext(ctx).RTFinder.StopTimezone(obj.ID, obj.StopTimezone) if !ok { return nil, errors.New("timezone not available for stop") } serviceDate := time.Now().In(loc) - if r.cfg.Clock != nil { - serviceDate = r.cfg.Clock.Now().In(loc) + if model.ForContext(ctx).Clock != nil { + serviceDate = model.ForContext(ctx).Clock.Now().In(loc) } st, et := 0, 0 st = serviceDate.Hour()*3600 + serviceDate.Minute()*60 + serviceDate.Second() @@ -118,7 +118,7 @@ func (r *stopResolver) getStopTimes(ctx context.Context, obj *model.Stop, limit } // Check if service date is outside the window for this feed version if where.ServiceDate != nil && (where.UseServiceWindow != nil && *where.UseServiceWindow) { - sl, ok := r.fvslCache.Get(obj.FeedVersionID) + sl, ok := r.fvslCache.Get(ctx, obj.FeedVersionID) if !ok { return nil, errors.New("service level information not available for feed version") } @@ -163,13 +163,13 @@ func (r *stopResolver) getStopTimes(ctx context.Context, obj *model.Stop, limit for _, st := range sts { ft := model.Trip{} ft.FeedVersionID = obj.FeedVersionID - ft.TripID, _ = r.rtfinder.GetGtfsTripID(atoi(st.TripID)) // TODO! - if ste, ok := r.rtfinder.FindStopTimeUpdate(&ft, st); ok { + ft.TripID, _ = model.ForContext(ctx).RTFinder.GetGtfsTripID(atoi(st.TripID)) // TODO! + if ste, ok := model.ForContext(ctx).RTFinder.FindStopTimeUpdate(&ft, st); ok { st.RTStopTimeUpdate = ste } } // Handle added trips; these must specify stop_id in StopTimeUpdates - for _, rtTrip := range r.rtfinder.GetAddedTripsForStop(obj) { + for _, rtTrip := range model.ForContext(ctx).RTFinder.GetAddedTripsForStop(obj) { for _, stu := range rtTrip.StopTimeUpdate { if stu.GetStopId() != obj.StopID { continue @@ -196,7 +196,7 @@ func (r *stopResolver) getStopTimes(ctx context.Context, obj *model.Stop, limit } func (r *stopResolver) Alerts(ctx context.Context, obj *model.Stop, active *bool, limit *int) ([]*model.Alert, error) { - rtAlerts := r.rtfinder.FindAlertsForStop(obj, checkLimit(limit), active) + rtAlerts := model.ForContext(ctx).RTFinder.FindAlertsForStop(obj, checkLimit(limit), active) return rtAlerts, nil } @@ -223,7 +223,7 @@ func (r *stopResolver) Directions(ctx context.Context, obj *model.Stop, from *mo func (r *stopResolver) NearbyStops(ctx context.Context, obj *model.Stop, limit *int, radius *float64) ([]*model.Stop, error) { c := obj.Coordinates() - nearbyStops, err := r.finder.FindStops(ctx, checkLimit(limit), nil, nil, &model.StopFilter{Near: &model.PointRadius{Lon: c[0], Lat: c[1], Radius: checkFloat(radius, 0, MAX_RADIUS)}}) + nearbyStops, err := model.ForContext(ctx).Finder.FindStops(ctx, checkLimit(limit), nil, nil, &model.StopFilter{Near: &model.PointRadius{Lon: c[0], Lat: c[1], Radius: checkFloat(radius, 0, MAX_RADIUS)}}) return nearbyStops, err } diff --git a/server/gql/stop_resolver_test.go b/server/gql/stop_resolver_test.go index 5be3f844..28ce9b56 100644 --- a/server/gql/stop_resolver_test.go +++ b/server/gql/stop_resolver_test.go @@ -10,31 +10,31 @@ import ( ) func TestStopResolver(t *testing.T) { - c, te := newTestClient(t) - queryTestcases(t, c, stopResolverTestcases(t, te)) + c, cfg := newTestClient(t) + queryTestcases(t, c, stopResolverTestcases(t, cfg)) } func TestStopResolver_Cursor(t *testing.T) { - c, te := newTestClient(t) - queryTestcases(t, c, stopResolverCursorTestcases(t, te)) + c, cfg := newTestClient(t) + queryTestcases(t, c, stopResolverCursorTestcases(t, cfg)) } func TestStopResolver_PreviousOnestopID(t *testing.T) { - c, te := newTestClient(t) - queryTestcases(t, c, stopResolverPreviousOnestopIDTestcases(t, te)) + c, cfg := newTestClient(t) + queryTestcases(t, c, stopResolverPreviousOnestopIDTestcases(t, cfg)) } func TestStopResolver_License(t *testing.T) { - c, te := newTestClient(t) - queryTestcases(t, c, stopResolverLicenseTestcases(t, te)) + c, cfg := newTestClient(t) + queryTestcases(t, c, stopResolverLicenseTestcases(t, cfg)) } func TestStopResolver_AdminCache(t *testing.T) { type canLoadAdmins interface { LoadAdmins() error } - c, te := newTestClient(t) - if v, ok := te.Finder.(canLoadAdmins); !ok { + c, cfg := newTestClient(t) + if v, ok := cfg.Finder.(canLoadAdmins); !ok { t.Fatal("finder cant load admins") } else { if err := v.LoadAdmins(); err != nil { @@ -83,11 +83,11 @@ func TestStopResolver_AdminCache(t *testing.T) { } func BenchmarkStopResolver(b *testing.B) { - c, te := newTestClient(b) - benchmarkTestcases(b, c, stopResolverTestcases(b, te)) + c, cfg := newTestClient(b) + benchmarkTestcases(b, c, stopResolverTestcases(b, cfg)) } -func stopResolverTestcases(t testing.TB, te model.Finders) []testcase { +func stopResolverTestcases(t testing.TB, cfg model.Config) []testcase { bartStops := []string{"12TH", "16TH", "19TH", "19TH_N", "24TH", "ANTC", "ASHB", "BALB", "BAYF", "CAST", "CIVC", "COLS", "COLM", "CONC", "DALY", "DBRK", "DUBL", "DELN", "PLZA", "EMBR", "FRMT", "FTVL", "GLEN", "HAYW", "LAFY", "LAKE", "MCAR", "MCAR_S", "MLBR", "MONT", "NBRK", "NCON", "OAKL", "ORIN", "PITT", "PCTR", "PHIL", "POWL", "RICH", "ROCK", "SBRN", "SFIA", "SANL", "SHAY", "SSAN", "UCTY", "WCRK", "WARM", "WDUB", "WOAK"} caltrainRailStops := []string{"70011", "70012", "70021", "70022", "70031", "70032", "70041", "70042", "70051", "70052", "70061", "70062", "70071", "70072", "70081", "70082", "70091", "70092", "70101", "70102", "70111", "70112", "70121", "70122", "70131", "70132", "70141", "70142", "70151", "70152", "70161", "70162", "70171", "70172", "70191", "70192", "70201", "70202", "70211", "70212", "70221", "70222", "70231", "70232", "70241", "70242", "70251", "70252", "70261", "70262", "70271", "70272", "70281", "70282", "70291", "70292", "70301", "70302", "70311", "70312", "70321", "70322"} caltrainBusStops := []string{"777402", "777403"} @@ -99,7 +99,7 @@ func stopResolverTestcases(t testing.TB, te model.Finders) []testcase { allStops = append(allStops, caltrainStops...) vars := hw{"stop_id": "MCAR"} stopObsFvid := 0 - if err := te.Finder.DBX().QueryRowx("select feed_version_id from ext_performance_stop_observations limit 1").Scan(&stopObsFvid); err != nil { + if err := cfg.Finder.DBX().QueryRowx("select feed_version_id from ext_performance_stop_observations limit 1").Scan(&stopObsFvid); err != nil { t.Errorf("could not get fvid for stop observation test: %s", err.Error()) } testcases := []testcase{ @@ -510,9 +510,9 @@ func stopResolverTestcases(t testing.TB, te model.Finders) []testcase { return testcases } -func stopResolverCursorTestcases(t *testing.T, te model.Finders) []testcase { +func stopResolverCursorTestcases(t *testing.T, cfg model.Config) []testcase { // First 1000 stops... - dbf := te.Finder + dbf := cfg.Finder allEnts, err := dbf.FindStops(context.Background(), nil, nil, nil, nil) if err != nil { t.Fatal(err) @@ -562,7 +562,7 @@ func stopResolverCursorTestcases(t *testing.T, te model.Finders) []testcase { return testcases } -func stopResolverPreviousOnestopIDTestcases(t testing.TB, te model.Finders) []testcase { +func stopResolverPreviousOnestopIDTestcases(t testing.TB, cfg model.Config) []testcase { testcases := []testcase{ { name: "default", @@ -596,7 +596,7 @@ func stopResolverPreviousOnestopIDTestcases(t testing.TB, te model.Finders) []te return testcases } -func stopResolverLicenseTestcases(t testing.TB, te model.Finders) []testcase { +func stopResolverLicenseTestcases(t testing.TB, cfg model.Config) []testcase { q := ` query ($lic: LicenseFilter) { stops(limit: 10000, where: {license: $lic}) { diff --git a/server/gql/stop_time_resolver.go b/server/gql/stop_time_resolver.go index e76688f7..0f6f70c9 100644 --- a/server/gql/stop_time_resolver.go +++ b/server/gql/stop_time_resolver.go @@ -26,7 +26,7 @@ func (r *stopTimeResolver) Trip(ctx context.Context, obj *model.StopTime) (*mode t := model.Trip{} t.FeedVersionID = obj.FeedVersionID t.TripID = obj.RTTripID - a, err := r.rtfinder.MakeTrip(&t) + a, err := model.ForContext(ctx).RTFinder.MakeTrip(&t) return a, err } return For(ctx).TripsByID.Load(ctx, atoi(obj.TripID))() @@ -34,7 +34,7 @@ func (r *stopTimeResolver) Trip(ctx context.Context, obj *model.StopTime) (*mode func (r *stopTimeResolver) Arrival(ctx context.Context, obj *model.StopTime) (*model.StopTimeEvent, error) { // lookup timezone - loc, ok := r.rtfinder.StopTimezone(atoi(obj.StopID), "") + loc, ok := model.ForContext(ctx).RTFinder.StopTimezone(atoi(obj.StopID), "") if !ok { return nil, errors.New("timezone not available for stop") } @@ -54,7 +54,7 @@ func (r *stopTimeResolver) Arrival(ctx context.Context, obj *model.StopTime) (*m func (r *stopTimeResolver) Departure(ctx context.Context, obj *model.StopTime) (*model.StopTimeEvent, error) { // lookup timezone - loc, ok := r.rtfinder.StopTimezone(atoi(obj.StopID), "") + loc, ok := model.ForContext(ctx).RTFinder.StopTimezone(atoi(obj.StopID), "") if !ok { return nil, errors.New("timezone not available for stop") } diff --git a/server/gql/stop_time_resolver_test.go b/server/gql/stop_time_resolver_test.go index 813f2d03..a5f621ec 100644 --- a/server/gql/stop_time_resolver_test.go +++ b/server/gql/stop_time_resolver_test.go @@ -2,10 +2,8 @@ package gql import ( "testing" - "time" - "github.com/interline-io/transitland-server/internal/clock" - "github.com/interline-io/transitland-server/internal/testfinder" + "github.com/interline-io/transitland-server/internal/testconfig" ) func TestStopResolver_StopTimes(t *testing.T) { @@ -363,11 +361,10 @@ func TestStopResolver_StopTimes_Next(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // 2018-05-28 22:00:00 +0000 UTC // 2018-05-28 15:00:00 -0700 PDT - when, err := time.Parse("2006-01-02T15:04:05", tc.when) - if err != nil { - t.Fatal(err) - } - c, _ := newTestClientWithClock(t, &clock.Mock{T: when}, testfinder.DefaultRTJson()) + c, _ := newTestClientWithOpts(t, testconfig.Options{ + When: tc.when, + RTJsons: testconfig.DefaultRTJson(), + }) queryTestcase(t, c, tc.testcase) }) } diff --git a/server/gql/trip_resolver.go b/server/gql/trip_resolver.go index 30c285f6..f6b27990 100644 --- a/server/gql/trip_resolver.go +++ b/server/gql/trip_resolver.go @@ -45,7 +45,7 @@ func (r *tripResolver) Frequencies(ctx context.Context, obj *model.Trip, limit * func (r *tripResolver) ScheduleRelationship(ctx context.Context, obj *model.Trip) (*model.ScheduleRelationship, error) { msr := model.ScheduleRelationshipScheduled - if rtt := r.rtfinder.FindTrip(obj); rtt != nil { + if rtt := model.ForContext(ctx).RTFinder.FindTrip(obj); rtt != nil { sr := rtt.GetTrip().GetScheduleRelationship().String() switch sr { case "SCHEDULED": @@ -64,7 +64,7 @@ func (r *tripResolver) ScheduleRelationship(ctx context.Context, obj *model.Trip } func (r *tripResolver) Timestamp(ctx context.Context, obj *model.Trip) (*time.Time, error) { - if rtt := r.rtfinder.FindTrip(obj); rtt != nil { + if rtt := model.ForContext(ctx).RTFinder.FindTrip(obj); rtt != nil { t := time.Unix(int64(rtt.GetTimestamp()), 0).In(time.UTC) return &t, nil } @@ -72,6 +72,6 @@ func (r *tripResolver) Timestamp(ctx context.Context, obj *model.Trip) (*time.Ti } func (r *tripResolver) Alerts(ctx context.Context, obj *model.Trip, active *bool, limit *int) ([]*model.Alert, error) { - rtAlerts := r.rtfinder.FindAlertsForTrip(obj, checkLimit(limit), active) + rtAlerts := model.ForContext(ctx).RTFinder.FindAlertsForTrip(obj, checkLimit(limit), active) return rtAlerts, nil } diff --git a/server/rest/agency_request_test.go b/server/rest/agency_request_test.go index f1569a73..64e203ee 100644 --- a/server/rest/agency_request_test.go +++ b/server/rest/agency_request_test.go @@ -5,14 +5,15 @@ import ( "strings" "testing" + "github.com/interline-io/transitland-server/internal/testconfig" + "github.com/interline-io/transitland-server/model" "github.com/stretchr/testify/assert" "github.com/tidwall/gjson" ) func TestAgencyRequest(t *testing.T) { - srv, te := testRestConfig(t) fv := "e535eb2b3b9ac3ef15d82c56575e914575e732e0" - testcases := []testRest{ + testcases := []testCase{ { name: "basic", h: AgencyRequest{}, @@ -159,13 +160,13 @@ func TestAgencyRequest(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + checkTestCase(t, tc) }) } } func TestAgencyRequest_Format(t *testing.T) { - tcs := []testRest{ + tcs := []testCase{ { name: "agency geojson", format: "geojson", @@ -192,17 +193,16 @@ func TestAgencyRequest_Format(t *testing.T) { }, }, } - srv, te := testRestConfig(t) for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + checkTestCase(t, tc) }) } } func TestAgencyRequest_Pagination(t *testing.T) { - srv, te := testRestConfig(t) - allEnts, err := te.Finder.FindAgencies(context.Background(), nil, nil, nil, nil) + graphqlHandler, restHandler, cfg := testHandlersWithOptions(t, testconfig.Options{}) + allEnts, err := cfg.Finder.FindAgencies(model.WithConfig(context.Background(), cfg), nil, nil, nil, nil) if err != nil { t.Fatal(err) } @@ -210,7 +210,7 @@ func TestAgencyRequest_Pagination(t *testing.T) { for _, ent := range allEnts { allIds = append(allIds, ent.AgencyID) } - testcases := []testRest{ + testcases := []testCase{ { name: "limit:1", h: AgencyRequest{WithCursor: WithCursor{Limit: 1}}, @@ -240,13 +240,13 @@ func TestAgencyRequest_Pagination(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + checkTestCaseWithHandlers(t, tc, graphqlHandler, restHandler) }) } } func TestAgencyRequest_License(t *testing.T) { - testcases := []testRest{ + testcases := []testCase{ { name: "license:share_alike_optional yes", h: AgencyRequest{WithCursor: WithCursor{Limit: 10_000}, LicenseFilter: LicenseFilter{LicenseShareAlikeOptional: "yes"}}, selector: "agencies.#.agency_id", @@ -293,10 +293,9 @@ func TestAgencyRequest_License(t *testing.T) { expectSelect: []string{"caltrain-ca-us", ""}, }, } - srv, te := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + checkTestCase(t, tc) }) } } diff --git a/server/rest/feed_request_test.go b/server/rest/feed_request_test.go index 0ea1d5b3..687da1b1 100644 --- a/server/rest/feed_request_test.go +++ b/server/rest/feed_request_test.go @@ -11,7 +11,7 @@ import ( func TestFeedRequest(t *testing.T) { // fv := "e535eb2b3b9ac3ef15d82c56575e914575e732e0" - testcases := []testRest{ + testcases := []testCase{ { name: "basic", h: &FeedRequest{}, @@ -156,16 +156,15 @@ func TestFeedRequest(t *testing.T) { expectLength: 0, }, } - srv, te := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + checkTestCase(t, tc) }) } } func TestFeedRequest_Format(t *testing.T) { - tcs := []testRest{ + tcs := []testCase{ { name: "feed geojson", format: "geojson", @@ -192,16 +191,15 @@ func TestFeedRequest_Format(t *testing.T) { }, }, } - srv, te := testRestConfig(t) for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + checkTestCase(t, tc) }) } } func TestFeedRequest_License(t *testing.T) { - testcases := []testRest{ + testcases := []testCase{ { name: "license:share_alike_optional yes", h: FeedRequest{LicenseFilter: LicenseFilter{LicenseShareAlikeOptional: "yes"}}, selector: "feeds.#.onestop_id", @@ -248,10 +246,9 @@ func TestFeedRequest_License(t *testing.T) { expectSelect: []string{"CT", "test-gbfs", "HA", "BA~rt", "CT~rt", "test", "EX"}, }, } - srv, te := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + checkTestCase(t, tc) }) } diff --git a/server/rest/feed_version_download.go b/server/rest/feed_version_download.go index 4bf60751..3c12e7c6 100644 --- a/server/rest/feed_version_download.go +++ b/server/rest/feed_version_download.go @@ -13,6 +13,7 @@ import ( "github.com/interline-io/transitland-lib/tl/request" "github.com/interline-io/transitland-mw/meters" "github.com/interline-io/transitland-server/internal/util" + "github.com/interline-io/transitland-server/model" "github.com/tidwall/gjson" ) @@ -32,7 +33,7 @@ query($feed_onestop_id: String!, $ids: [Int!]) { // Query redirects user to download the given fv from S3 public URL // assuming that redistribution is allowed for the feed. -func feedVersionDownloadLatestHandler(cfg restConfig, w http.ResponseWriter, r *http.Request) { +func feedVersionDownloadLatestHandler(graphqlHandler http.Handler, w http.ResponseWriter, r *http.Request) { key := chi.URLParam(r, "feed_key") gvars := hw{} if key == "" { @@ -45,7 +46,7 @@ func feedVersionDownloadLatestHandler(cfg restConfig, w http.ResponseWriter, r * } // Check if we're allowed to redistribute feed and look up latest feed version - feedResponse, err := makeGraphQLRequest(r.Context(), cfg.srv, latestFeedVersionQuery, gvars) + feedResponse, err := makeGraphQLRequest(r.Context(), graphqlHandler, latestFeedVersionQuery, gvars) if err != nil { http.Error(w, util.MakeJsonError("server error"), http.StatusInternalServerError) return @@ -83,6 +84,8 @@ func feedVersionDownloadLatestHandler(cfg restConfig, w http.ResponseWriter, r * } apiMeter.Meter("feed-version-downloads", 1.0, dims) } + + cfg := model.ForContext(r.Context()) serveFromStorage(w, r, cfg.Storage, fvsha1) } @@ -102,7 +105,7 @@ query($feed_version_sha1:String!, $ids: [Int!]) { // Query redirects user to download the given fv from S3 public URL // assuming that redistribution is allowed for the feed. -func feedVersionDownloadHandler(cfg restConfig, w http.ResponseWriter, r *http.Request) { +func feedVersionDownloadHandler(graphqlHandler http.Handler, w http.ResponseWriter, r *http.Request) { gvars := hw{} key := chi.URLParam(r, "feed_version_key") if key == "" { @@ -114,7 +117,7 @@ func feedVersionDownloadHandler(cfg restConfig, w http.ResponseWriter, r *http.R gvars["feed_version_sha1"] = key } // Check if we're allowed to redistribute feed - checkfv, err := makeGraphQLRequest(r.Context(), cfg.srv, feedVersionFileQuery, gvars) + checkfv, err := makeGraphQLRequest(r.Context(), graphqlHandler, feedVersionFileQuery, gvars) if err != nil { http.Error(w, util.MakeJsonError("server error"), http.StatusInternalServerError) return @@ -159,6 +162,7 @@ func feedVersionDownloadHandler(cfg restConfig, w http.ResponseWriter, r *http.R apiMeter.Meter("feed-version-downloads", 1.0, dims) } + cfg := model.ForContext(r.Context()) serveFromStorage(w, r, cfg.Storage, fvsha1) } diff --git a/server/rest/feed_version_download_test.go b/server/rest/feed_version_download_test.go index dacf6881..8c67312c 100644 --- a/server/rest/feed_version_download_test.go +++ b/server/rest/feed_version_download_test.go @@ -7,21 +7,14 @@ import ( "github.com/interline-io/transitland-mw/auth/ancheck" "github.com/interline-io/transitland-mw/auth/authn" + "github.com/interline-io/transitland-server/internal/testconfig" "github.com/interline-io/transitland-server/internal/testutil" ) func TestFeedVersionDownloadRequest(t *testing.T) { - g, a, ok := testutil.CheckEnv("TL_TEST_STORAGE") - if !ok { - t.Skip(a) - return - } - srv, te := testRestConfig(t) - te.Config.Storage = g - restSrv, err := testRestServer(t, te.Config, srv) - if err != nil { - t.Fatal(err) - } + _, restSrv, _ := testHandlersWithOptions(t, testconfig.Options{ + Storage: testutil.RelPath("tmp"), + }) t.Run("ok", func(t *testing.T) { req, _ := http.NewRequest("GET", "/feed_versions/d2813c293bcfd7a97dde599527ae6c62c98e66c6/download", nil) @@ -112,17 +105,9 @@ func TestFeedVersionDownloadRequest(t *testing.T) { } func TestFeedDownloadLatestRequest(t *testing.T) { - g, a, ok := testutil.CheckEnv("TL_TEST_STORAGE") - if !ok { - t.Skip(a) - return - } - srv, te := testRestConfig(t) - te.Config.Storage = g - restSrv, err := testRestServer(t, te.Config, srv) - if err != nil { - t.Fatal(err) - } + _, restSrv, _ := testHandlersWithOptions(t, testconfig.Options{ + Storage: testutil.RelPath("tmp"), + }) t.Run("ok", func(t *testing.T) { req, _ := http.NewRequest("GET", "/feeds/CT/download_latest_feed_version", nil) diff --git a/server/rest/feed_version_request_test.go b/server/rest/feed_version_request_test.go index 6429073b..56192429 100644 --- a/server/rest/feed_version_request_test.go +++ b/server/rest/feed_version_request_test.go @@ -8,7 +8,7 @@ import ( func TestFeedVersionRequest(t *testing.T) { fv := "d2813c293bcfd7a97dde599527ae6c62c98e66c6" - testcases := []testRest{ + testcases := []testCase{ { name: "basic", h: FeedVersionRequest{}, @@ -127,10 +127,9 @@ func TestFeedVersionRequest(t *testing.T) { expectLength: 0, }, } - srv, te := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + checkTestCase(t, tc) }) } } diff --git a/server/rest/operator_request_test.go b/server/rest/operator_request_test.go index 0c839f9d..82de0439 100644 --- a/server/rest/operator_request_test.go +++ b/server/rest/operator_request_test.go @@ -5,7 +5,7 @@ import ( ) func TestOperatorRequest(t *testing.T) { - testcases := []testRest{ + testcases := []testCase{ { name: "basic", h: OperatorRequest{}, @@ -107,16 +107,15 @@ func TestOperatorRequest(t *testing.T) { expectLength: 0, }, } - srv, te := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + checkTestCase(t, tc) }) } } func TestOperatorRequest_Pagination(t *testing.T) { - testcases := []testRest{ + testcases := []testCase{ { name: "limit:1", h: OperatorRequest{WithCursor: WithCursor{Limit: 1}}, @@ -130,16 +129,15 @@ func TestOperatorRequest_Pagination(t *testing.T) { expectLength: 4, }, } - srv, te := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + checkTestCase(t, tc) }) } } func TestOperatorRequest_License(t *testing.T) { - testcases := []testRest{ + testcases := []testCase{ { name: "license:share_alike_optional yes", h: OperatorRequest{WithCursor: WithCursor{Limit: 10_000}, LicenseFilter: LicenseFilter{LicenseShareAlikeOptional: "yes"}}, @@ -195,10 +193,9 @@ func TestOperatorRequest_License(t *testing.T) { expectSelect: []string{"o-9q9-caltrain", "o-dhv-hillsborougharearegionaltransit", "o-9qs-demotransitauthority"}, }, } - srv, te := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + checkTestCase(t, tc) }) } } diff --git a/server/rest/rest.go b/server/rest/rest.go index 8838ae2e..2a3290d5 100644 --- a/server/rest/rest.go +++ b/server/rest/rest.go @@ -17,7 +17,6 @@ import ( "github.com/interline-io/log" "github.com/interline-io/transitland-mw/auth/ancheck" "github.com/interline-io/transitland-mw/meters" - "github.com/interline-io/transitland-server/config" "github.com/interline-io/transitland-server/internal/util" "github.com/interline-io/transitland-server/model" ) @@ -31,38 +30,36 @@ var MAXLIMIT = 1_000 // MAXRADIUS is the maximum point search radius const MAXRADIUS = 100 * 1000.0 -// restConfig holds the base config and the graphql handler -type restConfig struct { - config.Config - srv http.Handler +type Config struct { + DisableImage bool + RestPrefix string } // NewServer . -func NewServer(cfg config.Config, srv http.Handler) (http.Handler, error) { - restcfg := restConfig{Config: cfg, srv: srv} +func NewServer(cfg Config, graphqlHandler http.Handler) (http.Handler, error) { r := chi.NewRouter() - feedHandler := makeHandler(restcfg, "feeds", func() apiHandler { return &FeedRequest{} }) - feedVersionHandler := makeHandler(restcfg, "feedVersions", func() apiHandler { return &FeedVersionRequest{} }) - agencyHandler := makeHandler(restcfg, "agencies", func() apiHandler { return &AgencyRequest{} }) - routeHandler := makeHandler(restcfg, "routes", func() apiHandler { return &RouteRequest{} }) - tripHandler := makeHandler(restcfg, "trips", func() apiHandler { return &TripRequest{} }) - stopHandler := makeHandler(restcfg, "stops", func() apiHandler { return &StopRequest{} }) - stopDepartureHandler := makeHandler(restcfg, "stopDepartures", func() apiHandler { return &StopDepartureRequest{} }) - operatorHandler := makeHandler(restcfg, "operators", func() apiHandler { return &OperatorRequest{} }) + feedHandler := makeHandler(cfg, graphqlHandler, "feeds", func() apiHandler { return &FeedRequest{} }) + feedVersionHandler := makeHandler(cfg, graphqlHandler, "feedVersions", func() apiHandler { return &FeedVersionRequest{} }) + agencyHandler := makeHandler(cfg, graphqlHandler, "agencies", func() apiHandler { return &AgencyRequest{} }) + routeHandler := makeHandler(cfg, graphqlHandler, "routes", func() apiHandler { return &RouteRequest{} }) + tripHandler := makeHandler(cfg, graphqlHandler, "trips", func() apiHandler { return &TripRequest{} }) + stopHandler := makeHandler(cfg, graphqlHandler, "stops", func() apiHandler { return &StopRequest{} }) + stopDepartureHandler := makeHandler(cfg, graphqlHandler, "stopDepartures", func() apiHandler { return &StopDepartureRequest{} }) + operatorHandler := makeHandler(cfg, graphqlHandler, "operators", func() apiHandler { return &OperatorRequest{} }) r.HandleFunc("/feeds.{format}", feedHandler) r.HandleFunc("/feeds", feedHandler) r.HandleFunc("/feeds/{feed_key}.{format}", feedHandler) r.HandleFunc("/feeds/{feed_key}", feedHandler) - r.Handle("/feeds/{feed_key}/download_latest_feed_version", ancheck.RoleRequired("tl_download_fv_current")(makeHandlerFunc(restcfg, "feedVersionDownloadLatest", feedVersionDownloadLatestHandler))) + r.Handle("/feeds/{feed_key}/download_latest_feed_version", ancheck.RoleRequired("tl_download_fv_current")(makeHandlerFunc(graphqlHandler, "feedVersionDownloadLatest", feedVersionDownloadLatestHandler))) r.HandleFunc("/feed_versions.{format}", feedVersionHandler) r.HandleFunc("/feed_versions", feedVersionHandler) r.HandleFunc("/feed_versions/{feed_version_key}.{format}", feedVersionHandler) r.HandleFunc("/feed_versions/{feed_version_key}", feedVersionHandler) r.HandleFunc("/feeds/{feed_key}/feed_versions", feedVersionHandler) - r.Handle("/feed_versions/{feed_version_key}/download", ancheck.RoleRequired("tl_download_fv_historic")(makeHandlerFunc(restcfg, "feedVersionDownload", feedVersionDownloadHandler))) + r.Handle("/feed_versions/{feed_version_key}/download", ancheck.RoleRequired("tl_download_fv_historic")(makeHandlerFunc(graphqlHandler, "feedVersionDownload", feedVersionDownloadHandler))) r.HandleFunc("/agencies.{format}", agencyHandler) r.HandleFunc("/agencies", agencyHandler) @@ -186,7 +183,7 @@ func queryToMap(vars url.Values) map[string]string { } // makeHandler wraps an apiHandler into an HandlerFunc and performs common checks. -func makeHandler(cfg restConfig, handlerName string, f func() apiHandler) http.HandlerFunc { +func makeHandler(cfg Config, graphqlHandler http.Handler, handlerName string, f func() apiHandler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ent := f() opts := queryToMap(r.URL.Query()) @@ -239,7 +236,7 @@ func makeHandler(cfg restConfig, handlerName string, f func() apiHandler) http.H } // Make the request - response, err := makeRequest(r.Context(), cfg, ent, format, r.URL) + response, err := makeRequest(r.Context(), cfg, graphqlHandler, ent, format, r.URL) if err != nil { http.Error(w, util.MakeJsonError(err.Error()), http.StatusInternalServerError) return @@ -298,9 +295,9 @@ func makeGraphQLRequest(ctx context.Context, srv http.Handler, query string, var } // makeRequest prepares an apiHandler and makes the request. -func makeRequest(ctx context.Context, cfg restConfig, ent apiHandler, format string, u *url.URL) ([]byte, error) { +func makeRequest(ctx context.Context, cfg Config, graphqlHandler http.Handler, ent apiHandler, format string, u *url.URL) ([]byte, error) { query, vars := ent.Query() - response, err := makeGraphQLRequest(ctx, cfg.srv, query, vars) + response, err := makeGraphQLRequest(ctx, graphqlHandler, query, vars) if err != nil { vjson, _ := json.Marshal(vars) log.Error().Err(err).Str("query", query).Str("vars", string(vjson)).Msgf("graphql request failed") @@ -375,12 +372,12 @@ func renderGeojsonl(response map[string]any) ([]byte, error) { return ret, nil } -func makeHandlerFunc(cfg restConfig, handlerName string, f func(restConfig, http.ResponseWriter, *http.Request)) http.HandlerFunc { +func makeHandlerFunc(graphqlHandler http.Handler, handlerName string, f func(http.Handler, http.ResponseWriter, *http.Request)) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if apiMeter := meters.ForContext(r.Context()); apiMeter != nil { apiMeter.AddDimension("rest", "handler", handlerName) } - f(cfg, w, r) + f(graphqlHandler, w, r) } } diff --git a/server/rest/rest_test.go b/server/rest/rest_test.go index 5c47786d..90fb5e82 100644 --- a/server/rest/rest_test.go +++ b/server/rest/rest_test.go @@ -7,11 +7,8 @@ import ( "net/http" "os" "testing" - "time" - "github.com/interline-io/transitland-server/config" - "github.com/interline-io/transitland-server/internal/clock" - "github.com/interline-io/transitland-server/internal/testfinder" + "github.com/interline-io/transitland-server/internal/testconfig" "github.com/interline-io/transitland-server/internal/testutil" "github.com/interline-io/transitland-server/model" "github.com/interline-io/transitland-server/server/gql" @@ -30,29 +27,7 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } -func testRestConfig(t testing.TB) (http.Handler, model.Finders) { - when, err := time.Parse("2006-01-02T15:04:05", "2018-06-01T00:00:00") - if err != nil { - t.Fatal(err) - } - te := testfinder.Finders(t, &clock.Mock{T: when}, testfinder.DefaultRTJson()) - srv, err := gql.NewServer(te.Config, te.Finder, te.RTFinder, te.GbfsFinder, te.Checker) - if err != nil { - panic(err) - } - return srv, te -} - -func testRestServer(t testing.TB, cfg config.Config, srv http.Handler) (http.Handler, error) { - return NewServer(cfg, srv) -} - -func toJson(m map[string]interface{}) string { - rr, _ := json.Marshal(&m) - return string(rr) -} - -type testRest struct { +type testCase struct { name string h apiHandler format string @@ -62,8 +37,38 @@ type testRest struct { f func(*testing.T, string) } -func testquery(t *testing.T, srv http.Handler, te model.Finders, tc testRest) { - data, err := makeRequest(context.TODO(), restConfig{srv: srv, Config: te.Config}, tc.h, tc.format, nil) +func testHandlersWithOptions(t testing.TB, opts testconfig.Options) (http.Handler, http.Handler, model.Config) { + cfg := testconfig.Config(t, opts) + graphqlHandler, err := gql.NewServer() + if err != nil { + t.Fatal(err) + } + restHandler, err := NewServer(Config{}, graphqlHandler) + if err != nil { + t.Fatal(err) + } + return model.AddConfig(cfg)(graphqlHandler), model.AddConfig(cfg)(restHandler), cfg +} + +func checkTestCase(t *testing.T, tc testCase) { + opts := testconfig.Options{ + When: "2018-06-01T00:00:00", + RTJsons: testconfig.DefaultRTJson(), + } + cfg := testconfig.Config(t, opts) + graphqlHandler, err := gql.NewServer() + if err != nil { + t.Fatal(err) + } + restHandler, err := NewServer(Config{}, graphqlHandler) + if err != nil { + t.Fatal(err) + } + checkTestCaseWithHandlers(t, tc, model.AddConfig(cfg)(graphqlHandler), restHandler) +} + +func checkTestCaseWithHandlers(t *testing.T, tc testCase, graphqlHandler http.Handler, restHandler http.Handler) { + data, err := makeRequest(context.TODO(), Config{}, graphqlHandler, tc.h, tc.format, nil) if err != nil { t.Error(err) return @@ -98,3 +103,8 @@ func testquery(t *testing.T, srv http.Handler, te model.Finders, tc testRest) { t.Errorf("no test performed, check test case") } } + +func toJson(m map[string]interface{}) string { + rr, _ := json.Marshal(&m) + return string(rr) +} diff --git a/server/rest/route_request_test.go b/server/rest/route_request_test.go index 8b952032..ba0a9430 100644 --- a/server/rest/route_request_test.go +++ b/server/rest/route_request_test.go @@ -5,6 +5,7 @@ import ( "strings" "testing" + "github.com/interline-io/transitland-server/internal/testconfig" "github.com/interline-io/transitland-server/model" "github.com/stretchr/testify/assert" "github.com/tidwall/gjson" @@ -13,7 +14,7 @@ import ( func TestRouteRequest(t *testing.T) { routeIds := []string{"1", "12", "14", "15", "16", "17", "19", "20", "24", "25", "275", "30", "31", "32", "33", "34", "35", "36", "360", "37", "38", "39", "400", "42", "45", "46", "48", "5", "51", "6", "60", "7", "75", "8", "9", "96", "97", "570", "571", "572", "573", "574", "800", "PWT", "SKY", "01", "03", "05", "07", "11", "19", "Bu-130", "Li-130", "Lo-130", "TaSj-130", "Gi-130", "Sp-130"} fv := "e535eb2b3b9ac3ef15d82c56575e914575e732e0" - testcases := []testRest{ + testcases := []testCase{ { name: "none", h: RouteRequest{WithCursor: WithCursor{Limit: 1000}}, @@ -116,16 +117,15 @@ func TestRouteRequest(t *testing.T) { }, }, } - srv, te := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + checkTestCase(t, tc) }) } } func TestRouteRequest_Format(t *testing.T) { - tcs := []testRest{ + tcs := []testCase{ { name: "route geojson", format: "geojson", @@ -152,17 +152,16 @@ func TestRouteRequest_Format(t *testing.T) { }, }, } - srv, te := testRestConfig(t) for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + checkTestCase(t, tc) }) } } func TestRouteRequest_Pagination(t *testing.T) { - srv, te := testRestConfig(t) - allEnts, err := te.Finder.FindRoutes(context.Background(), nil, nil, nil, nil) + graphqlHandler, restHandler, cfg := testHandlersWithOptions(t, testconfig.Options{}) + allEnts, err := cfg.Finder.FindRoutes(model.WithConfig(context.Background(), cfg), nil, nil, nil, nil) if err != nil { t.Fatal(err) } @@ -170,7 +169,7 @@ func TestRouteRequest_Pagination(t *testing.T) { for _, ent := range allEnts { allIds = append(allIds, ent.RouteID) } - testcases := []testRest{ + testcases := []testCase{ { name: "limit:1", h: RouteRequest{WithCursor: WithCursor{Limit: 1}}, @@ -209,13 +208,13 @@ func TestRouteRequest_Pagination(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + checkTestCaseWithHandlers(t, tc, graphqlHandler, restHandler) }) } } func TestRouteRequest_License(t *testing.T) { - testcases := []testRest{ + testcases := []testCase{ { name: "license:share_alike_optional yes", h: RouteRequest{WithCursor: WithCursor{Limit: 10_000}, LicenseFilter: LicenseFilter{LicenseShareAlikeOptional: "yes"}}, selector: "routes.#.route_id", @@ -262,10 +261,9 @@ func TestRouteRequest_License(t *testing.T) { expectLength: 51, }, } - srv, te := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + checkTestCase(t, tc) }) } } diff --git a/server/rest/stop_departure_request_test.go b/server/rest/stop_departure_request_test.go index ed3ddc77..dad9d19f 100644 --- a/server/rest/stop_departure_request_test.go +++ b/server/rest/stop_departure_request_test.go @@ -12,7 +12,7 @@ func TestStopDepartureRequest(t *testing.T) { return &v } sid := "s-9q9nfsxn67-fruitvale" - testcases := []testRest{ + testcases := []testCase{ { name: "basic", h: StopDepartureRequest{StopKey: sid}, @@ -144,10 +144,9 @@ func TestStopDepartureRequest(t *testing.T) { // 0, // }, } - srv, te := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + checkTestCase(t, tc) }) } } diff --git a/server/rest/stop_request_test.go b/server/rest/stop_request_test.go index ad5a8cd3..17961baf 100644 --- a/server/rest/stop_request_test.go +++ b/server/rest/stop_request_test.go @@ -5,6 +5,7 @@ import ( "strings" "testing" + "github.com/interline-io/transitland-server/internal/testconfig" "github.com/interline-io/transitland-server/model" "github.com/stretchr/testify/assert" "github.com/tidwall/gjson" @@ -18,7 +19,7 @@ func TestStopRequest(t *testing.T) { caltrainBusStops := []string{"777402", "777403"} _ = caltrainRailStops _ = caltrainBusStops - testcases := []testRest{ + testcases := []testCase{ { name: "basic", h: StopRequest{}, @@ -174,16 +175,15 @@ func TestStopRequest(t *testing.T) { expectLength: 0, }, } - srv, te := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + checkTestCase(t, tc) }) } } func TestStopRequest_AdminCache(t *testing.T) { - tc := testRest{ + tc := testCase{ name: "place", h: StopRequest{StopKey: "BA:FTVL"}, selector: "stops.#.place.adm1_name", @@ -193,19 +193,19 @@ func TestStopRequest_AdminCache(t *testing.T) { type canLoadAdmins interface { LoadAdmins() error } - srv, te := testRestConfig(t) - if v, ok := te.Finder.(canLoadAdmins); !ok { + graphqlHandler, restHandler, cfg := testHandlersWithOptions(t, testconfig.Options{}) + if v, ok := cfg.Finder.(canLoadAdmins); !ok { t.Fatal("finder cant load admins") } else { if err := v.LoadAdmins(); err != nil { t.Fatal(err) } } - testquery(t, srv, te, tc) + checkTestCaseWithHandlers(t, tc, graphqlHandler, restHandler) } func TestStopRequest_Format(t *testing.T) { - tcs := []testRest{ + tcs := []testCase{ { name: "stop geojson", format: "geojson", @@ -232,17 +232,16 @@ func TestStopRequest_Format(t *testing.T) { }, }, } - srv, te := testRestConfig(t) for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + checkTestCase(t, tc) }) } } func TestStopRequest_Pagination(t *testing.T) { - srv, te := testRestConfig(t) - allEnts, err := te.Finder.FindStops(context.Background(), nil, nil, nil, nil) + graphqlHandler, restHandler, cfg := testHandlersWithOptions(t, testconfig.Options{}) + allEnts, err := cfg.Finder.FindStops(model.WithConfig(context.Background(), cfg), nil, nil, nil, nil) if err != nil { t.Fatal(err) } @@ -250,7 +249,7 @@ func TestStopRequest_Pagination(t *testing.T) { for _, ent := range allEnts { allIds = append(allIds, ent.StopID) } - testcases := []testRest{ + testcases := []testCase{ { name: "pagination exists", h: StopRequest{}, @@ -276,13 +275,13 @@ func TestStopRequest_Pagination(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + checkTestCaseWithHandlers(t, tc, graphqlHandler, restHandler) }) } } func TestStopRequest_License(t *testing.T) { - testcases := []testRest{ + testcases := []testCase{ { name: "license:share_alike_optional yes", h: StopRequest{WithCursor: WithCursor{Limit: 10_000}, LicenseFilter: LicenseFilter{LicenseShareAlikeOptional: "yes"}}, @@ -363,10 +362,9 @@ func TestStopRequest_License(t *testing.T) { }, }, } - srv, te := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + checkTestCase(t, tc) }) } } diff --git a/server/rest/trip_request_test.go b/server/rest/trip_request_test.go index 75f8961e..facd0dcf 100644 --- a/server/rest/trip_request_test.go +++ b/server/rest/trip_request_test.go @@ -6,27 +6,32 @@ import ( "strings" "testing" + "github.com/interline-io/transitland-server/internal/testconfig" + "github.com/interline-io/transitland-server/model" "github.com/stretchr/testify/assert" "github.com/tidwall/gjson" ) func TestTripRequest(t *testing.T) { - srv, te := testRestConfig(t) - d, err := makeGraphQLRequest(context.Background(), srv, `query{routes(where:{feed_onestop_id:"BA",route_id:"11"}) {id onestop_id}}`, nil) + graphqlHandler, restHandler, cfg := testHandlersWithOptions(t, testconfig.Options{ + When: "2018-06-01T00:00:00", + RTJsons: testconfig.DefaultRTJson(), + }) + ctx := model.WithConfig(context.Background(), cfg) + d, err := makeGraphQLRequest(ctx, graphqlHandler, `query{routes(where:{feed_onestop_id:"BA",route_id:"11"}) {id onestop_id}}`, nil) if err != nil { t.Error("failed to get route id for tests") } routeId := int(gjson.Get(toJson(d), "routes.0.id").Int()) routeOnestopId := gjson.Get(toJson(d), "routes.0.onestop_id").String() - d2, err := makeGraphQLRequest(context.Background(), srv, `query{trips(where:{trip_id:"5132248WKDY"}){id}}`, nil) + d2, err := makeGraphQLRequest(ctx, graphqlHandler, `query{trips(where:{trip_id:"5132248WKDY"}){id}}`, nil) if err != nil { t.Error("failed to get route id for tests") } tripId := int(gjson.Get(toJson(d2), "trips.0.id").Int()) - fv := "e535eb2b3b9ac3ef15d82c56575e914575e732e0" ctfv := "d2813c293bcfd7a97dde599527ae6c62c98e66c6" - testcases := []testRest{ + testcases := []testCase{ { name: "none", h: TripRequest{}, @@ -179,13 +184,13 @@ func TestTripRequest(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + checkTestCaseWithHandlers(t, tc, graphqlHandler, restHandler) }) } } func TestTripRequest_Format(t *testing.T) { - tcs := []testRest{ + tcs := []testCase{ { name: "trip geojson", format: "geojson", @@ -212,16 +217,15 @@ func TestTripRequest_Format(t *testing.T) { }, }, } - srv, te := testRestConfig(t) for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + checkTestCase(t, tc) }) } } func TestTripRequest_Pagination(t *testing.T) { - testcases := []testRest{ + testcases := []testCase{ { name: "limit:1", h: TripRequest{WithCursor: WithCursor{Limit: 1}}, @@ -251,16 +255,15 @@ func TestTripRequest_Pagination(t *testing.T) { expectLength: 10_000, }, } - srv, te := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + checkTestCase(t, tc) }) } } func TestTripRequest_License(t *testing.T) { - testcases := []testRest{ + testcases := []testCase{ { name: "license:share_alike_optional yes", h: TripRequest{WithCursor: WithCursor{Limit: 100_000}, LicenseFilter: LicenseFilter{LicenseShareAlikeOptional: "yes"}}, selector: "trips.#.trip_id", @@ -307,10 +310,9 @@ func TestTripRequest_License(t *testing.T) { expectLength: 14903, }, } - srv, te := testRestConfig(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - testquery(t, srv, te, tc) + checkTestCase(t, tc) }) } } diff --git a/server/server_cmd.go b/server/server_cmd.go index f8911b78..76e643b0 100644 --- a/server/server_cmd.go +++ b/server/server_cmd.go @@ -28,7 +28,6 @@ import ( "github.com/interline-io/transitland-mw/lmw" "github.com/interline-io/transitland-mw/meters" "github.com/interline-io/transitland-mw/metrics" - "github.com/interline-io/transitland-server/config" "github.com/interline-io/transitland-server/finders/dbfinder" "github.com/interline-io/transitland-server/finders/gbfsfinder" "github.com/interline-io/transitland-server/finders/rtfinder" @@ -44,26 +43,32 @@ import ( ) type Command struct { - Timeout int - Port string - LongQueryDuration int - DisableGraphql bool - DisableRest bool - EnablePlayground bool - EnableAdminApi bool - EnableJobsApi bool - EnableWorkers bool - EnableProfiler bool - EnableRateLimits bool - LoadAdmins bool - QueuePrefix string - SecretsFile string - AuthMiddlewares arrayFlags - metersConfig meters.Config - metricsConfig metrics.Config - AuthConfig ancheck.AuthConfig - CheckerConfig azcheck.CheckerConfig - config.Config + Timeout int + Port string + LongQueryDuration int + DisableGraphql bool + DisableRest bool + EnablePlayground bool + EnableAdminApi bool + EnableJobsApi bool + EnableWorkers bool + EnableProfiler bool + EnableRateLimits bool + LoadAdmins bool + QueuePrefix string + SecretsFile string + Storage string + RTStorage string + ValidateLargeFiles bool + DBURL string + RedisURL string + AuthMiddlewares arrayFlags + metersConfig meters.Config + metricsConfig metrics.Config + AuthConfig ancheck.AuthConfig + CheckerConfig azcheck.CheckerConfig + RestConfig rest.Config + secrets []tl.Secret } func (cmd *Command) Parse(args []string) error { @@ -78,9 +83,9 @@ func (cmd *Command) Parse(args []string) error { fl.StringVar(&cmd.RedisURL, "redisurl", "", "Redis URL (default: $TL_REDIS_URL)") fl.StringVar(&cmd.Storage, "storage", "", "Static storage backend") fl.StringVar(&cmd.RTStorage, "rt-storage", "", "RT storage backend") - fl.StringVar(&cmd.RestPrefix, "rest-prefix", "", "REST prefix for generating pagination links") fl.BoolVar(&cmd.ValidateLargeFiles, "validate-large-files", false, "Allow validation of large files") - fl.BoolVar(&cmd.DisableImage, "disable-image", false, "Disable image generation") + fl.StringVar(&cmd.RestConfig.RestPrefix, "rest-prefix", "", "REST prefix for generating pagination links") + fl.BoolVar(&cmd.RestConfig.DisableImage, "disable-image", false, "Disable image generation") // Server config fl.StringVar(&cmd.Port, "port", "8080", "") @@ -161,16 +166,14 @@ func (cmd *Command) Parse(args []string) error { } secrets = rr.Secrets } - cmd.Config.Secrets = secrets + cmd.secrets = secrets return nil } func (cmd *Command) Run() error { - cfg := cmd.Config - // Open database var db sqlx.Ext - dbx, err := dbutil.OpenDB(cfg.DBURL) + dbx, err := dbutil.OpenDB(cmd.DBURL) if err != nil { return err } @@ -182,7 +185,7 @@ func (cmd *Command) Run() error { // Open redis var redisClient *redis.Client if cmd.RedisURL != "" { - rOpts, err := getRedisOpts(cfg.RedisURL) + rOpts, err := getRedisOpts(cmd.RedisURL) if err != nil { return err } @@ -197,12 +200,10 @@ func (cmd *Command) Run() error { } // Create Finder - var dbFinder model.Finder - f := dbfinder.NewFinder(db, checker) + dbFinder := dbfinder.NewFinder(db, checker) if cmd.LoadAdmins { - f.LoadAdmins() + dbFinder.LoadAdmins() } - dbFinder = f // Create RTFinder, GBFSFinder var rtFinder model.RTFinder @@ -220,6 +221,18 @@ func (cmd *Command) Run() error { jobQueue = jobs.NewLocalJobs() } + // Setup config + cfg := model.Config{ + Finder: dbFinder, + RTFinder: rtFinder, + GbfsFinder: gbfsFinder, + Checker: checker, + Secrets: cmd.secrets, + Storage: cmd.Storage, + RTStorage: cmd.RTStorage, + ValidateLargeFiles: cmd.ValidateLargeFiles, + } + // Setup metrics metricProvider, err := metrics.GetProvider(cmd.metricsConfig) if err != nil { @@ -245,6 +258,9 @@ func (cmd *Command) Run() error { AllowCredentials: true, })) + // Finders config + root.Use(model.AddConfig(cfg)) + // Setup user middleware for _, k := range cmd.AuthMiddlewares { if userMiddleware, err := ancheck.GetUserMiddleware(k, cmd.AuthConfig, redisClient); err != nil { @@ -271,7 +287,8 @@ func (cmd *Command) Run() error { } // GraphQL API - graphqlServer, err := gql.NewServer(cfg, dbFinder, rtFinder, gbfsFinder, checker) + + graphqlServer, err := gql.NewServer() if err != nil { return err } @@ -287,7 +304,7 @@ func (cmd *Command) Run() error { // REST API if !cmd.DisableRest { - restServer, err := rest.NewServer(cfg, graphqlServer) + restServer, err := rest.NewServer(cmd.RestConfig, graphqlServer) if err != nil { return err } @@ -321,12 +338,8 @@ func (cmd *Command) Run() error { // Start workers/api jobWorkers := 8 jobOptions := jobs.JobOptions{ - Logger: log.Logger, - JobQueue: jobQueue, - Finder: dbFinder, - RTFinder: rtFinder, - GbfsFinder: gbfsFinder, - Config: cfg, + Logger: log.Logger, + JobQueue: jobQueue, } // Add metrics // jobQueue.Use(metrics.NewJobMiddleware("", metricProvider.NewJobMetric("default"))) @@ -340,7 +353,7 @@ func (cmd *Command) Run() error { } if cmd.EnableJobsApi { log.Infof("Enabling job api") - jobServer, err := workers.NewServer(cfg, "", jobWorkers, jobOptions) + jobServer, err := workers.NewServer("", jobWorkers, jobOptions) if err != nil { return err } diff --git a/test_setup.sh b/test_setup.sh index 43c46e26..0229a965 100755 --- a/test_setup.sh +++ b/test_setup.sh @@ -1,6 +1,6 @@ #!/bin/bash # Remove import files -TL_TEST_STORAGE="${TL_TEST_STORAGE:-tmp}" +TL_TEST_STORAGE="${PWD}/tmp" mkdir -p "${TL_TEST_STORAGE}"; rm ${TL_TEST_STORAGE}/*.zip # export TL_LOG=debug (cd cmd/tlserver && go install .) diff --git a/workers/fetch_enqueue_worker.go b/workers/fetch_enqueue_worker.go index 75e96686..100a109f 100644 --- a/workers/fetch_enqueue_worker.go +++ b/workers/fetch_enqueue_worker.go @@ -16,10 +16,11 @@ type FetchEnqueueWorker struct { } func (w *FetchEnqueueWorker) Run(ctx context.Context, job jobs.Job) error { + cfg := model.ForContext(ctx) + db := cfg.Finder.DBX() opts := job.Opts - db := opts.Finder.DBX() now := time.Now().In(time.UTC) - feeds, err := job.Opts.Finder.FindFeeds(ctx, nil, nil, nil, &model.FeedFilter{}) + feeds, err := cfg.Finder.FindFeeds(ctx, nil, nil, nil, &model.FeedFilter{}) if err != nil { return err } diff --git a/workers/gbfs_fetch_worker.go b/workers/gbfs_fetch_worker.go index afd82dd8..e1a473b0 100644 --- a/workers/gbfs_fetch_worker.go +++ b/workers/gbfs_fetch_worker.go @@ -19,8 +19,9 @@ type GbfsFetchWorker struct { } func (w *GbfsFetchWorker) Run(ctx context.Context, job jobs.Job) error { + cfg := model.ForContext(ctx) log := job.Opts.Logger.With().Str("feed_id", w.FeedID).Str("url", w.Url).Logger() - gfeeds, err := job.Opts.Finder.FindFeeds(ctx, nil, nil, nil, &model.FeedFilter{OnestopID: &w.FeedID}) + gfeeds, err := cfg.Finder.FindFeeds(ctx, nil, nil, nil, &model.FeedFilter{OnestopID: &w.FeedID}) if err != nil { log.Error().Err(err).Msg("gbfsfetch worker: error loading source feed") return err @@ -40,7 +41,7 @@ func (w *GbfsFetchWorker) Run(ctx context.Context, job jobs.Job) error { opts.FeedURL = w.Url } feeds, result, err := gbfs.Fetch( - tldb.NewPostgresAdapterFromDBX(job.Opts.Finder.DBX()), + tldb.NewPostgresAdapterFromDBX(cfg.Finder.DBX()), opts, ) if err != nil { @@ -54,7 +55,7 @@ func (w *GbfsFetchWorker) Run(ctx context.Context, job jobs.Job) error { for _, feed := range feeds { if feed.SystemInformation != nil { key := fmt.Sprintf("%s:%s", w.FeedID, feed.SystemInformation.Language.Val) - job.Opts.GbfsFinder.AddData(ctx, key, feed) + cfg.GbfsFinder.AddData(ctx, key, feed) } } log.Info().Msg("gbfs fetch worker: success") diff --git a/workers/gbfs_fetch_worker_test.go b/workers/gbfs_fetch_worker_test.go index 091e967f..a6936244 100644 --- a/workers/gbfs_fetch_worker_test.go +++ b/workers/gbfs_fetch_worker_test.go @@ -6,7 +6,7 @@ import ( "testing" "github.com/interline-io/transitland-server/internal/gbfs" - "github.com/interline-io/transitland-server/internal/testfinder" + "github.com/interline-io/transitland-server/internal/testconfig" "github.com/interline-io/transitland-server/internal/testutil" "github.com/interline-io/transitland-server/jobs" "github.com/stretchr/testify/assert" @@ -18,22 +18,20 @@ func TestGbfsFetchWorker(t *testing.T) { ts := httptest.NewServer(&gbfs.TestGbfsServer{Language: "en", Path: testutil.RelPath("test/data/gbfs")}) defer ts.Close() - testfinder.FindersTxRollback(t, nil, nil, func(te model.Finders) { + testconfig.ConfigTxRollback(t, testconfig.Options{}, func(cfg model.Config) { job := jobs.Job{} - job.Opts.Finder = te.Finder - job.Opts.RTFinder = te.RTFinder - job.Opts.GbfsFinder = te.GbfsFinder w := GbfsFetchWorker{ Url: ts.URL + "/gbfs.json", FeedID: "test-gbfs", } - err := w.Run(context.Background(), job) + ctx := model.WithConfig(context.Background(), cfg) + err := w.Run(ctx, job) if err != nil { t.Fatal(err) } // Test - bikes, err := te.GbfsFinder.FindBikes( - context.Background(), + bikes, err := cfg.GbfsFinder.FindBikes( + ctx, nil, &model.GbfsBikeRequest{ Near: &model.PointRadius{ diff --git a/workers/rt_fetch_worker.go b/workers/rt_fetch_worker.go index f2833feb..2e515beb 100644 --- a/workers/rt_fetch_worker.go +++ b/workers/rt_fetch_worker.go @@ -17,7 +17,7 @@ type RTFetchWorker struct { func (w *RTFetchWorker) Run(ctx context.Context, job jobs.Job) error { log := job.Opts.Logger.With().Str("target", w.Target).Str("source_feed_id", w.SourceFeedID).Str("source_type", w.SourceType).Str("url", w.Url).Logger() - err := actions.RTFetch(ctx, job.Opts.Config, job.Opts.Finder, job.Opts.RTFinder, w.Target, w.SourceFeedID, w.Url, w.SourceType, nil) + err := actions.RTFetch(ctx, w.Target, w.SourceFeedID, w.Url, w.SourceType) if err != nil { log.Error().Err(err).Msg("rtfetch worker: request failed") return err diff --git a/workers/static_fetch_worker.go b/workers/static_fetch_worker.go index 08db35f5..b0ba4b9c 100644 --- a/workers/static_fetch_worker.go +++ b/workers/static_fetch_worker.go @@ -16,7 +16,7 @@ type StaticFetchWorker struct { func (w *StaticFetchWorker) Run(ctx context.Context, job jobs.Job) error { log := job.Opts.Logger.With().Str("feed_id", w.FeedID).Str("feed_url", w.FeedUrl).Logger() - if result, err := actions.StaticFetch(ctx, job.Opts.Config, job.Opts.Finder, w.FeedID, nil, w.FeedUrl, nil); err != nil { + if result, err := actions.StaticFetch(ctx, w.FeedID, nil, w.FeedUrl); err != nil { log.Error().Err(err).Msg("staticfetch worker: request failed") return err } else if result.FetchError != nil { diff --git a/workers/workers.go b/workers/workers.go index 055d5644..9f06bf18 100644 --- a/workers/workers.go +++ b/workers/workers.go @@ -7,7 +7,6 @@ import ( "net/http" "github.com/go-chi/chi/v5" - "github.com/interline-io/transitland-server/config" "github.com/interline-io/transitland-server/internal/util" "github.com/interline-io/transitland-server/jobs" ) @@ -44,7 +43,7 @@ func GetWorker(job jobs.Job) (jobs.JobWorker, error) { } // NewServer creates a simple api for submitting and running jobs. -func NewServer(cfg config.Config, queueName string, workers int, jo jobs.JobOptions) (http.Handler, error) { +func NewServer(queueName string, workers int, jo jobs.JobOptions) (http.Handler, error) { r := chi.NewRouter() r.HandleFunc("/add", wrapHandler(addJobRequest, jo)) r.HandleFunc("/run", wrapHandler(runJobRequest, jo))