From 89cf3bb2fa5f2b1b2a9361d32f808a8e4719cb6c Mon Sep 17 00:00:00 2001 From: Alessandro Ros Date: Wed, 20 Sep 2023 12:32:01 +0200 Subject: [PATCH] fix crash when processing H265 (#2378) (#2381) --- internal/formatprocessor/h264_test.go | 107 ++++++++++++---------- internal/formatprocessor/h265.go | 2 +- internal/formatprocessor/h265_test.go | 122 +++++++++++++++----------- 3 files changed, 131 insertions(+), 100 deletions(-) diff --git a/internal/formatprocessor/h264_test.go b/internal/formatprocessor/h264_test.go index 9477676b..b0567cb9 100644 --- a/internal/formatprocessor/h264_test.go +++ b/internal/formatprocessor/h264_test.go @@ -14,53 +14,68 @@ import ( ) func TestH264DynamicParams(t *testing.T) { - forma := &format.H264{ - PayloadTyp: 96, - PacketizationMode: 1, + for _, ca := range []string{"standard", "aggregated"} { + t.Run(ca, func(t *testing.T) { + forma := &format.H264{ + PayloadTyp: 96, + PacketizationMode: 1, + } + + p, err := New(1472, forma, false) + require.NoError(t, err) + + enc, err := forma.CreateEncoder() + require.NoError(t, err) + + pkts, err := enc.Encode([][]byte{{byte(h264.NALUTypeIDR)}}) + require.NoError(t, err) + + data, err := p.ProcessRTPPacket(pkts[0], time.Time{}, 0, true) + require.NoError(t, err) + + require.Equal(t, [][]byte{ + {byte(h264.NALUTypeIDR)}, + }, data.(*unit.H264).AU) + + if ca == "standard" { + pkts, err = enc.Encode([][]byte{{7, 4, 5, 6}}) // SPS + require.NoError(t, err) + + _, err = p.ProcessRTPPacket(pkts[0], time.Time{}, 0, false) + require.NoError(t, err) + + pkts, err = enc.Encode([][]byte{{8, 1}}) // PPS + require.NoError(t, err) + + _, err = p.ProcessRTPPacket(pkts[0], time.Time{}, 0, false) + require.NoError(t, err) + } else { + pkts, err = enc.Encode([][]byte{ + {7, 4, 5, 6}, // SPS + {8, 1}, // PPS + }) + require.NoError(t, err) + + _, err = p.ProcessRTPPacket(pkts[0], time.Time{}, 0, false) + require.NoError(t, err) + } + + require.Equal(t, []byte{7, 4, 5, 6}, forma.SPS) + require.Equal(t, []byte{8, 1}, forma.PPS) + + pkts, err = enc.Encode([][]byte{{byte(h264.NALUTypeIDR)}}) + require.NoError(t, err) + + data, err = p.ProcessRTPPacket(pkts[0], time.Time{}, 0, true) + require.NoError(t, err) + + require.Equal(t, [][]byte{ + {0x07, 4, 5, 6}, + {0x08, 1}, + {byte(h264.NALUTypeIDR)}, + }, data.(*unit.H264).AU) + }) } - - p, err := New(1472, forma, false) - require.NoError(t, err) - - enc, err := forma.CreateEncoder() - require.NoError(t, err) - - pkts, err := enc.Encode([][]byte{{byte(h264.NALUTypeIDR)}}) - require.NoError(t, err) - - data, err := p.ProcessRTPPacket(pkts[0], time.Time{}, 0, true) - require.NoError(t, err) - - require.Equal(t, [][]byte{ - {byte(h264.NALUTypeIDR)}, - }, data.(*unit.H264).AU) - - pkts, err = enc.Encode([][]byte{{7, 4, 5, 6}}) // SPS - require.NoError(t, err) - - _, err = p.ProcessRTPPacket(pkts[0], time.Time{}, 0, false) - require.NoError(t, err) - - pkts, err = enc.Encode([][]byte{{8, 1}}) // PPS - require.NoError(t, err) - - _, err = p.ProcessRTPPacket(pkts[0], time.Time{}, 0, false) - require.NoError(t, err) - - require.Equal(t, []byte{7, 4, 5, 6}, forma.SPS) - require.Equal(t, []byte{8, 1}, forma.PPS) - - pkts, err = enc.Encode([][]byte{{byte(h264.NALUTypeIDR)}}) - require.NoError(t, err) - - data, err = p.ProcessRTPPacket(pkts[0], time.Time{}, 0, true) - require.NoError(t, err) - - require.Equal(t, [][]byte{ - {0x07, 4, 5, 6}, - {0x08, 1}, - {byte(h264.NALUTypeIDR)}, - }, data.(*unit.H264).AU) } func TestH264OversizedPackets(t *testing.T) { diff --git a/internal/formatprocessor/h265.go b/internal/formatprocessor/h265.go index 63bd9a36..57e3767d 100644 --- a/internal/formatprocessor/h265.go +++ b/internal/formatprocessor/h265.go @@ -55,7 +55,7 @@ func rtpH265ExtractVPSSPSPPS(payload []byte) ([]byte, []byte, []byte) { nalu := payload[:size] payload = payload[size:] - typ = h265.NALUType((payload[0] >> 1) & 0b111111) + typ = h265.NALUType((nalu[0] >> 1) & 0b111111) switch typ { case h265.NALUType_VPS_NUT: diff --git a/internal/formatprocessor/h265_test.go b/internal/formatprocessor/h265_test.go index c4b96cc1..6e183bdc 100644 --- a/internal/formatprocessor/h265_test.go +++ b/internal/formatprocessor/h265_test.go @@ -14,60 +14,76 @@ import ( ) func TestH265DynamicParams(t *testing.T) { - forma := &format.H265{ - PayloadTyp: 96, + for _, ca := range []string{"standard", "aggregated"} { + t.Run(ca, func(t *testing.T) { + forma := &format.H265{ + PayloadTyp: 96, + } + + p, err := New(1472, forma, false) + require.NoError(t, err) + + enc, err := forma.CreateEncoder() + require.NoError(t, err) + + pkts, err := enc.Encode([][]byte{{byte(h265.NALUType_CRA_NUT) << 1, 0}}) + require.NoError(t, err) + + data, err := p.ProcessRTPPacket(pkts[0], time.Time{}, 0, true) + require.NoError(t, err) + + require.Equal(t, [][]byte{ + {byte(h265.NALUType_CRA_NUT) << 1, 0}, + }, data.(*unit.H265).AU) + + if ca == "standard" { + pkts, err = enc.Encode([][]byte{{byte(h265.NALUType_VPS_NUT) << 1, 1, 2, 3}}) + require.NoError(t, err) + + _, err = p.ProcessRTPPacket(pkts[0], time.Time{}, 0, false) + require.NoError(t, err) + + pkts, err = enc.Encode([][]byte{{byte(h265.NALUType_SPS_NUT) << 1, 4, 5, 6}}) + require.NoError(t, err) + + _, err = p.ProcessRTPPacket(pkts[0], time.Time{}, 0, false) + require.NoError(t, err) + + pkts, err = enc.Encode([][]byte{{byte(h265.NALUType_PPS_NUT) << 1, 7, 8, 9}}) + require.NoError(t, err) + + _, err = p.ProcessRTPPacket(pkts[0], time.Time{}, 0, false) + require.NoError(t, err) + } else { + pkts, err = enc.Encode([][]byte{ + {byte(h265.NALUType_VPS_NUT) << 1, 1, 2, 3}, + {byte(h265.NALUType_SPS_NUT) << 1, 4, 5, 6}, + {byte(h265.NALUType_PPS_NUT) << 1, 7, 8, 9}, + }) + require.NoError(t, err) + + _, err = p.ProcessRTPPacket(pkts[0], time.Time{}, 0, false) + require.NoError(t, err) + } + + require.Equal(t, []byte{byte(h265.NALUType_VPS_NUT) << 1, 1, 2, 3}, forma.VPS) + require.Equal(t, []byte{byte(h265.NALUType_SPS_NUT) << 1, 4, 5, 6}, forma.SPS) + require.Equal(t, []byte{byte(h265.NALUType_PPS_NUT) << 1, 7, 8, 9}, forma.PPS) + + pkts, err = enc.Encode([][]byte{{byte(h265.NALUType_CRA_NUT) << 1, 0}}) + require.NoError(t, err) + + data, err = p.ProcessRTPPacket(pkts[0], time.Time{}, 0, true) + require.NoError(t, err) + + require.Equal(t, [][]byte{ + {byte(h265.NALUType_VPS_NUT) << 1, 1, 2, 3}, + {byte(h265.NALUType_SPS_NUT) << 1, 4, 5, 6}, + {byte(h265.NALUType_PPS_NUT) << 1, 7, 8, 9}, + {byte(h265.NALUType_CRA_NUT) << 1, 0}, + }, data.(*unit.H265).AU) + }) } - - p, err := New(1472, forma, false) - require.NoError(t, err) - - enc, err := forma.CreateEncoder() - require.NoError(t, err) - - pkts, err := enc.Encode([][]byte{{byte(h265.NALUType_CRA_NUT) << 1, 0}}) - require.NoError(t, err) - - data, err := p.ProcessRTPPacket(pkts[0], time.Time{}, 0, true) - require.NoError(t, err) - - require.Equal(t, [][]byte{ - {byte(h265.NALUType_CRA_NUT) << 1, 0}, - }, data.(*unit.H265).AU) - - pkts, err = enc.Encode([][]byte{{byte(h265.NALUType_VPS_NUT) << 1, 1, 2, 3}}) - require.NoError(t, err) - - _, err = p.ProcessRTPPacket(pkts[0], time.Time{}, 0, false) - require.NoError(t, err) - - pkts, err = enc.Encode([][]byte{{byte(h265.NALUType_SPS_NUT) << 1, 4, 5, 6}}) - require.NoError(t, err) - - _, err = p.ProcessRTPPacket(pkts[0], time.Time{}, 0, false) - require.NoError(t, err) - - pkts, err = enc.Encode([][]byte{{byte(h265.NALUType_PPS_NUT) << 1, 7, 8, 9}}) - require.NoError(t, err) - - _, err = p.ProcessRTPPacket(pkts[0], time.Time{}, 0, false) - require.NoError(t, err) - - require.Equal(t, []byte{byte(h265.NALUType_VPS_NUT) << 1, 1, 2, 3}, forma.VPS) - require.Equal(t, []byte{byte(h265.NALUType_SPS_NUT) << 1, 4, 5, 6}, forma.SPS) - require.Equal(t, []byte{byte(h265.NALUType_PPS_NUT) << 1, 7, 8, 9}, forma.PPS) - - pkts, err = enc.Encode([][]byte{{byte(h265.NALUType_CRA_NUT) << 1, 0}}) - require.NoError(t, err) - - data, err = p.ProcessRTPPacket(pkts[0], time.Time{}, 0, true) - require.NoError(t, err) - - require.Equal(t, [][]byte{ - {byte(h265.NALUType_VPS_NUT) << 1, 1, 2, 3}, - {byte(h265.NALUType_SPS_NUT) << 1, 4, 5, 6}, - {byte(h265.NALUType_PPS_NUT) << 1, 7, 8, 9}, - {byte(h265.NALUType_CRA_NUT) << 1, 0}, - }, data.(*unit.H265).AU) } func TestH265OversizedPackets(t *testing.T) {