1
0
Fork 0
forked from External/mediamtx

webrtc: speed up gathering of incoming tracks (#3441)

This commit is contained in:
Alessandro Ros 2024-06-09 22:58:40 +02:00 committed by GitHub
parent eaf47e6598
commit d7bc304e52
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 45 additions and 49 deletions

View file

@ -10,6 +10,7 @@ import (
"github.com/pion/ice/v2" "github.com/pion/ice/v2"
"github.com/pion/interceptor" "github.com/pion/interceptor"
"github.com/pion/sdp/v3"
"github.com/pion/webrtc/v3" "github.com/pion/webrtc/v3"
"github.com/bluenviron/mediamtx/internal/conf" "github.com/bluenviron/mediamtx/internal/conf"
@ -29,6 +30,37 @@ func stringInSlice(a string, list []string) bool {
return false return false
} }
// TracksAreValid checks whether tracks in the SDP are valid
func TracksAreValid(medias []*sdp.MediaDescription) error {
videoTrack := false
audioTrack := false
for _, media := range medias {
switch media.MediaName.Media {
case "video":
if videoTrack {
return fmt.Errorf("only a single video and a single audio track are supported")
}
videoTrack = true
case "audio":
if audioTrack {
return fmt.Errorf("only a single video and a single audio track are supported")
}
audioTrack = true
default:
return fmt.Errorf("unsupported media '%s'", media.MediaName.Media)
}
}
if !videoTrack && !audioTrack {
return fmt.Errorf("no valid tracks count")
}
return nil
}
type trackRecvPair struct { type trackRecvPair struct {
track *webrtc.TrackRemote track *webrtc.TrackRemote
receiver *webrtc.RTPReceiver receiver *webrtc.RTPReceiver
@ -334,10 +366,12 @@ outer:
} }
// GatherIncomingTracks gathers incoming tracks. // GatherIncomingTracks gathers incoming tracks.
func (co *PeerConnection) GatherIncomingTracks( func (co *PeerConnection) GatherIncomingTracks(ctx context.Context) ([]*IncomingTrack, error) {
ctx context.Context, var sdp sdp.SessionDescription
maxCount int, sdp.Unmarshal([]byte(co.wr.RemoteDescription().SDP)) //nolint:errcheck
) ([]*IncomingTrack, error) {
maxTrackCount := len(sdp.MediaDescriptions)
var tracks []*IncomingTrack var tracks []*IncomingTrack
t := time.NewTimer(time.Duration(co.TrackGatherTimeout)) t := time.NewTimer(time.Duration(co.TrackGatherTimeout))
@ -346,7 +380,7 @@ func (co *PeerConnection) GatherIncomingTracks(
for { for {
select { select {
case <-t.C: case <-t.C:
if maxCount == 0 && len(tracks) != 0 { if len(tracks) != 0 {
return tracks, nil return tracks, nil
} }
return nil, fmt.Errorf("deadline exceeded while waiting tracks") return nil, fmt.Errorf("deadline exceeded while waiting tracks")
@ -358,7 +392,7 @@ func (co *PeerConnection) GatherIncomingTracks(
} }
tracks = append(tracks, track) tracks = append(tracks, track)
if len(tracks) == maxCount || len(tracks) >= 2 { if len(tracks) >= maxTrackCount {
return tracks, nil return tracks, nil
} }

View file

@ -284,7 +284,7 @@ func TestPeerConnectionPublishRead(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
inc, err := pc2.GatherIncomingTracks(context.Background(), 1) inc, err := pc2.GatherIncomingTracks(context.Background())
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, ca.out, inc[0].Format()) require.Equal(t, ca.out, inc[0].Format())

View file

@ -1,37 +0,0 @@
package webrtc
import (
"fmt"
"github.com/pion/sdp/v3"
)
// TrackCount returns the track count.
func TrackCount(medias []*sdp.MediaDescription) (int, error) {
videoTrack := false
audioTrack := false
trackCount := 0
for _, media := range medias {
switch media.MediaName.Media {
case "video":
if videoTrack {
return 0, fmt.Errorf("only a single video and a single audio track are supported")
}
videoTrack = true
case "audio":
if audioTrack {
return 0, fmt.Errorf("only a single video and a single audio track are supported")
}
audioTrack = true
default:
return 0, fmt.Errorf("unsupported media '%s'", media.MediaName.Media)
}
trackCount++
}
return trackCount, nil
}

View file

@ -169,8 +169,7 @@ func (c *WHIPClient) Read(ctx context.Context) ([]*IncomingTrack, error) {
return nil, err return nil, err
} }
// check that there are at most two tracks err = TracksAreValid(sdp.MediaDescriptions)
_, err = TrackCount(sdp.MediaDescriptions)
if err != nil { if err != nil {
c.deleteSession(context.Background()) //nolint:errcheck c.deleteSession(context.Background()) //nolint:errcheck
c.pc.Close() c.pc.Close()
@ -210,7 +209,7 @@ outer:
} }
} }
tracks, err := c.pc.GatherIncomingTracks(ctx, 0) tracks, err := c.pc.GatherIncomingTracks(ctx)
if err != nil { if err != nil {
c.deleteSession(context.Background()) //nolint:errcheck c.deleteSession(context.Background()) //nolint:errcheck
c.pc.Close() c.pc.Close()

View file

@ -461,7 +461,7 @@ func (s *session) runPublish() (int, error) {
return http.StatusBadRequest, err return http.StatusBadRequest, err
} }
trackCount, err := webrtc.TrackCount(sdp.MediaDescriptions) err = webrtc.TracksAreValid(sdp.MediaDescriptions)
if err != nil { if err != nil {
// RFC draft-ietf-wish-whip // RFC draft-ietf-wish-whip
// if the number of audio and or video // if the number of audio and or video
@ -489,7 +489,7 @@ func (s *session) runPublish() (int, error) {
s.pc = pc s.pc = pc
s.mutex.Unlock() s.mutex.Unlock()
tracks, err := pc.GatherIncomingTracks(s.ctx, trackCount) tracks, err := pc.GatherIncomingTracks(s.ctx)
if err != nil { if err != nil {
return 0, err return 0, err
} }