Skip to content

Commit

Permalink
Fix pli missed cause by two goroutine compete rtcp reader (#376)
Browse files Browse the repository at this point in the history
There were two goroutine to read rtcp when publishing a
LocalSampleTrack, one is publication to calculate rtt
and the other is LocalSampleTrack itself, cause pli hander
rely on LocalSampleTrack's rtcp callback might miss pli
request.
  • Loading branch information
cnderrauber authored Jan 8, 2024
1 parent 12e24f8 commit ddded83
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 39 deletions.
5 changes: 1 addition & 4 deletions engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,6 @@ func (e *RTCEngine) TrackPublishedChan() <-chan *livekit.TrackPublishedResponse
}

func (e *RTCEngine) setRTT(rtt uint32) {
if pc := e.publisher; pc != nil {
pc.SetRTT(rtt)
}

if pc := e.subscriber; pc != nil {
pc.SetRTT(rtt)
}
Expand All @@ -193,6 +189,7 @@ func (e *RTCEngine) configure(res *livekit.JoinResponse) error {
Configuration: configuration,
RetransmitBufferSize: e.connParams.RetransmitBufferSize,
Pacer: e.connParams.Pacer,
OnRTTUpdate: e.setRTT,
}); err != nil {
return err
}
Expand Down
9 changes: 4 additions & 5 deletions localparticipant.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@ func (p *LocalParticipant) PublishTrack(track webrtc.TrackLocal, opts *TrackPubl
}

pub := NewLocalTrackPublication(kind, track, *opts, p.engine.client)
pub.OnRttUpdate(func(rtt uint32) {
p.engine.setRTT(rtt)
})
pub.onMuteChanged = p.onTrackMuted

req := &livekit.AddTrackRequest{
Expand Down Expand Up @@ -107,7 +104,9 @@ func (p *LocalParticipant) PublishTrack(track webrtc.TrackLocal, opts *TrackPubl
return nil, err
}

pub.setSender(transceiver.Sender())
// LocalSampleTrack will consume rtcp packets so we don't need to consume again
_, isSampleTrack := track.(*LocalSampleTrack)
pub.setSender(transceiver.Sender(), !isSampleTrack)

pub.updateInfo(pubRes.Track)
p.addPublication(pub)
Expand Down Expand Up @@ -196,7 +195,7 @@ func (p *LocalParticipant) PublishSimulcastTrack(tracks []*LocalSampleTrack, opt
return nil, err
}
sender = transceiver.Sender()
pub.setSender(sender)
pub.setSender(sender, false)
} else {
if err = sender.AddEncoding(st); err != nil {
return nil, err
Expand Down
1 change: 1 addition & 0 deletions localsampletrack.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ func (s *LocalSampleTrack) rtcpWorker(rtcpReader interceptor.RTCPReader) {

pkts, err := rtcp.Unmarshal(b[:i])
if err != nil {
logger.Warnw("could not unmarshal rtcp", err)
return
}
for _, packet := range pkts {
Expand Down
66 changes: 66 additions & 0 deletions pkg/interceptor/rttinteceptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package interceptor

import (
"github.com/livekit/mediatransportutil"
"github.com/pion/interceptor"
"github.com/pion/rtcp"
)

type RTTInterceptorFactory struct {
onRttUpdate func(rtt uint32)
}

func NewRTTInterceptorFactory(onRttUpdate func(rtt uint32)) *RTTInterceptorFactory {
return &RTTInterceptorFactory{
onRttUpdate: onRttUpdate,
}
}

func (r *RTTInterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) {
return NewRTTInterceptor(r.onRttUpdate), nil
}

type RTTInterceptor struct {
interceptor.NoOp

onRttUpdate func(rtt uint32)
}

func NewRTTInterceptor(onRttUpdate func(rtt uint32)) *RTTInterceptor {
return &RTTInterceptor{
onRttUpdate: onRttUpdate,
}
}

func (r *RTTInterceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.RTCPReader {
return interceptor.RTCPReaderFunc(func(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) {
i, attr, err := reader.Read(b, a)
if err != nil {
return 0, nil, err
}

if attr == nil {
attr = make(interceptor.Attributes)
}
pkts, err := attr.GetRTCPPackets(b[:i])
if err != nil {
return 0, nil, err
}

rttCaculate:
for _, packet := range pkts {
if rr, ok := packet.(*rtcp.ReceiverReport); ok {
for _, report := range rr.Reports {
rtt, err := mediatransportutil.GetRttMsFromReceiverReportOnly(&report)
if err == nil && rtt != 0 {
r.onRttUpdate(rtt)
}

break rttCaculate
}
}
}

return i, attr, err
})
}
32 changes: 7 additions & 25 deletions publication.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"go.uber.org/atomic"
"google.golang.org/protobuf/proto"

"github.com/livekit/mediatransportutil"
"github.com/livekit/protocol/livekit"
)

Expand Down Expand Up @@ -240,7 +239,6 @@ type LocalTrackPublication struct {
sender *webrtc.RTPSender
// set for simulcasted tracks
simulcastTracks map[livekit.VideoQuality]*LocalSampleTrack
onRttUpdate func(uint32)
opts TrackPublicationOptions
onMuteChanged func(*LocalTrackPublication, bool)
}
Expand Down Expand Up @@ -315,43 +313,27 @@ func (p *LocalTrackPublication) addSimulcastTrack(st *LocalSampleTrack) {
}
}

func (p *LocalTrackPublication) setSender(sender *webrtc.RTPSender) {
func (p *LocalTrackPublication) setSender(sender *webrtc.RTPSender, consumeRTCP bool) {
p.lock.Lock()
p.sender = sender
p.lock.Unlock()

if !consumeRTCP {
return
}

// consume RTCP packets so interceptors can handle them (rtt, nacks...)
go func() {
for {
packets, _, err := sender.ReadRTCP()
_, _, err := sender.ReadRTCP()
if err != nil {
// pipe closed
return
}

rttCaculate:
for _, packet := range packets {
if rr, ok := packet.(*rtcp.ReceiverReport); ok {
for _, r := range rr.Reports {
rr.Reports = append(rr.Reports, r)
rtt, err := mediatransportutil.GetRttMsFromReceiverReportOnly(&r)
if err == nil && rtt != 0 && p.onRttUpdate != nil {
p.onRttUpdate(rtt)
}

break rttCaculate
}
}
}
}
}()
}

func (p *LocalTrackPublication) OnRttUpdate(cb func(uint32)) {
p.lock.Lock()
p.onRttUpdate = cb
p.lock.Unlock()
}

func (p *LocalTrackPublication) CloseTrack() {
for _, st := range p.simulcastTracks {
st.Close()
Expand Down
26 changes: 21 additions & 5 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ type PCTransport struct {
nackGenerator *sdkinterceptor.NackGeneratorInterceptorFactory

onRemoteDescriptionSettled func() error
onRTTUpdate func(rtt uint32)

OnOffer func(description webrtc.SessionDescription)
}
Expand All @@ -64,6 +65,7 @@ type PCTransportParams struct {

RetransmitBufferSize uint16
Pacer pacer.Factory
OnRTTUpdate func(rtt uint32)
}

func NewPCTransport(params PCTransportParams) (*PCTransport, error) {
Expand All @@ -86,6 +88,11 @@ func NewPCTransport(params PCTransportParams) (*PCTransport, error) {

i := &interceptor.Registry{}

t := &PCTransport{
debouncedNegotiate: debounce.New(negotiationFrequency),
onRTTUpdate: params.OnRTTUpdate,
}

// nack interceptor
generator := &sdkinterceptor.NackGeneratorInterceptorFactory{}
var generatorOption []nack.ResponderOption
Expand Down Expand Up @@ -119,6 +126,10 @@ func NewPCTransport(params PCTransportParams) (*PCTransport, error) {

i.Add(sdkinterceptor.NewLimitSizeInterceptorFactory())

if params.OnRTTUpdate != nil {
i.Add(sdkinterceptor.NewRTTInterceptorFactory(t.handleRTTUpdate))
}

se := webrtc.SettingEngine{}
se.SetSRTPProtectionProfiles(dtls.SRTP_AEAD_AES_128_GCM, dtls.SRTP_AES128_CM_HMAC_SHA1_80)
se.SetDTLSRetransmissionInterval(dtlsRetransmissionInterval)
Expand All @@ -130,17 +141,22 @@ func NewPCTransport(params PCTransportParams) (*PCTransport, error) {
return nil, err
}

t := &PCTransport{
pc: pc,
debouncedNegotiate: debounce.New(negotiationFrequency),
nackGenerator: generator,
}
t.pc = pc
t.nackGenerator = generator

pc.OnICEGatheringStateChange(t.onICEGatheringStateChange)

return t, nil
}

func (t *PCTransport) handleRTTUpdate(rtt uint32) {
t.SetRTT(rtt)

if t.onRTTUpdate != nil {
t.onRTTUpdate(rtt)
}
}

func (t *PCTransport) onICEGatheringStateChange(state webrtc.ICEGathererState) {
if state != webrtc.ICEGathererStateComplete {
return
Expand Down

0 comments on commit ddded83

Please sign in to comment.