From e189f4570c5687a7030a81a51112f49e09bf4fff Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Sat, 21 Aug 2021 21:39:44 +0200 Subject: [PATCH] hls, rtmp: set DTS = PTS when a IDR frame is received --- internal/core/rtmp_conn.go | 17 +++++++++++++++-- internal/h264/dtsestimator.go | 8 +++++++- internal/h264/dtsestimator_test.go | 18 ++++++++++++++---- internal/hls/muxer.go | 2 +- 4 files changed, 37 insertions(+), 8 deletions(-) diff --git a/internal/core/rtmp_conn.go b/internal/core/rtmp_conn.go index 4687f308..2d3cc66f 100644 --- a/internal/core/rtmp_conn.go +++ b/internal/core/rtmp_conn.go @@ -329,13 +329,26 @@ func (c *rtmpConn) runRead(ctx context.Context) error { return err } - dts := videoDTSEst.Feed(pts + rtmpConnPTSOffset) + idrPresent := func() bool { + for _, nalu := range nalus { + typ := h264.NALUType(nalu[0] & 0x1F) + if typ == h264.NALUTypeIDR { + return true + } + } + return false + }() + + pts += rtmpConnPTSOffset + + dts := videoDTSEst.Feed(idrPresent, pts) + c.conn.NetConn().SetWriteDeadline(time.Now().Add(c.writeTimeout)) err = c.conn.WritePacket(av.Packet{ Type: av.H264, Data: data, Time: dts, - CTime: pts + rtmpConnPTSOffset - dts, + CTime: pts - dts, }) if err != nil { return err diff --git a/internal/h264/dtsestimator.go b/internal/h264/dtsestimator.go index 9da4bcb3..fb2022e9 100644 --- a/internal/h264/dtsestimator.go +++ b/internal/h264/dtsestimator.go @@ -20,7 +20,7 @@ func NewDTSEstimator() *DTSEstimator { } // Feed provides PTS to the estimator, and returns the estimated DTS. -func (d *DTSEstimator) Feed(pts time.Duration) time.Duration { +func (d *DTSEstimator) Feed(idrPresent bool, pts time.Duration) time.Duration { if d.initializing > 0 { d.initializing-- dts := d.prevDTS + time.Millisecond @@ -31,6 +31,12 @@ func (d *DTSEstimator) Feed(pts time.Duration) time.Duration { } dts := func() time.Duration { + // IDR + if idrPresent { + // DTS is always PTS + return pts + } + // P or I frame if pts > d.prevPTS { // previous frame was B diff --git a/internal/h264/dtsestimator_test.go b/internal/h264/dtsestimator_test.go index a5c8733d..448ce3a4 100644 --- a/internal/h264/dtsestimator_test.go +++ b/internal/h264/dtsestimator_test.go @@ -9,9 +9,19 @@ import ( func TestDTSEstimator(t *testing.T) { est := NewDTSEstimator() - est.Feed(2 * time.Second) - est.Feed(2*time.Second - 200*time.Millisecond) - est.Feed(2*time.Second - 400*time.Millisecond) - dts := est.Feed(2*time.Second + 200*time.Millisecond) + + dts := est.Feed(false, 2*time.Second) + require.Equal(t, time.Millisecond, dts) + + dts = est.Feed(false, 2*time.Second-200*time.Millisecond) + require.Equal(t, 2*time.Millisecond, dts) + + dts = est.Feed(false, 2*time.Second-400*time.Millisecond) + require.Equal(t, 3*time.Millisecond, dts) + + dts = est.Feed(false, 2*time.Second+200*time.Millisecond) require.Equal(t, 2*time.Second-400*time.Millisecond, dts) + + dts = est.Feed(true, 2*time.Second+300*time.Millisecond) + require.Equal(t, 2*time.Second+300*time.Millisecond, dts) } diff --git a/internal/hls/muxer.go b/internal/hls/muxer.go index d0fa3eb8..a95904b0 100644 --- a/internal/hls/muxer.go +++ b/internal/hls/muxer.go @@ -115,7 +115,7 @@ func (m *Muxer) WriteH264(pts time.Duration, nalus [][]byte) error { pts = pts + ptsOffset - m.startPTS err := m.currentSegment.writeH264( - m.videoDTSEst.Feed(pts), + m.videoDTSEst.Feed(idrPresent, pts), pts, idrPresent, nalus)