From 02afa8ff99276cd87a26d13ea321378873432010 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Tue, 7 Sep 2021 09:50:18 +0200 Subject: [PATCH] rtmp, hls: remove initial difference of 2secs between PTS and DTS of H264 --- internal/core/rtmp_conn.go | 49 ++++++++++++++++++++---------- internal/h264/dtsestimator.go | 23 ++++++-------- internal/h264/dtsestimator_test.go | 23 ++++++++------ internal/hls/muxer.go | 16 +++++----- internal/hls/muxer_test.go | 3 +- 5 files changed, 65 insertions(+), 49 deletions(-) diff --git a/internal/core/rtmp_conn.go b/internal/core/rtmp_conn.go index 2d3cc66f..c6878e08 100644 --- a/internal/core/rtmp_conn.go +++ b/internal/core/rtmp_conn.go @@ -27,11 +27,6 @@ import ( const ( rtmpConnPauseAfterAuthError = 2 * time.Second - - // an offset is needed to - // - avoid negative PTS values - // - avoid PTS < DTS during startup - rtmpConnPTSOffset = 2 * time.Second ) func pathNameAndQuery(inURL *url.URL) (string, url.Values) { @@ -285,7 +280,9 @@ func (c *rtmpConn) runRead(ctx context.Context) error { c.conn.NetConn().SetReadDeadline(time.Time{}) var videoBuf [][]byte - videoDTSEst := h264.NewDTSEstimator() + var videoStartPTS time.Duration + var videoDTSEst *h264.DTSEstimator + videoFirstIDRFound := false for { data, ok := c.ringBuffer.Pull() @@ -324,11 +321,6 @@ func (c *rtmpConn) runRead(ctx context.Context) error { // RTP marker means that all the NALUs with the same PTS have been received. // send them together. if pkt.Marker { - data, err := h264.EncodeAVCC(videoBuf) - if err != nil { - return err - } - idrPresent := func() bool { for _, nalu := range nalus { typ := h264.NALUType(nalu[0] & 0x1F) @@ -339,9 +331,25 @@ func (c *rtmpConn) runRead(ctx context.Context) error { return false }() - pts += rtmpConnPTSOffset + // wait until we receive an IDR + if !videoFirstIDRFound { + if !idrPresent { + videoBuf = nil + continue + } - dts := videoDTSEst.Feed(idrPresent, pts) + videoFirstIDRFound = true + videoStartPTS = pts + videoDTSEst = h264.NewDTSEstimator() + } + + data, err := h264.EncodeAVCC(videoBuf) + if err != nil { + return err + } + + pts -= videoStartPTS + dts := videoDTSEst.Feed(pts) c.conn.NetConn().SetWriteDeadline(time.Now().Add(c.writeTimeout)) err = c.conn.WritePacket(av.Packet{ @@ -373,18 +381,27 @@ func (c *rtmpConn) runRead(ctx context.Context) error { continue } - for i, au := range aus { - auPTS := pts + rtmpConnPTSOffset + time.Duration(i)*1000*time.Second/time.Duration(audioClockRate) + if videoTrack != nil && !videoFirstIDRFound { + continue + } + pts -= videoStartPTS + if pts < 0 { + continue + } + + for _, au := range aus { c.conn.NetConn().SetWriteDeadline(time.Now().Add(c.writeTimeout)) err := c.conn.WritePacket(av.Packet{ Type: av.AAC, Data: au, - Time: auPTS, + Time: pts, }) if err != nil { return err } + + pts += 1000 * time.Second / time.Duration(audioClockRate) } } } diff --git a/internal/h264/dtsestimator.go b/internal/h264/dtsestimator.go index fb2022e9..e4f564c5 100644 --- a/internal/h264/dtsestimator.go +++ b/internal/h264/dtsestimator.go @@ -20,23 +20,20 @@ func NewDTSEstimator() *DTSEstimator { } // Feed provides PTS to the estimator, and returns the estimated DTS. -func (d *DTSEstimator) Feed(idrPresent bool, pts time.Duration) time.Duration { - if d.initializing > 0 { +func (d *DTSEstimator) Feed(pts time.Duration) time.Duration { + switch d.initializing { + case 2: + d.initializing-- + return 0 + + case 1: d.initializing-- - dts := d.prevDTS + time.Millisecond - d.prevPrevPTS = d.prevPTS d.prevPTS = pts - d.prevDTS = dts - return dts + d.prevDTS = time.Millisecond + return time.Millisecond } dts := func() time.Duration { - // IDR - if idrPresent { - // DTS is always PTS - return pts - } - // P or I frame if pts > d.prevPTS { // previous frame was B @@ -52,7 +49,7 @@ func (d *DTSEstimator) Feed(idrPresent bool, pts time.Duration) time.Duration { } // B Frame - // do not increase + // increase by a small quantity return d.prevDTS + time.Millisecond }() diff --git a/internal/h264/dtsestimator_test.go b/internal/h264/dtsestimator_test.go index 448ce3a4..1965f328 100644 --- a/internal/h264/dtsestimator_test.go +++ b/internal/h264/dtsestimator_test.go @@ -10,18 +10,23 @@ import ( func TestDTSEstimator(t *testing.T) { est := NewDTSEstimator() - dts := est.Feed(false, 2*time.Second) + // initial state + dts := est.Feed(0) + require.Equal(t, time.Duration(0), dts) + + // b-frame + dts = est.Feed(1*time.Second - 200*time.Millisecond) require.Equal(t, time.Millisecond, dts) - dts = est.Feed(false, 2*time.Second-200*time.Millisecond) + // b-frame + dts = est.Feed(1*time.Second - 400*time.Millisecond) require.Equal(t, 2*time.Millisecond, dts) - dts = est.Feed(false, 2*time.Second-400*time.Millisecond) - require.Equal(t, 3*time.Millisecond, dts) + // p-frame + dts = est.Feed(1 * time.Second) + require.Equal(t, 1*time.Second-400*time.Millisecond, dts) - dts = est.Feed(false, 2*time.Second+200*time.Millisecond) - require.Equal(t, 2*time.Second-400*time.Millisecond, dts) - - dts = est.Feed(true, 2*time.Second+300*time.Millisecond) - require.Equal(t, 2*time.Second+300*time.Millisecond, dts) + // p-frame + dts = est.Feed(1*time.Second + 200*time.Millisecond) + require.Equal(t, 1*time.Second-399*time.Millisecond, dts) } diff --git a/internal/hls/muxer.go b/internal/hls/muxer.go index a95904b0..18714cb8 100644 --- a/internal/hls/muxer.go +++ b/internal/hls/muxer.go @@ -10,10 +10,8 @@ import ( ) const ( - // an offset is needed to - // - avoid negative PTS values - // - avoid PTS < DTS during startup - ptsOffset = 2 * time.Second + // an offset between PCR and PTS/DTS is needed to avoid PCR > PTS + pcrOffset = 500 * time.Millisecond segmentMinAUCount = 100 ) @@ -67,7 +65,6 @@ func NewMuxer( audioTrack: audioTrack, h264Conf: h264Conf, aacConf: aacConf, - videoDTSEst: h264.NewDTSEstimator(), currentSegment: newSegment(videoTrack, audioTrack, h264Conf, aacConf), primaryPlaylist: newPrimaryPlaylist(videoTrack, audioTrack, h264Conf), streamPlaylist: newStreamPlaylist(hlsSegmentCount), @@ -110,13 +107,14 @@ func (m *Muxer) WriteH264(pts time.Duration, nalus [][]byte) error { m.startPCR = time.Now() m.startPTS = pts m.currentSegment.setStartPCR(m.startPCR) + m.videoDTSEst = h264.NewDTSEstimator() } - pts = pts + ptsOffset - m.startPTS + pts -= m.startPTS err := m.currentSegment.writeH264( - m.videoDTSEst.Feed(idrPresent, pts), - pts, + m.videoDTSEst.Feed(pts)+pcrOffset, + pts+pcrOffset, idrPresent, nalus) if err != nil { @@ -150,7 +148,7 @@ func (m *Muxer) WriteAAC(pts time.Duration, aus [][]byte) error { } } - pts = pts + ptsOffset - m.startPTS + pts = pts - m.startPTS + pcrOffset for i, au := range aus { auPTS := pts + time.Duration(i)*1000*time.Second/time.Duration(m.aacConf.SampleRate) diff --git a/internal/hls/muxer_test.go b/internal/hls/muxer_test.go index 94391578..51069e87 100644 --- a/internal/hls/muxer_test.go +++ b/internal/hls/muxer_test.go @@ -93,8 +93,7 @@ func TestMuxer(t *testing.T) { byts = byts[188:] checkTSPacket(t, byts, 256, 3) - alen := int(byts[4]) - byts = byts[4+alen+20:] + byts = byts[4+145+15:] require.Equal(t, []byte{ 0, 0, 0, 1, 9, 240, // AUD