Skip to content

Commit

Permalink
fix timeout locks
Browse files Browse the repository at this point in the history
  • Loading branch information
hgiasac committed Feb 14, 2025
1 parent 68ab3a7 commit 0cafac8
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 32 deletions.
66 changes: 66 additions & 0 deletions example/subscription/subscription_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -556,3 +556,69 @@ func TestSubscription_LifeCycleEvents(t *testing.T) {
func TestSubscription_WithSyncMode(t *testing.T) {
testSubscription_LifeCycleEvents(t, true)
}

func TestTransportWS_ConnectionIdleTimeout(t *testing.T) {
server := subscription_setupServer(8081)
_, subscriptionClient := subscription_setupClients(8081)
msg := randomID()
go func() {
if err := server.ListenAndServe(); err != nil {
log.Println(err)
}
}()

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer func() {
_ = server.Shutdown(ctx)
}()
defer cancel()

subscriptionClient.
WithWebsocketConnectionIdleTimeout(2 * time.Second).
OnError(func(sc *gql.SubscriptionClient, err error) error {
return err
})

/*
subscription {
helloSaid {
id
msg
}
}
*/
var sub struct {
HelloSaid struct {
ID gql.String
Message gql.String `graphql:"msg" json:"msg"`
} `graphql:"helloSaid" json:"helloSaid"`
}

_, err := subscriptionClient.Subscribe(sub, nil, func(data []byte, e error) error {
if e != nil {
t.Fatalf("got error: %v, want: nil", e)
return nil
}

log.Println("result", string(data))
e = json.Unmarshal(data, &sub)
if e != nil {
t.Fatalf("got error: %v, want: nil", e)
return nil
}

if sub.HelloSaid.Message != gql.String(msg) {
t.Fatalf("subscription message does not match. got: %s, want: %s", sub.HelloSaid.Message, msg)
}

return errors.New("exit")
})

if err != nil {
t.Fatalf("got error: %v, want: nil", err)
}

if err := subscriptionClient.Run(); err == nil || !errors.Is(err, gql.ErrWebsocketConnectionIdleTimeout) {
t.Errorf("got error: %v, want: %s", err, gql.ErrWebsocketConnectionIdleTimeout)
}
}
85 changes: 53 additions & 32 deletions subscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,17 +153,7 @@ type SubscriptionContext struct {

// Log prints condition logging with message type filters
func (sc *SubscriptionContext) Log(message interface{}, source string, opType OperationMessageType) {
if sc.client.log == nil {
return
}

for _, ty := range sc.client.disabledLogTypes {
if ty == opType {
return
}
}

sc.client.log(message, source)
sc.client.printLog(message, source, opType)
}

// OnConnectionAlive executes the OnConnectionAlive callback if exists.
Expand Down Expand Up @@ -437,6 +427,7 @@ func (sc *SubscriptionContext) run() {
if err := conn.ReadJSON(&message); err != nil {
// manual EOF check
if err == io.EOF || strings.Contains(err.Error(), "EOF") || errors.Is(err, net.ErrClosed) || strings.Contains(err.Error(), "connection reset by peer") {
sc.Log(err.Error(), "client", GQLConnectionError)
sc.client.errorChan <- errRestartSubscriptionClient

return
Expand Down Expand Up @@ -1071,6 +1062,42 @@ func (sc *SubscriptionClient) RunWithContext(ctx context.Context) error {

go subContext.run()

if sc.connectionInitialisationTimeout > 0 || sc.websocketConnectionIdleTimeout > 0 {
go func() {
ticker := time.NewTicker(time.Second)

for {
select {
case <-ctx.Done():
return
case <-ticker.C:
session := sc.getCurrentSession()
if session == nil {
continue
}

isAcknowledge := session.GetAcknowledge()
if sc.connectionInitialisationTimeout > 0 && !isAcknowledge &&
time.Since(session.getConnectionInitAt()) > sc.connectionInitialisationTimeout {
sc.printLog("Connection initialisation timeout", "client", GQLInternal)
sc.errorChan <- &websocket.CloseError{
Code: StatusConnectionInitialisationTimeout,
Reason: "Connection initialisation timeout",
}

continue
}

if sc.websocketConnectionIdleTimeout > 0 && isAcknowledge &&
time.Since(session.getLastReceivedMessageAt()) > sc.websocketConnectionIdleTimeout {
sc.printLog(ErrWebsocketConnectionIdleTimeout.Error(), "client", GQLInternal)
sc.errorChan <- ErrWebsocketConnectionIdleTimeout
}
}
}
}()
}

for {
select {
case <-ctx.Done():
Expand Down Expand Up @@ -1106,27 +1133,6 @@ func (sc *SubscriptionClient) RunWithContext(ctx context.Context) error {
}

go subContext.run()
default:
session := sc.getCurrentSession()
if session == nil {
continue
}

isAcknowledge := session.GetAcknowledge()
if sc.connectionInitialisationTimeout > 0 && !isAcknowledge &&
time.Since(session.getConnectionInitAt()) > sc.connectionInitialisationTimeout {
sc.errorChan <- &websocket.CloseError{
Code: StatusConnectionInitialisationTimeout,
Reason: "Connection initialisation timeout",
}

continue
}

if sc.websocketConnectionIdleTimeout > 0 && isAcknowledge &&
time.Since(session.getLastReceivedMessageAt()) > sc.websocketConnectionIdleTimeout {
sc.errorChan <- ErrWebsocketConnectionIdleTimeout
}
}
}
}
Expand Down Expand Up @@ -1304,6 +1310,21 @@ func (sc *SubscriptionClient) checkSubscriptionStatuses(session *SubscriptionCon
}
}

// prints condition logging with message type filters
func (sc *SubscriptionClient) printLog(message interface{}, source string, opType OperationMessageType) {
if sc.log == nil {
return
}

for _, ty := range sc.disabledLogTypes {
if ty == opType {
return
}
}

sc.log(message, source)
}

// the reusable function for sending connection init message.
// The payload format of both subscriptions-transport-ws and graphql-ws are the same
func connectionInit(conn *SubscriptionContext, connectionParams map[string]interface{}) error {
Expand Down

0 comments on commit 0cafac8

Please sign in to comment.