From faf8d24dff08902852af4c3b565b3feb96fe5dce Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Sat, 24 Apr 2021 20:47:43 +0200 Subject: [PATCH] RTMP source: apply read and write timeouts to connection initialization --- internal/path/path.go | 1 + internal/rtmp/client.go | 32 +++- internal/sourcertmp/source.go | 296 +++++++++++++++++----------------- internal/sourcertsp/source.go | 2 +- 4 files changed, 181 insertions(+), 150 deletions(-) diff --git a/internal/path/path.go b/internal/path/path.go index cd110740..b3f33ffb 100644 --- a/internal/path/path.go +++ b/internal/path/path.go @@ -420,6 +420,7 @@ func (pa *Path) startExternalSource() { pa.source = sourcertmp.New( pa.conf.Source, pa.readTimeout, + pa.writeTimeout, &pa.sourceWg, pa.stats, pa) diff --git a/internal/rtmp/client.go b/internal/rtmp/client.go index beb945c1..15d83762 100644 --- a/internal/rtmp/client.go +++ b/internal/rtmp/client.go @@ -1,18 +1,44 @@ package rtmp import ( + "bufio" + "context" + "net" + "net/url" + "github.com/notedit/rtmp/format/rtmp" ) -// Dial connects to a server in reading mode. -func Dial(address string) (*Conn, error) { - rconn, nconn, err := rtmp.NewClient().Dial(address, rtmp.PrepareReading) +// DialContext connects to a server in reading mode. +func DialContext(ctx context.Context, address string) (*Conn, error) { + // https://github.com/aler9/rtmp/blob/master/format/rtmp/client.go#L74 + + u, err := url.Parse(address) if err != nil { return nil, err } + host := rtmp.UrlGetHost(u) + + var d net.Dialer + nconn, err := d.DialContext(ctx, "tcp", host) + if err != nil { + return nil, err + } + + rw := &bufio.ReadWriter{ + Reader: bufio.NewReaderSize(nconn, 4096), + Writer: bufio.NewWriterSize(nconn, 4096), + } + rconn := rtmp.NewConn(rw) + rconn.URL = u return &Conn{ rconn: rconn, nconn: nconn, }, nil } + +// ClientHandshake performs the handshake of a client-side connection. +func (c *Conn) ClientHandshake() error { + return c.rconn.Prepare(rtmp.StageGotPublishOrPlayCommand, rtmp.PrepareReading) +} diff --git a/internal/sourcertmp/source.go b/internal/sourcertmp/source.go index f57f42f9..fbae4fa9 100644 --- a/internal/sourcertmp/source.go +++ b/internal/sourcertmp/source.go @@ -1,6 +1,7 @@ package sourcertmp import ( + "context" "fmt" "sync" "sync/atomic" @@ -32,11 +33,12 @@ type Parent interface { // Source is a RTMP external source. type Source struct { - ur string - readTimeout time.Duration - wg *sync.WaitGroup - stats *stats.Stats - parent Parent + ur string + readTimeout time.Duration + writeTimeout time.Duration + wg *sync.WaitGroup + stats *stats.Stats + parent Parent // in terminate chan struct{} @@ -45,16 +47,18 @@ type Source struct { // New allocates a Source. func New(ur string, readTimeout time.Duration, + writeTimeout time.Duration, wg *sync.WaitGroup, stats *stats.Stats, parent Parent) *Source { s := &Source{ - ur: ur, - readTimeout: readTimeout, - wg: wg, - stats: stats, - parent: parent, - terminate: make(chan struct{}), + ur: ur, + readTimeout: readTimeout, + writeTimeout: writeTimeout, + wg: wg, + stats: stats, + parent: parent, + terminate: make(chan struct{}), } atomic.AddInt64(s.stats.CountSourcesRtmp, +1) @@ -106,165 +110,165 @@ func (s *Source) run() { } func (s *Source) runInner() bool { - s.log(logger.Info, "connecting") + ctx, cancel := context.WithCancel(context.Background()) - var conn *rtmp.Conn - var err error - dialDone := make(chan struct{}, 1) + done := make(chan error) go func() { - defer close(dialDone) - conn, err = rtmp.Dial(s.ur) - }() + done <- func() error { + s.log(logger.Debug, "connecting") - select { - case <-dialDone: - case <-s.terminate: - return false - } + ctx2, cancel2 := context.WithTimeout(ctx, s.readTimeout) + defer cancel2() - if err != nil { - s.log(logger.Info, "ERR: %s", err) - return true - } - - var videoTrack *gortsplib.Track - var audioTrack *gortsplib.Track - metadataDone := make(chan struct{}) - go func() { - defer close(metadataDone) - conn.NetConn().SetReadDeadline(time.Now().Add(s.readTimeout)) - videoTrack, audioTrack, err = conn.ReadMetadata() - }() - - select { - case <-metadataDone: - case <-s.terminate: - conn.NetConn().Close() - <-metadataDone - return false - } - - if err != nil { - s.log(logger.Info, "ERR: %s", err) - return true - } - - var tracks gortsplib.Tracks - - var h264Encoder *rtph264.Encoder - if videoTrack != nil { - h264Encoder = rtph264.NewEncoder(96, nil, nil, nil) - tracks = append(tracks, videoTrack) - } - - var aacEncoder *rtpaac.Encoder - if audioTrack != nil { - clockRate, _ := audioTrack.ClockRate() - aacEncoder = rtpaac.NewEncoder(96, clockRate, nil, nil, nil) - tracks = append(tracks, audioTrack) - } - - for i, t := range tracks { - t.ID = i - } - - s.log(logger.Info, "ready") - - cres := make(chan source.ExtSetReadyRes) - s.parent.OnExtSourceSetReady(source.ExtSetReadyReq{ - Tracks: tracks, - Res: cres, - }) - res := <-cres - - defer func() { - res := make(chan struct{}) - s.parent.OnExtSourceSetNotReady(source.ExtSetNotReadyReq{ - Res: res, - }) - <-res - }() - - readerDone := make(chan error) - go func() { - readerDone <- func() error { - rtcpSenders := rtcpsenderset.New(tracks, res.SP.OnFrame) - defer rtcpSenders.Close() - - onFrame := func(trackID int, payload []byte) { - rtcpSenders.OnFrame(trackID, gortsplib.StreamTypeRTP, payload) - res.SP.OnFrame(trackID, gortsplib.StreamTypeRTP, payload) + conn, err := rtmp.DialContext(ctx2, s.ur) + if err != nil { + return err } - for { - conn.NetConn().SetReadDeadline(time.Now().Add(s.readTimeout)) - pkt, err := conn.ReadPacket() - if err != nil { - return err - } - - switch pkt.Type { - case av.H264: - if videoTrack == nil { - return fmt.Errorf("ERR: received an H264 frame, but track is not set up") - } - - nalus, err := h264.DecodeAVCC(pkt.Data) + readDone := make(chan error) + go func() { + readDone <- func() error { + conn.NetConn().SetReadDeadline(time.Now().Add(s.readTimeout)) + conn.NetConn().SetWriteDeadline(time.Now().Add(s.writeTimeout)) + err = conn.ClientHandshake() if err != nil { return err } - var outNALUs [][]byte - for _, nalu := range nalus { - // remove SPS, PPS and AUD, not needed by RTSP / RTMP - typ := h264.NALUType(nalu[0] & 0x1F) - switch typ { - case h264.NALUTypeSPS, h264.NALUTypePPS, h264.NALUTypeAccessUnitDelimiter: - continue + conn.NetConn().SetWriteDeadline(time.Time{}) + + conn.NetConn().SetReadDeadline(time.Now().Add(s.readTimeout)) + videoTrack, audioTrack, err := conn.ReadMetadata() + if err != nil { + return err + } + + var tracks gortsplib.Tracks + + var h264Encoder *rtph264.Encoder + if videoTrack != nil { + h264Encoder = rtph264.NewEncoder(96, nil, nil, nil) + tracks = append(tracks, videoTrack) + } + + var aacEncoder *rtpaac.Encoder + if audioTrack != nil { + clockRate, _ := audioTrack.ClockRate() + aacEncoder = rtpaac.NewEncoder(96, clockRate, nil, nil, nil) + tracks = append(tracks, audioTrack) + } + + for i, t := range tracks { + t.ID = i + } + + s.log(logger.Info, "ready") + + cres := make(chan source.ExtSetReadyRes) + s.parent.OnExtSourceSetReady(source.ExtSetReadyReq{ + Tracks: tracks, + Res: cres, + }) + res := <-cres + + defer func() { + res := make(chan struct{}) + s.parent.OnExtSourceSetNotReady(source.ExtSetNotReadyReq{ + Res: res, + }) + <-res + }() + + rtcpSenders := rtcpsenderset.New(tracks, res.SP.OnFrame) + defer rtcpSenders.Close() + + onFrame := func(trackID int, payload []byte) { + rtcpSenders.OnFrame(trackID, gortsplib.StreamTypeRTP, payload) + res.SP.OnFrame(trackID, gortsplib.StreamTypeRTP, payload) + } + + for { + conn.NetConn().SetReadDeadline(time.Now().Add(s.readTimeout)) + pkt, err := conn.ReadPacket() + if err != nil { + return err } - outNALUs = append(outNALUs, nalu) - } + switch pkt.Type { + case av.H264: + if videoTrack == nil { + return fmt.Errorf("ERR: received an H264 frame, but track is not set up") + } - pkts, err := h264Encoder.Encode(outNALUs, pkt.Time+pkt.CTime) - if err != nil { - return fmt.Errorf("ERR while encoding H264: %v", err) - } + nalus, err := h264.DecodeAVCC(pkt.Data) + if err != nil { + return err + } - for _, pkt := range pkts { - onFrame(videoTrack.ID, pkt) - } + var outNALUs [][]byte + for _, nalu := range nalus { + // remove SPS, PPS and AUD, not needed by RTSP / RTMP + typ := h264.NALUType(nalu[0] & 0x1F) + switch typ { + case h264.NALUTypeSPS, h264.NALUTypePPS, h264.NALUTypeAccessUnitDelimiter: + continue + } - case av.AAC: - if audioTrack == nil { - return fmt.Errorf("ERR: received an AAC frame, but track is not set up") - } + outNALUs = append(outNALUs, nalu) + } - pkts, err := aacEncoder.Encode([][]byte{pkt.Data}, pkt.Time+pkt.CTime) - if err != nil { - return fmt.Errorf("ERR while encoding AAC: %v", err) - } + pkts, err := h264Encoder.Encode(outNALUs, pkt.Time+pkt.CTime) + if err != nil { + return fmt.Errorf("ERR while encoding H264: %v", err) + } - for _, pkt := range pkts { - onFrame(audioTrack.ID, pkt) - } + for _, pkt := range pkts { + onFrame(videoTrack.ID, pkt) + } - default: - return fmt.Errorf("ERR: unexpected packet: %v", pkt.Type) - } + case av.AAC: + if audioTrack == nil { + return fmt.Errorf("ERR: received an AAC frame, but track is not set up") + } + + pkts, err := aacEncoder.Encode([][]byte{pkt.Data}, pkt.Time+pkt.CTime) + if err != nil { + return fmt.Errorf("ERR while encoding AAC: %v", err) + } + + for _, pkt := range pkts { + onFrame(audioTrack.ID, pkt) + } + + default: + return fmt.Errorf("ERR: unexpected packet: %v", pkt.Type) + } + } + }() + }() + + select { + case err := <-readDone: + conn.NetConn().Close() + return err + + case <-ctx.Done(): + conn.NetConn().Close() + <-readDone + return nil } }() }() select { - case <-s.terminate: - conn.NetConn().Close() - <-readerDone - return false - - case err := <-readerDone: + case err := <-done: + cancel() s.log(logger.Info, "ERR: %s", err) - conn.NetConn().Close() return true + + case <-s.terminate: + cancel() + <-done + return false } } diff --git a/internal/sourcertsp/source.go b/internal/sourcertsp/source.go index d27fadc4..742590a6 100644 --- a/internal/sourcertsp/source.go +++ b/internal/sourcertsp/source.go @@ -121,7 +121,7 @@ func (s *Source) run() { } func (s *Source) runInner() bool { - s.log(logger.Info, "connecting") + s.log(logger.Debug, "connecting") var conn *gortsplib.ClientConn var err error