Skip to content

Commit

Permalink
add URLParam selfTransports, that default is true, and if you want fa…
Browse files Browse the repository at this point in the history
…lse should pass it with hide value, like ?selfTransports=hide&...
  • Loading branch information
mrpalide committed Jan 13, 2025
1 parent 06b3a09 commit 2604956
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 7 deletions.
2 changes: 1 addition & 1 deletion pkg/route-finder/store/graph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func (m *mockStore) GetTransportsByEdge(_ context.Context, edgePK cipher.PubKey)
func (m *mockStore) GetNumberOfTransports(context.Context) (map[network.Type]int, error) {
return nil, nil
}
func (m *mockStore) GetAllTransports(context.Context) ([]*transport.Entry, error) {
func (m *mockStore) GetAllTransports(context.Context, bool) ([]*transport.Entry, error) {
return nil, nil
}
func (m *mockStore) Close() {}
Expand Down
2 changes: 1 addition & 1 deletion pkg/transport-discovery/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ func TestGETAllTransports(t *testing.T) {
require.Len(t, resp, 2)

t.Run("Persistence", func(t *testing.T) {
found, err := mock.GetAllTransports(ctx)
found, err := mock.GetAllTransports(ctx, true)
require.NoError(t, err)
for i, f := range found {
if f.ID == resp[i].ID {
Expand Down
8 changes: 6 additions & 2 deletions pkg/transport-discovery/api/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,12 @@ func (api *API) getTransportByEdge(w http.ResponseWriter, r *http.Request) {
}

func (api *API) getAllTransports(w http.ResponseWriter, r *http.Request) {

entries, err := api.store.GetAllTransports(r.Context())
selfTransportsParam := chi.URLParam(r, "selfTransports")
selfTransports := true
if selfTransportsParam == "hide" {
selfTransports = false
}
entries, err := api.store.GetAllTransports(r.Context(), selfTransports)
if err != nil {
if err != store.ErrTransportNotFound {
api.log(r).WithError(err).Error("Error getting transports")
Expand Down
7 changes: 6 additions & 1 deletion pkg/transport-discovery/store/memory_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,16 @@ func (s *memStore) GetNumberOfTransports(context.Context) (map[network.Type]int,
return response, nil
}

func (s *memStore) GetAllTransports(context.Context) ([]*transport.Entry, error) {
func (s *memStore) GetAllTransports(_ context.Context, selfTransports bool) ([]*transport.Entry, error) {
s.mu.Lock()
defer s.mu.Unlock()
var response []*transport.Entry
for _, entry := range s.transports {
if !selfTransports {
if entry.Edges[0] == entry.Edges[1] {
continue
}
}
response = append(response, entry)
}
return response, nil
Expand Down
7 changes: 6 additions & 1 deletion pkg/transport-discovery/store/postgres_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func (s *postgresStore) GetNumberOfTransports(context.Context) (map[network.Type
return response, nil
}

func (s *postgresStore) GetAllTransports(context.Context) ([]*transport.Entry, error) {
func (s *postgresStore) GetAllTransports(_ context.Context, selfTransports bool) ([]*transport.Entry, error) {
var tpRecords []Transport
if err := s.client.Find(&tpRecords).Error; err != nil {
return nil, ErrTransportNotFound
Expand All @@ -115,6 +115,11 @@ func (s *postgresStore) GetAllTransports(context.Context) ([]*transport.Entry, e
if err != nil {
return nil, err
}
if !selfTransports {
if entry.Edges[0] == entry.Edges[1] {
continue
}
}
entries = append(entries, &entry)
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/transport-discovery/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type TransportStore interface {
GetTransportByID(context.Context, uuid.UUID) (*transport.Entry, error)
GetTransportsByEdge(context.Context, cipher.PubKey) ([]*transport.Entry, error)
GetNumberOfTransports(context.Context) (map[network.Type]int, error)
GetAllTransports(context.Context) ([]*transport.Entry, error)
GetAllTransports(context.Context, bool) ([]*transport.Entry, error)
Close()
}

Expand Down

0 comments on commit 2604956

Please sign in to comment.