Skip to content

Commit

Permalink
Merge branch 'main' into eric/ics-compat
Browse files Browse the repository at this point in the history
  • Loading branch information
Eric-Warehime authored Apr 10, 2024
2 parents 9392b2b + e2eb670 commit ebbd8b0
Show file tree
Hide file tree
Showing 8 changed files with 305 additions and 8 deletions.
48 changes: 48 additions & 0 deletions oracle/config/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,48 @@ type APIConfig struct {
type Endpoint struct {
// URL is the URL that is used to fetch data from the API.
URL string `json:"url"`

// Authentication holds all data necessary for an API provider to authenticate with
// an endpoint.
Authentication Authentication `json:"authentication"`
}

// ValidateBasic performs basic validation of the API endpoint.
func (e Endpoint) ValidateBasic() error {
if len(e.URL) == 0 {
return fmt.Errorf("endpoint url cannot be empty")
}

return e.Authentication.ValidateBasic()
}

// Authentication holds all data necessary for an API provider to authenticate with an
// endpoint.
type Authentication struct {
// HTTPHeaderAPIKey is the API-key that will be set under the X-Api-Key header
APIKey string `json:"apiKey"`

// APIKeyHeader is the header that will be used to set the API key.
APIKeyHeader string `json:"apiKeyHeader"`
}

// Enabled returns true if the authentication is enabled.
func (a Authentication) Enabled() bool {
return a.APIKey != "" && a.APIKeyHeader != ""
}

// ValidateBasic performs basic validation of the API authentication. Specifically, the APIKey + APIKeyHeader
// must be set atomically.
func (a Authentication) ValidateBasic() error {
if a.APIKey != "" && a.APIKeyHeader == "" {
return fmt.Errorf("api key header cannot be empty when api key is set")
}

if a.APIKey == "" && a.APIKeyHeader != "" {
return fmt.Errorf("api key cannot be empty when api key header is set")
}

return nil
}

// ValidateBasic performs basic validation of the API config.
Expand Down Expand Up @@ -78,5 +120,11 @@ func (c *APIConfig) ValidateBasic() error {
return fmt.Errorf("batch size cannot be set for atomic providers")
}

for _, e := range c.Endpoints {
if err := e.ValidateBasic(); err != nil {
return err
}
}

return nil
}
81 changes: 81 additions & 0 deletions oracle/config/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,87 @@ func TestAPIConfig(t *testing.T) {
},
expectedErr: false,
},
{
name: "bad config with invalid endpoint (no url)",
config: config.APIConfig{
Enabled: true,
Timeout: time.Second,
Interval: time.Second,
ReconnectTimeout: time.Second,
MaxQueries: 1,
Name: "test",
Endpoints: []config.Endpoint{
{
URL: "",
},
},
BatchSize: 1,
},
expectedErr: true,
},
{
name: "bad config with invalid endpoint authentication (HTTP header field missing)",
config: config.APIConfig{
Enabled: true,
Timeout: time.Second,
Interval: time.Second,
ReconnectTimeout: time.Second,
MaxQueries: 1,
Name: "test",
Endpoints: []config.Endpoint{
{
URL: "http://test.com",
Authentication: config.Authentication{
APIKey: "test",
},
},
},
BatchSize: 1,
},
expectedErr: true,
},
{
name: "bad config with invalid endpoint authentication (API-key field missing)",
config: config.APIConfig{
Enabled: true,
Timeout: time.Second,
Interval: time.Second,
ReconnectTimeout: time.Second,
MaxQueries: 1,
Name: "test",
Endpoints: []config.Endpoint{
{
URL: "http://test.com",
Authentication: config.Authentication{
APIKeyHeader: "test",
},
},
},
BatchSize: 1,
},
expectedErr: true,
},
{
name: "good config with valid endpoint",
config: config.APIConfig{
Enabled: true,
Timeout: time.Second,
Interval: time.Second,
ReconnectTimeout: time.Second,
MaxQueries: 1,
Name: "test",
Endpoints: []config.Endpoint{
{
URL: "http://test.com",
Authentication: config.Authentication{
APIKey: "test",
APIKeyHeader: "X-API-KEY",
},
},
},
BatchSize: 1,
},
},
}

for _, tc := range testCases {
Expand Down
31 changes: 31 additions & 0 deletions pkg/http/round_tripper_with_headers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package http

import (
"net/http"
)

// RoundTripperWithHeaders is a round tripper that adds headers to the request.
type RoundTripperWithHeaders struct {
// Headers is the map of headers to add to the request.
headers map[string]string

// Next is the next round tripper in the chain.
next http.RoundTripper
}

// NewRoundTripperWithHeaders creates a new RoundTripperWithHeaders.
func NewRoundTripperWithHeaders(headers map[string]string, next http.RoundTripper) *RoundTripperWithHeaders {
return &RoundTripperWithHeaders{
headers: headers,
next: next,
}
}

// RoundTrip updates the Requests' headers with the headers specified in the constructor, and calls the underlying RoundTripper.
func (r *RoundTripperWithHeaders) RoundTrip(req *http.Request) (*http.Response, error) {
for k, v := range r.headers {
req.Header.Set(k, v)
}

return r.next.RoundTrip(req)
}
47 changes: 47 additions & 0 deletions pkg/http/round_tripper_with_headers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package http_test

import (
"fmt"
"net/http"
"testing"

"github.com/stretchr/testify/require"

slinkyhttp "github.com/skip-mev/slinky/pkg/http"
)

func TestRoundTripperWithHeaders(t *testing.T) {
expectedHeaderFields := map[string]string{
"X-Api-Key": "test",
}

rt := &customRoundTripper{
expectedHeaderFields: expectedHeaderFields,
}

rtWithHeaders := slinkyhttp.NewRoundTripperWithHeaders(expectedHeaderFields, rt)

client := &http.Client{
Transport: rtWithHeaders,
}

req, err := http.NewRequest(http.MethodGet, "http://test.com", nil)
require.NoError(t, err)

// Make the request
_, err = client.Do(req)
require.NoError(t, err)
}

type customRoundTripper struct {
expectedHeaderFields map[string]string
}

func (c *customRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
for k, v := range c.expectedHeaderFields {
if req.Header.Get(k) != v {
return nil, fmt.Errorf("expected header %s to be %s, got %s", k, v, req.Header.Get(k))
}
}
return &http.Response{}, nil
}
40 changes: 36 additions & 4 deletions providers/apis/defi/raydium/multi_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@ package raydium
import (
"context"
"fmt"
"net/http"
"sync"

"github.com/gagliardetto/solana-go"
"github.com/gagliardetto/solana-go/rpc"
"github.com/gagliardetto/solana-go/rpc/jsonrpc"
"go.uber.org/zap"

oracleconfig "github.com/skip-mev/slinky/oracle/config"
slinkyhttp "github.com/skip-mev/slinky/pkg/http"
)

// MultiJSONRPCClient is an implementation of the SolanaJSONRPCClient interface that delegates
Expand All @@ -30,13 +33,42 @@ func NewMultiJSONRPCClient(clients []SolanaJSONRPCClient, logger *zap.Logger) *M
}

// NewMultiJSONRPCClientFromEndpoints creates a new MultiJSONRPCClient from a list of endpoints.
func NewMultiJSONRPCClientFromEndpoints(endpoints []oracleconfig.Endpoint, logger *zap.Logger) *MultiJSONRPCClient {
func NewMultiJSONRPCClientFromEndpoints(endpoints []oracleconfig.Endpoint, logger *zap.Logger) (*MultiJSONRPCClient, error) {
clients := make([]SolanaJSONRPCClient, len(endpoints))

var err error
for i := range endpoints {
client := rpc.New(endpoints[i].URL)
clients[i] = client
clients[i], err = solanaClientFromEndpoint(endpoints[i])
if err != nil {
return nil, fmt.Errorf("failed to create solana client from endpoint: %w", err)
}
}

return NewMultiJSONRPCClient(clients, logger), nil
}

// solanaClientFromEndpoint creates a new SolanaJSONRPCClient from an endpoint.
func solanaClientFromEndpoint(endpoint oracleconfig.Endpoint) (SolanaJSONRPCClient, error) {
// fail if the endpoint is invalid
if err := endpoint.ValidateBasic(); err != nil {
return nil, fmt.Errorf("invalid endpoint %v: %w", endpoint, err)
}

// if authentication is enabled
if endpoint.Authentication.Enabled() {
transport := slinkyhttp.NewRoundTripperWithHeaders(map[string]string{
endpoint.Authentication.APIKeyHeader: endpoint.Authentication.APIKey,
}, http.DefaultTransport)

client := rpc.NewWithCustomRPCClient(jsonrpc.NewClientWithOpts(endpoint.URL, &jsonrpc.RPCClientOpts{
HTTPClient: &http.Client{
Transport: transport,
},
}))

return client, nil
}
return NewMultiJSONRPCClient(clients, logger)
return rpc.New(endpoint.URL), nil
}

// GetMultipleAccountsWithOpts delegates the request to all underlying clients and applies a filter
Expand Down
30 changes: 30 additions & 0 deletions providers/apis/defi/raydium/multi_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package raydium_test
import (
"context"
"fmt"
"strings"
"testing"
"time"

Expand All @@ -11,6 +12,7 @@ import (
"github.com/stretchr/testify/require"
"go.uber.org/zap"

oracleconfig "github.com/skip-mev/slinky/oracle/config"
"github.com/skip-mev/slinky/providers/apis/defi/raydium"
"github.com/skip-mev/slinky/providers/apis/defi/raydium/mocks"
)
Expand All @@ -22,6 +24,34 @@ func TestMultiJSONRPCClient(t *testing.T) {
client3 := mocks.NewSolanaJSONRPCClient(t)
client := raydium.NewMultiJSONRPCClient([]raydium.SolanaJSONRPCClient{client1, client2, client3}, zap.NewNop())

t.Run("test MultiJSONRPCClient From endpoints", func(t *testing.T) {
t.Run("invalid endpoint", func(t *testing.T) {
endpoint := oracleconfig.Endpoint{}

_, err := raydium.NewMultiJSONRPCClientFromEndpoints([]oracleconfig.Endpoint{endpoint}, zap.NewNop())
require.Error(t, err)
require.True(t, strings.Contains(err.Error(), "invalid endpoint"))
})

t.Run("endpoints with / wo authentication", func(t *testing.T) {
endpoints := []oracleconfig.Endpoint{
{
URL: "http://localhost:8899",
},
{
URL: "http://localhost:8899/",
Authentication: oracleconfig.Authentication{
APIKey: "test",
APIKeyHeader: "X-API-Key",
},
},
}

_, err := raydium.NewMultiJSONRPCClientFromEndpoints(endpoints, zap.NewNop())
require.NoError(t, err)
})
})

// test adherence to the context
t.Run("test failures in underlying client", func(t *testing.T) {
accounts := []solana.PublicKey{{}}
Expand Down
13 changes: 9 additions & 4 deletions providers/apis/defi/raydium/price_fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,16 @@ func NewAPIPriceFetcher(
// use a multi-client if multiple endpoints are provided
if len(config.Endpoints) > 0 {
if len(config.Endpoints) > 1 {
client, err := NewMultiJSONRPCClientFromEndpoints(
config.Endpoints,
logger.With(zap.String("raydium_multi_client", Name)),
)
if err != nil {
return nil, fmt.Errorf("error creating multi-client: %w", err)
}

opts = append(opts, WithSolanaClient(
NewMultiJSONRPCClientFromEndpoints(
config.Endpoints,
logger.With(zap.String("raydium_multi_client", Name)),
),
client,
))
} else {
config.URL = config.Endpoints[0].URL
Expand Down
23 changes: 23 additions & 0 deletions providers/apis/defi/raydium/price_fetcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,29 @@ func TestProviderInit(t *testing.T) {
require.True(t, strings.Contains(err.Error(), "config for raydium is invalid"))
})

t.Run("config has invalid endpoints", func(t *testing.T) {
cfg := oracleconfig.APIConfig{
Enabled: true,
MaxQueries: 0,
Endpoints: []oracleconfig.Endpoint{
{
URL: "", // invalid url
},
{
URL: "https://raydium.io",
},
},
}

_, err := raydium.NewAPIPriceFetcher(
oracletypes.ProviderMarketMap{},
cfg,
zap.NewNop(),
)

require.True(t, strings.Contains(err.Error(), "error creating multi-client"))
})

t.Run("market config fails validate basic", func(t *testing.T) {
// valid config
cfg := oracleconfig.APIConfig{
Expand Down

0 comments on commit ebbd8b0

Please sign in to comment.