remove context from webrtc.PeerConnection arguments (#4854)

contexts are useless since there's already PeerConnection.Close().
This commit is contained in:
Alessandro Ros 2025-08-12 15:19:59 +02:00 committed by GitHub
parent 5ae934887d
commit b627128d0f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 157 additions and 100 deletions

View file

@ -148,13 +148,6 @@ type PeerConnection struct {
Log logger.Writer
wr *webrtc.PeerConnection
stateChangeMutex sync.Mutex
newLocalCandidate chan *webrtc.ICECandidateInit
connected chan struct{}
failed chan struct{}
closed chan struct{}
gatheringDone chan struct{}
incomingTrack chan trackRecvPair
ctx context.Context
ctxCancel context.CancelFunc
incomingTracks []*IncomingTrack
@ -163,6 +156,15 @@ type PeerConnection struct {
rtpPacketsSent *uint64
rtpPacketsLost *uint64
statsInterceptor *statsInterceptor
newLocalCandidate chan *webrtc.ICECandidateInit
incomingTrack chan trackRecvPair
connected chan struct{}
failed chan struct{}
closed chan struct{}
gatheringDone chan struct{}
done chan struct{}
chStartReading chan struct{}
}
// Start starts the peer connection.
@ -289,13 +291,6 @@ func (co *PeerConnection) Start() error {
return err
}
co.newLocalCandidate = make(chan *webrtc.ICECandidateInit)
co.connected = make(chan struct{})
co.failed = make(chan struct{})
co.closed = make(chan struct{})
co.gatheringDone = make(chan struct{})
co.incomingTrack = make(chan trackRecvPair)
co.ctx, co.ctxCancel = context.WithCancel(context.Background())
co.startedReading = new(int64)
@ -303,6 +298,15 @@ func (co *PeerConnection) Start() error {
co.rtpPacketsSent = new(uint64)
co.rtpPacketsLost = new(uint64)
co.newLocalCandidate = make(chan *webrtc.ICECandidateInit)
co.connected = make(chan struct{})
co.failed = make(chan struct{})
co.closed = make(chan struct{})
co.gatheringDone = make(chan struct{})
co.incomingTrack = make(chan trackRecvPair)
co.done = make(chan struct{})
co.chStartReading = make(chan struct{})
if co.Publish {
for _, tr := range co.OutgoingTracks {
err = tr.setup(co)
@ -336,9 +340,11 @@ func (co *PeerConnection) Start() error {
})
}
var stateChangeMutex sync.Mutex
co.wr.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
co.stateChangeMutex.Lock()
defer co.stateChangeMutex.Unlock()
stateChangeMutex.Lock()
defer stateChangeMutex.Unlock()
select {
case <-co.closed:
@ -395,26 +401,49 @@ func (co *PeerConnection) Start() error {
}
})
go co.run()
return nil
}
// Close closes the connection.
func (co *PeerConnection) Close() {
for _, track := range co.incomingTracks {
track.close()
}
for _, track := range co.OutgoingTracks {
track.close()
}
co.ctxCancel()
co.wr.GracefulClose() //nolint:errcheck
<-co.done
}
// even if GracefulClose() should wait for any goroutine to return,
// we have to wait for OnConnectionStateChange to return anyway,
// since it is executed in an uncontrolled goroutine.
// https://github.com/pion/webrtc/blob/4742d1fd54abbc3f81c3b56013654574ba7254f3/peerconnection.go#L509
<-co.closed
func (co *PeerConnection) run() {
defer close(co.done)
defer func() {
for _, track := range co.incomingTracks {
track.close()
}
for _, track := range co.OutgoingTracks {
track.close()
}
co.wr.GracefulClose() //nolint:errcheck
// even if GracefulClose() should wait for any goroutine to return,
// we have to wait for OnConnectionStateChange to return anyway,
// since it is executed in an uncontrolled goroutine.
// https://github.com/pion/webrtc/blob/4742d1fd54abbc3f81c3b56013654574ba7254f3/peerconnection.go#L509
<-co.closed
}()
for {
select {
case <-co.chStartReading:
for _, track := range co.incomingTracks {
track.start()
}
atomic.StoreInt64(co.startedReading, 1)
case <-co.ctx.Done():
return
}
}
}
func (co *PeerConnection) removeUnwantedCandidates(firstMedia *sdp.MediaDescription) error {
@ -563,10 +592,7 @@ func (co *PeerConnection) AddRemoteCandidate(candidate *webrtc.ICECandidateInit)
}
// CreateFullAnswer creates a full answer.
func (co *PeerConnection) CreateFullAnswer(
ctx context.Context,
offer *webrtc.SessionDescription,
) (*webrtc.SessionDescription, error) {
func (co *PeerConnection) CreateFullAnswer(offer *webrtc.SessionDescription) (*webrtc.SessionDescription, error) {
err := co.wr.SetRemoteDescription(*offer)
if err != nil {
return nil, err
@ -586,7 +612,7 @@ func (co *PeerConnection) CreateFullAnswer(
return nil, err
}
err = co.waitGatheringDone(ctx)
err = co.waitGatheringDone()
if err != nil {
return nil, err
}
@ -601,22 +627,20 @@ func (co *PeerConnection) CreateFullAnswer(
return answer, nil
}
func (co *PeerConnection) waitGatheringDone(ctx context.Context) error {
func (co *PeerConnection) waitGatheringDone() error {
for {
select {
case <-co.NewLocalCandidate():
case <-co.GatheringDone():
return nil
case <-ctx.Done():
case <-co.ctx.Done():
return fmt.Errorf("terminated")
}
}
}
// WaitUntilConnected waits until connection is established.
func (co *PeerConnection) WaitUntilConnected(
ctx context.Context,
) error {
func (co *PeerConnection) WaitUntilConnected() error {
t := time.NewTimer(time.Duration(co.HandshakeTimeout))
defer t.Stop()
@ -629,7 +653,7 @@ outer:
case <-co.connected:
break outer
case <-ctx.Done():
case <-co.ctx.Done():
return fmt.Errorf("terminated")
}
}
@ -638,7 +662,7 @@ outer:
}
// GatherIncomingTracks gathers incoming tracks.
func (co *PeerConnection) GatherIncomingTracks(ctx context.Context) error {
func (co *PeerConnection) GatherIncomingTracks() error {
var sdp sdp.SessionDescription
sdp.Unmarshal([]byte(co.wr.RemoteDescription().SDP)) //nolint:errcheck
@ -675,7 +699,7 @@ func (co *PeerConnection) GatherIncomingTracks(ctx context.Context) error {
case <-co.Failed():
return fmt.Errorf("peer connection closed")
case <-ctx.Done():
case <-co.ctx.Done():
return fmt.Errorf("terminated")
}
}
@ -706,12 +730,12 @@ func (co *PeerConnection) IncomingTracks() []*IncomingTrack {
return co.incomingTracks
}
// StartReading starts reading all incoming tracks.
// StartReading starts reading incoming tracks.
func (co *PeerConnection) StartReading() {
for _, track := range co.incomingTracks {
track.start()
select {
case co.chStartReading <- struct{}{}:
case <-co.ctx.Done():
}
atomic.StoreInt64(co.startedReading, 1)
}
// LocalCandidate returns the local candidate.

View file

@ -1,7 +1,6 @@
package webrtc
import (
"context"
"net"
"regexp"
"sort"
@ -131,7 +130,7 @@ func TestPeerConnectionCandidates(t *testing.T) {
require.NoError(t, err)
defer pc.Close()
answer, err := pc.CreateFullAnswer(context.Background(), &offer)
answer, err := pc.CreateFullAnswer(&offer)
require.NoError(t, err)
n := len(regexp.MustCompile("(?m)^a=candidate:.+? udp .+? typ host").FindAllString(answer.SDP, -1))
@ -250,7 +249,7 @@ func TestPeerConnectionConnectivity(t *testing.T) {
offer, err := clientPC.CreatePartialOffer()
require.NoError(t, err)
answer, err := serverPC.CreateFullAnswer(context.Background(), offer)
answer, err := serverPC.CreateFullAnswer(offer)
require.NoError(t, err)
require.Equal(t, 2, strings.Count(answer.SDP, "a=candidate:"))
@ -271,7 +270,7 @@ func TestPeerConnectionConnectivity(t *testing.T) {
}
}()
err = serverPC.WaitUntilConnected(context.Background())
err = serverPC.WaitUntilConnected()
require.NoError(t, err)
})
}
@ -325,13 +324,13 @@ func TestPeerConnectionRead(t *testing.T) {
err = pub.SetLocalDescription(offer)
require.NoError(t, err)
answer, err := reader.CreateFullAnswer(context.Background(), &offer)
answer, err := reader.CreateFullAnswer(&offer)
require.NoError(t, err)
err = pub.SetRemoteDescription(*answer)
require.NoError(t, err)
err = reader.WaitUntilConnected(context.Background())
err = reader.WaitUntilConnected()
require.NoError(t, err)
go func() {
@ -364,7 +363,7 @@ func TestPeerConnectionRead(t *testing.T) {
require.NoError(t, err2)
}()
err = reader.GatherIncomingTracks(context.Background())
err = reader.GatherIncomingTracks()
require.NoError(t, err)
codecs := gatherCodecs(reader.IncomingTracks())
@ -470,16 +469,16 @@ func TestPeerConnectionPublishRead(t *testing.T) {
offer, err := pc1.CreatePartialOffer()
require.NoError(t, err)
answer, err := pc2.CreateFullAnswer(context.Background(), offer)
answer, err := pc2.CreateFullAnswer(offer)
require.NoError(t, err)
err = pc1.SetAnswer(answer)
require.NoError(t, err)
err = pc1.WaitUntilConnected(context.Background())
err = pc1.WaitUntilConnected()
require.NoError(t, err)
err = pc2.WaitUntilConnected(context.Background())
err = pc2.WaitUntilConnected()
require.NoError(t, err)
for _, track := range pc2.OutgoingTracks {
@ -497,7 +496,7 @@ func TestPeerConnectionPublishRead(t *testing.T) {
require.NoError(t, err)
}
err = pc1.GatherIncomingTracks(context.Background())
err = pc1.GatherIncomingTracks()
require.NoError(t, err)
codecs := gatherCodecs(pc1.IncomingTracks())
@ -564,7 +563,7 @@ func TestPeerConnectionFallbackCodecs(t *testing.T) {
offer, err := pc1.CreatePartialOffer()
require.NoError(t, err)
answer, err := pc2.CreateFullAnswer(context.Background(), offer)
answer, err := pc2.CreateFullAnswer(offer)
require.NoError(t, err)
var s sdp.SessionDescription

View file

@ -1,7 +1,6 @@
package webrtc
import (
"context"
"testing"
"time"
@ -365,7 +364,7 @@ func TestToStream(t *testing.T) {
offer, err := pc1.CreatePartialOffer()
require.NoError(t, err)
answer, err := pc2.CreateFullAnswer(context.Background(), offer)
answer, err := pc2.CreateFullAnswer(offer)
require.NoError(t, err)
err = pc1.SetAnswer(answer)
@ -384,10 +383,10 @@ func TestToStream(t *testing.T) {
}
}()
err = pc1.WaitUntilConnected(context.Background())
err = pc1.WaitUntilConnected()
require.NoError(t, err)
err = pc2.WaitUntilConnected(context.Background())
err = pc2.WaitUntilConnected()
require.NoError(t, err)
err = pc1.OutgoingTracks[0].WriteRTP(&rtp.Packet{
@ -403,7 +402,7 @@ func TestToStream(t *testing.T) {
})
require.NoError(t, err)
err = pc2.GatherIncomingTracks(context.Background())
err = pc2.GatherIncomingTracks()
require.NoError(t, err)
var stream *stream.Stream

View file

@ -55,27 +55,47 @@ func (c *Client) Initialize(ctx context.Context) error {
UseAbsoluteTimestamp: c.UseAbsoluteTimestamp,
Log: c.Log,
}
err = c.pc.Start()
if err != nil {
return err
}
offer, err := c.pc.CreatePartialOffer()
initializeRes := make(chan error)
go func() {
initializeRes <- c.initializeInner(ctx)
}()
select {
case <-ctx.Done():
c.pc.Close()
<-initializeRes
return fmt.Errorf("terminated")
case err = <-initializeRes:
}
if err != nil {
c.pc.Close()
return err
}
return nil
}
func (c *Client) initializeInner(ctx context.Context) error {
offer, err := c.pc.CreatePartialOffer()
if err != nil {
return err
}
res, err := c.postOffer(ctx, offer)
if err != nil {
c.pc.Close()
return err
}
c.URL, err = c.URL.Parse(res.Location)
if err != nil {
c.pc.Close()
return err
}
@ -84,14 +104,12 @@ func (c *Client) Initialize(ctx context.Context) error {
err = sdp.Unmarshal([]byte(res.Answer.SDP))
if err != nil {
c.deleteSession(context.Background()) //nolint:errcheck
c.pc.Close()
return err
}
err = webrtc.TracksAreValid(sdp.MediaDescriptions)
if err != nil {
c.deleteSession(context.Background()) //nolint:errcheck
c.pc.Close()
return err
}
}
@ -99,7 +117,6 @@ func (c *Client) Initialize(ctx context.Context) error {
err = c.pc.SetAnswer(res.Answer)
if err != nil {
c.deleteSession(context.Background()) //nolint:errcheck
c.pc.Close()
return err
}
@ -113,7 +130,6 @@ outer:
err = c.patchCandidate(ctx, offer, res.ETag, ca)
if err != nil {
c.deleteSession(context.Background()) //nolint:errcheck
c.pc.Close()
return err
}
@ -124,16 +140,14 @@ outer:
case <-t.C:
c.deleteSession(context.Background()) //nolint:errcheck
c.pc.Close()
return fmt.Errorf("deadline exceeded while waiting connection")
}
}
if !c.Publish {
err = c.pc.GatherIncomingTracks(ctx)
err = c.pc.GatherIncomingTracks()
if err != nil {
c.deleteSession(context.Background()) //nolint:errcheck
c.pc.Close()
return err
}
}
@ -163,15 +177,10 @@ func (c *Client) Close() error {
return err
}
// Wait waits for client errors.
func (c *Client) Wait(ctx context.Context) error {
select {
case <-c.pc.Failed():
return fmt.Errorf("peer connection closed")
case <-ctx.Done():
return fmt.Errorf("terminated")
}
// Wait waits until a fatal error.
func (c *Client) Wait() error {
<-c.pc.Failed()
return fmt.Errorf("peer connection closed")
}
func (c *Client) optionsICEServers(

View file

@ -105,7 +105,7 @@ func TestClientRead(t *testing.T) {
require.NoError(t, err2)
offer := whipOffer(body)
answer, err2 := pc.CreateFullAnswer(context.Background(), offer)
answer, err2 := pc.CreateFullAnswer(offer)
require.NoError(t, err2)
w.Header().Set("Content-Type", "application/sdp")
@ -116,7 +116,7 @@ func TestClientRead(t *testing.T) {
w.Write([]byte(answer.SDP))
go func() {
err3 := pc.WaitUntilConnected(context.Background())
err3 := pc.WaitUntilConnected()
require.NoError(t, err3)
for _, track := range outgoingTracks {
@ -277,7 +277,7 @@ func TestClientPublish(t *testing.T) {
require.NoError(t, err2)
offer := whipOffer(body)
answer, err2 := pc.CreateFullAnswer(context.Background(), offer)
answer, err2 := pc.CreateFullAnswer(offer)
require.NoError(t, err2)
w.Header().Set("Content-Type", "application/sdp")
@ -288,10 +288,10 @@ func TestClientPublish(t *testing.T) {
w.Write([]byte(answer.SDP))
go func() {
err3 := pc.WaitUntilConnected(context.Background())
err3 := pc.WaitUntilConnected()
require.NoError(t, err3)
err3 = pc.GatherIncomingTracks(context.Background())
err3 = pc.GatherIncomingTracks()
require.NoError(t, err3)
codecs := gatherCodecs(pc.IncomingTracks())

View file

@ -175,7 +175,21 @@ func (s *session) runPublish() (int, error) {
if err != nil {
return http.StatusBadRequest, err
}
defer pc.Close()
terminatorDone := make(chan struct{})
defer func() { <-terminatorDone }()
terminatorRun := make(chan struct{})
defer close(terminatorRun)
go func() {
defer close(terminatorDone)
select {
case <-s.ctx.Done():
case <-terminatorRun:
}
pc.Close()
}()
offer := whipOffer(s.req.offer)
@ -195,7 +209,7 @@ func (s *session) runPublish() (int, error) {
return http.StatusNotAcceptable, err
}
answer, err := pc.CreateFullAnswer(s.ctx, offer)
answer, err := pc.CreateFullAnswer(offer)
if err != nil {
return http.StatusBadRequest, err
}
@ -204,7 +218,7 @@ func (s *session) runPublish() (int, error) {
go s.readRemoteCandidates(pc)
err = pc.WaitUntilConnected(s.ctx)
err = pc.WaitUntilConnected()
if err != nil {
return 0, err
}
@ -213,7 +227,7 @@ func (s *session) runPublish() (int, error) {
s.pc = pc
s.mutex.Unlock()
err = pc.GatherIncomingTracks(s.ctx)
err = pc.GatherIncomingTracks()
if err != nil {
return 0, err
}
@ -312,11 +326,25 @@ func (s *session) runRead() (int, error) {
stream.RemoveReader(s)
return http.StatusBadRequest, err
}
defer pc.Close()
terminatorDone := make(chan struct{})
defer func() { <-terminatorDone }()
terminatorRun := make(chan struct{})
defer close(terminatorRun)
go func() {
defer close(terminatorDone)
select {
case <-s.ctx.Done():
case <-terminatorRun:
}
pc.Close()
}()
offer := whipOffer(s.req.offer)
answer, err := pc.CreateFullAnswer(s.ctx, offer)
answer, err := pc.CreateFullAnswer(offer)
if err != nil {
stream.RemoveReader(s)
return http.StatusBadRequest, err
@ -326,7 +354,7 @@ func (s *session) runRead() (int, error) {
go s.readRemoteCandidates(pc)
err = pc.WaitUntilConnected(s.ctx)
err = pc.WaitUntilConnected()
if err != nil {
stream.RemoveReader(s)
return 0, err

View file

@ -2,7 +2,6 @@
package webrtc
import (
"context"
"fmt"
"net/http"
"net/url"
@ -62,7 +61,6 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error {
URL: u,
Log: s,
}
err = client.Initialize(params.Context)
if err != nil {
return err
@ -94,7 +92,7 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error {
readErr := make(chan error)
go func() {
readErr <- client.Wait(context.Background())
readErr <- client.Wait()
}()
for {

View file

@ -71,7 +71,7 @@ func TestSource(t *testing.T) {
require.NoError(t, err2)
offer := whipOffer(body)
answer, err2 := pc.CreateFullAnswer(context.Background(), offer)
answer, err2 := pc.CreateFullAnswer(offer)
require.NoError(t, err2)
w.Header().Set("Content-Type", "application/sdp")
@ -82,7 +82,7 @@ func TestSource(t *testing.T) {
w.Write([]byte(answer.SDP))
go func() {
err3 := pc.WaitUntilConnected(context.Background())
err3 := pc.WaitUntilConnected()
require.NoError(t, err3)
err3 = outgoingTracks[0].WriteRTP(&rtp.Packet{