From 4d6f8b9b9bb426e5c717d45f26d597f1917c3f49 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Sun, 6 Mar 2022 15:50:03 +0100 Subject: [PATCH] RTSP client/source: support dynamic H264 SPS/PPS --- internal/core/rtsp_source_test.go | 105 ++++++++++++++++-------------- internal/core/stream.go | 25 +++++++ 2 files changed, 80 insertions(+), 50 deletions(-) diff --git a/internal/core/rtsp_source_test.go b/internal/core/rtsp_source_test.go index 12a64c5f..11c148a2 100644 --- a/internal/core/rtsp_source_test.go +++ b/internal/core/rtsp_source_test.go @@ -216,7 +216,7 @@ func TestRTSPSourceNoPassword(t *testing.T) { <-done } -func TestRTSPSourceMissingH264Params(t *testing.T) { +func TestRTSPSourceDynamicH264Params(t *testing.T) { track, err := gortsplib.NewTrackH264(96, nil, nil, nil) require.NoError(t, err) @@ -235,35 +235,6 @@ func TestRTSPSourceMissingH264Params(t *testing.T) { }, stream, nil }, onPlay: func(ctx *gortsplib.ServerHandlerOnPlayCtx) (*base.Response, error) { - go func() { - time.Sleep(500 * time.Millisecond) - - enc := &rtph264.Encoder{PayloadType: 96} - enc.Init() - - pkts, err := enc.Encode([][]byte{{5}}, 0) // IDR - require.NoError(t, err) - stream.WritePacketRTP(0, pkts[0], true) - - pkts, err = enc.Encode([][]byte{{7, 1, 2, 3}}, 0) // SPS - require.NoError(t, err) - stream.WritePacketRTP(0, pkts[0], true) - - pkts, err = enc.Encode([][]byte{{8}}, 0) // PPS - require.NoError(t, err) - stream.WritePacketRTP(0, pkts[0], true) - - pkts, err = enc.Encode([][]byte{{5, 1}}, 0) // IDR - require.NoError(t, err) - stream.WritePacketRTP(0, pkts[0], true) - - time.Sleep(500 * time.Millisecond) - - pkts, err = enc.Encode([][]byte{{5, 2}}, 0) // IDR - require.NoError(t, err) - stream.WritePacketRTP(0, pkts[0], true) - }() - return &base.Response{ StatusCode: base.StatusOK, }, nil @@ -280,32 +251,66 @@ func TestRTSPSourceMissingH264Params(t *testing.T) { "hlsDisable: yes\n" + "paths:\n" + " proxied:\n" + - " source: rtsp://127.0.0.1:8555/teststream\n" + - " sourceOnDemand: yes\n") + " source: rtsp://127.0.0.1:8555/teststream\n") require.Equal(t, true, ok) defer p.close() - received := make(chan struct{}) + time.Sleep(1 * time.Second) - c := gortsplib.Client{ - OnPacketRTP: func(ctx *gortsplib.ClientOnPacketRTPCtx) { - if ctx.H264NALUs == nil { - return - } + enc := &rtph264.Encoder{PayloadType: 96} + enc.Init() - require.Equal(t, [][]byte{{0x05, 0x02}}, ctx.H264NALUs) - close(received) - }, - } - - err = c.StartReading("rtsp://127.0.0.1:8554/proxied") + pkts, err := enc.Encode([][]byte{{7, 1, 2, 3}}, 0) // SPS require.NoError(t, err) - defer c.Close() + stream.WritePacketRTP(0, pkts[0], true) - h264Track, ok := c.Tracks()[0].(*gortsplib.TrackH264) - require.Equal(t, true, ok) - require.Equal(t, []byte{7, 1, 2, 3}, h264Track.SPS()) - require.Equal(t, []byte{8}, h264Track.PPS()) + pkts, err = enc.Encode([][]byte{{8}}, 0) // PPS + require.NoError(t, err) + stream.WritePacketRTP(0, pkts[0], true) - <-received + func() { + c := gortsplib.Client{} + + u, err := base.ParseURL("rtsp://127.0.0.1:8554/proxied") + require.NoError(t, err) + + err = c.Start(u.Scheme, u.Host) + require.NoError(t, err) + defer c.Close() + + tracks, _, _, err := c.Describe(u) + require.NoError(t, err) + + h264Track, ok := tracks[0].(*gortsplib.TrackH264) + require.Equal(t, true, ok) + require.Equal(t, []byte{7, 1, 2, 3}, h264Track.SPS()) + require.Equal(t, []byte{8}, h264Track.PPS()) + }() + + pkts, err = enc.Encode([][]byte{{7, 4, 5, 6}}, 0) // SPS + require.NoError(t, err) + stream.WritePacketRTP(0, pkts[0], true) + + pkts, err = enc.Encode([][]byte{{8, 1}}, 0) // PPS + require.NoError(t, err) + stream.WritePacketRTP(0, pkts[0], true) + + func() { + c := gortsplib.Client{} + + u, err := base.ParseURL("rtsp://127.0.0.1:8554/proxied") + require.NoError(t, err) + + err = c.Start(u.Scheme, u.Host) + require.NoError(t, err) + defer c.Close() + + tracks, _, _, err := c.Describe(u) + require.NoError(t, err) + + h264Track, ok := tracks[0].(*gortsplib.TrackH264) + require.Equal(t, true, ok) + require.Equal(t, []byte{7, 4, 5, 6}, h264Track.SPS()) + require.Equal(t, []byte{8, 1}, h264Track.PPS()) + }() } diff --git a/internal/core/stream.go b/internal/core/stream.go index f5ea6066..4308acd4 100644 --- a/internal/core/stream.go +++ b/internal/core/stream.go @@ -1,9 +1,11 @@ package core import ( + "bytes" "sync" "github.com/aler9/gortsplib" + "github.com/aler9/gortsplib/pkg/h264" ) type streamNonRTSPReadersMap struct { @@ -78,7 +80,30 @@ func (s *stream) readerRemove(r reader) { } } +func (s *stream) updateH264TrackParameters(h264track *gortsplib.TrackH264, nalus [][]byte) { + for _, nalu := range nalus { + typ := h264.NALUType(nalu[0] & 0x1F) + + switch typ { + case h264.NALUTypeSPS: + if !bytes.Equal(nalu, h264track.SPS()) { + h264track.SetSPS(append([]byte(nil), nalu...)) + } + + case h264.NALUTypePPS: + if !bytes.Equal(nalu, h264track.PPS()) { + h264track.SetPPS(append([]byte(nil), nalu...)) + } + } + } +} + func (s *stream) writeData(trackID int, data *data) { + track := s.rtspStream.Tracks()[trackID] + if h264track, ok := track.(*gortsplib.TrackH264); ok { + s.updateH264TrackParameters(h264track, data.h264NALUs) + } + // forward to RTSP readers s.rtspStream.WritePacketRTP(trackID, data.rtp, data.ptsEqualsDTS)