RTMP source: apply read and write timeouts to connection initialization

This commit is contained in:
aler9 2021-04-24 20:47:43 +02:00
parent a1a56ff203
commit faf8d24dff
4 changed files with 181 additions and 150 deletions

View file

@ -420,6 +420,7 @@ func (pa *Path) startExternalSource() {
pa.source = sourcertmp.New(
pa.conf.Source,
pa.readTimeout,
pa.writeTimeout,
&pa.sourceWg,
pa.stats,
pa)

View file

@ -1,18 +1,44 @@
package rtmp
import (
"bufio"
"context"
"net"
"net/url"
"github.com/notedit/rtmp/format/rtmp"
)
// Dial connects to a server in reading mode.
func Dial(address string) (*Conn, error) {
rconn, nconn, err := rtmp.NewClient().Dial(address, rtmp.PrepareReading)
// DialContext connects to a server in reading mode.
func DialContext(ctx context.Context, address string) (*Conn, error) {
// https://github.com/aler9/rtmp/blob/master/format/rtmp/client.go#L74
u, err := url.Parse(address)
if err != nil {
return nil, err
}
host := rtmp.UrlGetHost(u)
var d net.Dialer
nconn, err := d.DialContext(ctx, "tcp", host)
if err != nil {
return nil, err
}
rw := &bufio.ReadWriter{
Reader: bufio.NewReaderSize(nconn, 4096),
Writer: bufio.NewWriterSize(nconn, 4096),
}
rconn := rtmp.NewConn(rw)
rconn.URL = u
return &Conn{
rconn: rconn,
nconn: nconn,
}, nil
}
// ClientHandshake performs the handshake of a client-side connection.
func (c *Conn) ClientHandshake() error {
return c.rconn.Prepare(rtmp.StageGotPublishOrPlayCommand, rtmp.PrepareReading)
}

View file

@ -1,6 +1,7 @@
package sourcertmp
import (
"context"
"fmt"
"sync"
"sync/atomic"
@ -32,11 +33,12 @@ type Parent interface {
// Source is a RTMP external source.
type Source struct {
ur string
readTimeout time.Duration
wg *sync.WaitGroup
stats *stats.Stats
parent Parent
ur string
readTimeout time.Duration
writeTimeout time.Duration
wg *sync.WaitGroup
stats *stats.Stats
parent Parent
// in
terminate chan struct{}
@ -45,16 +47,18 @@ type Source struct {
// New allocates a Source.
func New(ur string,
readTimeout time.Duration,
writeTimeout time.Duration,
wg *sync.WaitGroup,
stats *stats.Stats,
parent Parent) *Source {
s := &Source{
ur: ur,
readTimeout: readTimeout,
wg: wg,
stats: stats,
parent: parent,
terminate: make(chan struct{}),
ur: ur,
readTimeout: readTimeout,
writeTimeout: writeTimeout,
wg: wg,
stats: stats,
parent: parent,
terminate: make(chan struct{}),
}
atomic.AddInt64(s.stats.CountSourcesRtmp, +1)
@ -106,165 +110,165 @@ func (s *Source) run() {
}
func (s *Source) runInner() bool {
s.log(logger.Info, "connecting")
ctx, cancel := context.WithCancel(context.Background())
var conn *rtmp.Conn
var err error
dialDone := make(chan struct{}, 1)
done := make(chan error)
go func() {
defer close(dialDone)
conn, err = rtmp.Dial(s.ur)
}()
done <- func() error {
s.log(logger.Debug, "connecting")
select {
case <-dialDone:
case <-s.terminate:
return false
}
ctx2, cancel2 := context.WithTimeout(ctx, s.readTimeout)
defer cancel2()
if err != nil {
s.log(logger.Info, "ERR: %s", err)
return true
}
var videoTrack *gortsplib.Track
var audioTrack *gortsplib.Track
metadataDone := make(chan struct{})
go func() {
defer close(metadataDone)
conn.NetConn().SetReadDeadline(time.Now().Add(s.readTimeout))
videoTrack, audioTrack, err = conn.ReadMetadata()
}()
select {
case <-metadataDone:
case <-s.terminate:
conn.NetConn().Close()
<-metadataDone
return false
}
if err != nil {
s.log(logger.Info, "ERR: %s", err)
return true
}
var tracks gortsplib.Tracks
var h264Encoder *rtph264.Encoder
if videoTrack != nil {
h264Encoder = rtph264.NewEncoder(96, nil, nil, nil)
tracks = append(tracks, videoTrack)
}
var aacEncoder *rtpaac.Encoder
if audioTrack != nil {
clockRate, _ := audioTrack.ClockRate()
aacEncoder = rtpaac.NewEncoder(96, clockRate, nil, nil, nil)
tracks = append(tracks, audioTrack)
}
for i, t := range tracks {
t.ID = i
}
s.log(logger.Info, "ready")
cres := make(chan source.ExtSetReadyRes)
s.parent.OnExtSourceSetReady(source.ExtSetReadyReq{
Tracks: tracks,
Res: cres,
})
res := <-cres
defer func() {
res := make(chan struct{})
s.parent.OnExtSourceSetNotReady(source.ExtSetNotReadyReq{
Res: res,
})
<-res
}()
readerDone := make(chan error)
go func() {
readerDone <- func() error {
rtcpSenders := rtcpsenderset.New(tracks, res.SP.OnFrame)
defer rtcpSenders.Close()
onFrame := func(trackID int, payload []byte) {
rtcpSenders.OnFrame(trackID, gortsplib.StreamTypeRTP, payload)
res.SP.OnFrame(trackID, gortsplib.StreamTypeRTP, payload)
conn, err := rtmp.DialContext(ctx2, s.ur)
if err != nil {
return err
}
for {
conn.NetConn().SetReadDeadline(time.Now().Add(s.readTimeout))
pkt, err := conn.ReadPacket()
if err != nil {
return err
}
switch pkt.Type {
case av.H264:
if videoTrack == nil {
return fmt.Errorf("ERR: received an H264 frame, but track is not set up")
}
nalus, err := h264.DecodeAVCC(pkt.Data)
readDone := make(chan error)
go func() {
readDone <- func() error {
conn.NetConn().SetReadDeadline(time.Now().Add(s.readTimeout))
conn.NetConn().SetWriteDeadline(time.Now().Add(s.writeTimeout))
err = conn.ClientHandshake()
if err != nil {
return err
}
var outNALUs [][]byte
for _, nalu := range nalus {
// remove SPS, PPS and AUD, not needed by RTSP / RTMP
typ := h264.NALUType(nalu[0] & 0x1F)
switch typ {
case h264.NALUTypeSPS, h264.NALUTypePPS, h264.NALUTypeAccessUnitDelimiter:
continue
conn.NetConn().SetWriteDeadline(time.Time{})
conn.NetConn().SetReadDeadline(time.Now().Add(s.readTimeout))
videoTrack, audioTrack, err := conn.ReadMetadata()
if err != nil {
return err
}
var tracks gortsplib.Tracks
var h264Encoder *rtph264.Encoder
if videoTrack != nil {
h264Encoder = rtph264.NewEncoder(96, nil, nil, nil)
tracks = append(tracks, videoTrack)
}
var aacEncoder *rtpaac.Encoder
if audioTrack != nil {
clockRate, _ := audioTrack.ClockRate()
aacEncoder = rtpaac.NewEncoder(96, clockRate, nil, nil, nil)
tracks = append(tracks, audioTrack)
}
for i, t := range tracks {
t.ID = i
}
s.log(logger.Info, "ready")
cres := make(chan source.ExtSetReadyRes)
s.parent.OnExtSourceSetReady(source.ExtSetReadyReq{
Tracks: tracks,
Res: cres,
})
res := <-cres
defer func() {
res := make(chan struct{})
s.parent.OnExtSourceSetNotReady(source.ExtSetNotReadyReq{
Res: res,
})
<-res
}()
rtcpSenders := rtcpsenderset.New(tracks, res.SP.OnFrame)
defer rtcpSenders.Close()
onFrame := func(trackID int, payload []byte) {
rtcpSenders.OnFrame(trackID, gortsplib.StreamTypeRTP, payload)
res.SP.OnFrame(trackID, gortsplib.StreamTypeRTP, payload)
}
for {
conn.NetConn().SetReadDeadline(time.Now().Add(s.readTimeout))
pkt, err := conn.ReadPacket()
if err != nil {
return err
}
outNALUs = append(outNALUs, nalu)
}
switch pkt.Type {
case av.H264:
if videoTrack == nil {
return fmt.Errorf("ERR: received an H264 frame, but track is not set up")
}
pkts, err := h264Encoder.Encode(outNALUs, pkt.Time+pkt.CTime)
if err != nil {
return fmt.Errorf("ERR while encoding H264: %v", err)
}
nalus, err := h264.DecodeAVCC(pkt.Data)
if err != nil {
return err
}
for _, pkt := range pkts {
onFrame(videoTrack.ID, pkt)
}
var outNALUs [][]byte
for _, nalu := range nalus {
// remove SPS, PPS and AUD, not needed by RTSP / RTMP
typ := h264.NALUType(nalu[0] & 0x1F)
switch typ {
case h264.NALUTypeSPS, h264.NALUTypePPS, h264.NALUTypeAccessUnitDelimiter:
continue
}
case av.AAC:
if audioTrack == nil {
return fmt.Errorf("ERR: received an AAC frame, but track is not set up")
}
outNALUs = append(outNALUs, nalu)
}
pkts, err := aacEncoder.Encode([][]byte{pkt.Data}, pkt.Time+pkt.CTime)
if err != nil {
return fmt.Errorf("ERR while encoding AAC: %v", err)
}
pkts, err := h264Encoder.Encode(outNALUs, pkt.Time+pkt.CTime)
if err != nil {
return fmt.Errorf("ERR while encoding H264: %v", err)
}
for _, pkt := range pkts {
onFrame(audioTrack.ID, pkt)
}
for _, pkt := range pkts {
onFrame(videoTrack.ID, pkt)
}
default:
return fmt.Errorf("ERR: unexpected packet: %v", pkt.Type)
}
case av.AAC:
if audioTrack == nil {
return fmt.Errorf("ERR: received an AAC frame, but track is not set up")
}
pkts, err := aacEncoder.Encode([][]byte{pkt.Data}, pkt.Time+pkt.CTime)
if err != nil {
return fmt.Errorf("ERR while encoding AAC: %v", err)
}
for _, pkt := range pkts {
onFrame(audioTrack.ID, pkt)
}
default:
return fmt.Errorf("ERR: unexpected packet: %v", pkt.Type)
}
}
}()
}()
select {
case err := <-readDone:
conn.NetConn().Close()
return err
case <-ctx.Done():
conn.NetConn().Close()
<-readDone
return nil
}
}()
}()
select {
case <-s.terminate:
conn.NetConn().Close()
<-readerDone
return false
case err := <-readerDone:
case err := <-done:
cancel()
s.log(logger.Info, "ERR: %s", err)
conn.NetConn().Close()
return true
case <-s.terminate:
cancel()
<-done
return false
}
}

View file

@ -121,7 +121,7 @@ func (s *Source) run() {
}
func (s *Source) runInner() bool {
s.log(logger.Info, "connecting")
s.log(logger.Debug, "connecting")
var conn *gortsplib.ClientConn
var err error