From 9b491499bc247ec7d27267df9f256213596cc48d Mon Sep 17 00:00:00 2001 From: Alessandro Ros Date: Mon, 24 Jul 2023 20:32:28 +0200 Subject: [PATCH] webrtc: speed up track detection (#2105) --- internal/core/webrtc_manager.go | 2 +- internal/core/webrtc_session.go | 111 +++++++++++++++++++++----------- 2 files changed, 76 insertions(+), 37 deletions(-) diff --git a/internal/core/webrtc_manager.go b/internal/core/webrtc_manager.go index 3d3ec2eb..b7f309ff 100644 --- a/internal/core/webrtc_manager.go +++ b/internal/core/webrtc_manager.go @@ -25,7 +25,7 @@ import ( const ( webrtcPauseAfterAuthError = 2 * time.Second webrtcHandshakeTimeout = 10 * time.Second - webrtcTrackGatherTimeout = 2 * time.Second + webrtcTrackGatherTimeout = 5 * time.Second webrtcPayloadMaxSize = 1188 // 1200 - 12 (RTP header) webrtcStreamID = "mediamtx" webrtcTurnSecretExpiration = 24 * 3600 * time.Second diff --git a/internal/core/webrtc_session.go b/internal/core/webrtc_session.go index 70dd92ba..16f6e4ce 100644 --- a/internal/core/webrtc_session.go +++ b/internal/core/webrtc_session.go @@ -14,6 +14,7 @@ import ( "github.com/bluenviron/gortsplib/v3/pkg/ringbuffer" "github.com/google/uuid" "github.com/pion/ice/v2" + "github.com/pion/sdp/v3" "github.com/pion/webrtc/v3" "github.com/bluenviron/mediamtx/internal/logger" @@ -40,6 +41,30 @@ func mediasOfIncomingTracks(tracks []*webRTCIncomingTrack) media.Medias { return ret } +func waitUntilConnected( + ctx context.Context, + pc *peerConnection, +) error { + t := time.NewTimer(webrtcHandshakeTimeout) + defer t.Stop() + +outer: + for { + select { + case <-t.C: + return fmt.Errorf("deadline exceeded while waiting connection") + + case <-pc.connected: + break outer + + case <-ctx.Done(): + return fmt.Errorf("terminated") + } + } + + return nil +} + func gatherOutgoingTracks(medias media.Medias) ([]*webRTCOutgoingTrack, error) { var tracks []*webRTCOutgoingTrack @@ -73,6 +98,7 @@ func gatherIncomingTracks( ctx context.Context, pc *peerConnection, trackRecv chan trackRecvPair, + trackCount int, ) ([]*webRTCIncomingTrack, error) { var tracks []*webRTCIncomingTrack @@ -82,10 +108,7 @@ func gatherIncomingTracks( for { select { case <-t.C: - if len(tracks) == 0 { - return nil, fmt.Errorf("no tracks found") - } - return tracks, nil + return nil, fmt.Errorf("deadline exceeded while waiting tracks") case pair := <-trackRecv: track, err := newWebRTCIncomingTrack(pair.track, pair.receiver, pc.WriteRTCP) @@ -94,7 +117,7 @@ func gatherIncomingTracks( } tracks = append(tracks, track) - if len(tracks) == 2 { + if len(tracks) == trackCount { return tracks, nil } @@ -262,6 +285,39 @@ func (s *webRTCSession) runPublish() (int, error) { } defer pc.close() + offer := s.offer() + + var sdp sdp.SessionDescription + err = sdp.Unmarshal([]byte(offer.SDP)) + if err != nil { + return http.StatusBadRequest, err + } + + videoTrack := false + audioTrack := false + trackCount := 0 + + for _, media := range sdp.MediaDescriptions { + switch media.MediaName.Media { + case "video": + if videoTrack { + return http.StatusBadRequest, fmt.Errorf("only a single video and a single audio track are supported") + } + videoTrack = true + + case "audio": + if audioTrack { + return http.StatusBadRequest, fmt.Errorf("only a single video and a single audio track are supported") + } + audioTrack = true + + default: + return http.StatusBadRequest, fmt.Errorf("unsupported media '%s'", media.MediaName.Media) + } + + trackCount++ + } + _, err = pc.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo, webrtc.RtpTransceiverInit{ Direction: webrtc.RTPTransceiverDirectionRecvonly, }) @@ -285,7 +341,6 @@ func (s *webRTCSession) runPublish() (int, error) { } }) - offer := s.buildOffer() err = pc.SetRemoteDescription(*offer) if err != nil { return http.StatusBadRequest, err @@ -313,12 +368,16 @@ func (s *webRTCSession) runPublish() (int, error) { go s.readRemoteCandidates(pc) - err = s.waitUntilConnected(pc) + err = waitUntilConnected(s.ctx, pc) if err != nil { return 0, err } - tracks, err := gatherIncomingTracks(s.ctx, pc, trackRecv) + s.mutex.Lock() + s.pc = pc + s.mutex.Unlock() + + tracks, err := gatherIncomingTracks(s.ctx, pc, trackRecv, trackCount) if err != nil { return 0, err } @@ -406,7 +465,8 @@ func (s *webRTCSession) runRead() (int, error) { } } - offer := s.buildOffer() + offer := s.offer() + err = pc.SetRemoteDescription(*offer) if err != nil { return http.StatusBadRequest, err @@ -434,11 +494,15 @@ func (s *webRTCSession) runRead() (int, error) { go s.readRemoteCandidates(pc) - err = s.waitUntilConnected(pc) + err = waitUntilConnected(s.ctx, pc) if err != nil { return 0, err } + s.mutex.Lock() + s.pc = pc + s.mutex.Unlock() + ringBuffer, _ := ringbuffer.New(uint64(s.readBufferCount)) defer ringBuffer.Close() @@ -475,7 +539,7 @@ func (s *webRTCSession) runRead() (int, error) { } } -func (s *webRTCSession) buildOffer() *webrtc.SessionDescription { +func (s *webRTCSession) offer() *webrtc.SessionDescription { return &webrtc.SessionDescription{ Type: webrtc.SDPTypeOffer, SDP: string(s.req.offer), @@ -502,31 +566,6 @@ func (s *webRTCSession) writeAnswer(answer *webrtc.SessionDescription) { s.answerSent = true } -func (s *webRTCSession) waitUntilConnected(pc *peerConnection) error { - t := time.NewTimer(webrtcHandshakeTimeout) - defer t.Stop() - -outer: - for { - select { - case <-t.C: - return fmt.Errorf("deadline exceeded") - - case <-pc.connected: - break outer - - case <-s.ctx.Done(): - return fmt.Errorf("terminated") - } - } - - s.mutex.Lock() - s.pc = pc - s.mutex.Unlock() - - return nil -} - func (s *webRTCSession) readRemoteCandidates(pc *peerConnection) { for { select {