Skip to content

Commit

Permalink
Allow dynamic configuration of tls client config
Browse files Browse the repository at this point in the history
  • Loading branch information
everesio committed Jan 4, 2024
1 parent 27a859a commit 2903d43
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 16 deletions.
4 changes: 3 additions & 1 deletion handler/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,9 @@ func newHttpServer(ctx context.Context, codec HttpCodec, httpAddr string, proxyA
// forward all WS events to http server
mux.ServeHTTP(w, r)
}, codec), slog.Default())
client := ws.NewClient(ctx, u.String(), wsHandler, codec.MessageCodec(), ws.WithClientTLSConfig(clientTLSConfig))
client := ws.NewClient(ctx, u.String(), wsHandler, codec.MessageCodec(), ws.WithClientTLSConfigFunc(func() *tls.Config {
return clientTLSConfig
}))
client.Start()

mux.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
Expand Down
31 changes: 16 additions & 15 deletions ws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ type Client struct {

codec Codec[*message.Message]

logger *slog.Logger
clientID string
tlsConfig *tls.Config
logger *slog.Logger
clientID string
tlsConfigFunc func() *tls.Config
}

type ClientOption func(*Client)
Expand All @@ -45,22 +45,22 @@ func WithClientLogger(logger *slog.Logger) ClientOption {
}
}

func WithClientTLSConfig(tlsConfig *tls.Config) ClientOption {
func WithClientTLSConfigFunc(tlsConfigFunc func() *tls.Config) ClientOption {
return func(c *Client) {
c.tlsConfig = tlsConfig
c.tlsConfigFunc = tlsConfigFunc
}
}

func NewClient(parent context.Context, urlStr string, handler EventHandler, codec Codec[*message.Message], opts ...ClientOption) *Client {
client := &Client{
pool: NewPool(),
parent: parent,
urlStr: urlStr,
handler: handler,
codec: codec,
logger: slog.Default(),
clientID: "",
tlsConfig: nil,
pool: NewPool(),
parent: parent,
urlStr: urlStr,
handler: handler,
codec: codec,
logger: slog.Default(),
clientID: "",
tlsConfigFunc: nil,
}
for _, opt := range opts {
opt(client)
Expand Down Expand Up @@ -114,8 +114,9 @@ func (c *Client) connect() (*Conn, error) {
c.logger.Info("connecting ws to " + c.urlStr)

dialer := *websocket.DefaultDialer
dialer.TLSClientConfig = c.tlsConfig

if c.tlsConfigFunc != nil {
dialer.TLSClientConfig = c.tlsConfigFunc()
}
requestHeader := make(http.Header)
requestHeader.Add(HeaderClientId, c.clientID)

Expand Down

0 comments on commit 2903d43

Please sign in to comment.