From 7e180ceea2b37e0fe0be95969b36bee316485922 Mon Sep 17 00:00:00 2001 From: Alessandro Ros Date: Thu, 10 Aug 2023 21:06:51 +0200 Subject: [PATCH] rtmp: support ingesting RTMPE streams (#2189) --- internal/rtmp/conn.go | 21 +- internal/rtmp/conn_test.go | 8 +- internal/rtmp/handshake/c0s0.go | 20 +- internal/rtmp/handshake/c0s0_test.go | 6 +- internal/rtmp/handshake/c1s1.go | 252 ++++++++++-------- internal/rtmp/handshake/c1s1_test.go | 136 +--------- internal/rtmp/handshake/c2s2.go | 101 +++++-- internal/rtmp/handshake/c2s2_test.go | 67 +---- internal/rtmp/handshake/dh.go | 118 +++++++++ internal/rtmp/handshake/handshake.go | 309 ++++++++++++++++++---- internal/rtmp/handshake/handshake_test.go | 115 +++----- internal/rtmp/message/reader.go | 9 +- internal/rtmp/message/reader_test.go | 3 +- internal/rtmp/message/readwriter.go | 12 +- internal/rtmp/message/readwriter_test.go | 20 +- internal/rtmp/message/writer.go | 10 +- internal/rtmp/message/writer_test.go | 3 +- internal/rtmp/rawmessage/reader.go | 13 +- internal/rtmp/rawmessage/reader_test.go | 5 +- internal/rtmp/rawmessage/writer.go | 13 +- internal/rtmp/rawmessage/writer_test.go | 9 +- internal/rtmp/rc4_readwriter.go | 49 ++++ internal/rtmp/reader_test.go | 2 +- internal/rtmp/writer_test.go | 2 +- 24 files changed, 806 insertions(+), 497 deletions(-) create mode 100644 internal/rtmp/handshake/dh.go create mode 100644 internal/rtmp/rc4_readwriter.go diff --git a/internal/rtmp/conn.go b/internal/rtmp/conn.go index 3bcc736a..431ecc0d 100644 --- a/internal/rtmp/conn.go +++ b/internal/rtmp/conn.go @@ -156,12 +156,12 @@ func NewClientConn(rw io.ReadWriter, u *url.URL, publish bool) (*Conn, error) { func (c *Conn) initializeClient(u *url.URL, publish bool) error { connectpath, actionpath := splitPath(u) - err := handshake.DoClient(c.bc, false) + _, _, err := handshake.DoClient(c.bc, false, false) if err != nil { return err } - c.mrw = message.NewReadWriter(c.bc, false) + c.mrw = message.NewReadWriter(c.bc, c.bc, false) err = c.mrw.Write(&message.SetWindowAckSize{ Value: 2500000, @@ -329,12 +329,23 @@ func NewServerConn(rw io.ReadWriter) (*Conn, *url.URL, bool, error) { } func (c *Conn) initializeServer() (*url.URL, bool, error) { - err := handshake.DoServer(c.bc, false) + keyIn, keyOut, err := handshake.DoServer(c.bc, false) if err != nil { return nil, false, err } - c.mrw = message.NewReadWriter(c.bc, false) + var rw io.ReadWriter + if keyIn != nil { + var err error + rw, err = newRC4ReadWriter(c.bc, keyIn, keyOut) + if err != nil { + return nil, false, err + } + } else { + rw = c.bc + } + + c.mrw = message.NewReadWriter(rw, c.bc, false) cmd, err := readCommand(c.mrw) if err != nil { @@ -581,7 +592,7 @@ func newNoHandshakeConn(rw io.ReadWriter) *Conn { bc: bytecounter.NewReadWriter(rw), } - c.mrw = message.NewReadWriter(c.bc, false) + c.mrw = message.NewReadWriter(c.bc, c.bc, false) return c } diff --git a/internal/rtmp/conn_test.go b/internal/rtmp/conn_test.go index acd63d6a..3cd98fd7 100644 --- a/internal/rtmp/conn_test.go +++ b/internal/rtmp/conn_test.go @@ -29,10 +29,10 @@ func TestNewClientConn(t *testing.T) { defer conn.Close() bc := bytecounter.NewReadWriter(conn) - err = handshake.DoServer(bc, true) + _, _, err = handshake.DoServer(bc, false) require.NoError(t, err) - mrw := message.NewReadWriter(bc, true) + mrw := message.NewReadWriter(bc, bc, true) msg, err := mrw.Read() require.NoError(t, err) @@ -289,10 +289,10 @@ func TestNewServerConn(t *testing.T) { defer conn.Close() bc := bytecounter.NewReadWriter(conn) - err = handshake.DoClient(bc, true) + _, _, err = handshake.DoClient(bc, false, false) require.NoError(t, err) - mrw := message.NewReadWriter(bc, true) + mrw := message.NewReadWriter(bc, bc, true) tcURL := "rtmp://127.0.0.1:9121/stream" if ca == "publish neko" { diff --git a/internal/rtmp/handshake/c0s0.go b/internal/rtmp/handshake/c0s0.go index 0e650c38..eee56575 100644 --- a/internal/rtmp/handshake/c0s0.go +++ b/internal/rtmp/handshake/c0s0.go @@ -5,30 +5,30 @@ import ( "io" ) -const ( - rtmpVersion = 0x03 -) - // C0S0 is a C0 or S0 packet. -type C0S0 struct{} +type C0S0 struct { + Version byte +} // Read reads a C0S0. -func (C0S0) Read(r io.Reader) error { +func (c *C0S0) Read(r io.Reader) error { buf := make([]byte, 1) _, err := io.ReadFull(r, buf) if err != nil { return err } - if buf[0] != rtmpVersion { - return fmt.Errorf("invalid rtmp version (%d)", buf[0]) + c.Version = buf[0] + + if c.Version != 3 && c.Version != 6 { + return fmt.Errorf("invalid rtmp version (%d)", c.Version) } return nil } // Write writes a C0S0. -func (C0S0) Write(w io.Writer) error { - _, err := w.Write([]byte{rtmpVersion}) +func (c C0S0) Write(w io.Writer) error { + _, err := w.Write([]byte{c.Version}) return err } diff --git a/internal/rtmp/handshake/c0s0_test.go b/internal/rtmp/handshake/c0s0_test.go index 4f2dc703..98503349 100644 --- a/internal/rtmp/handshake/c0s0_test.go +++ b/internal/rtmp/handshake/c0s0_test.go @@ -7,9 +7,11 @@ import ( "github.com/stretchr/testify/require" ) -var c0s0enc = []byte{0x03} +var c0s0enc = []byte{3} -var c0s0dec = C0S0{} +var c0s0dec = C0S0{ + Version: 3, +} func TestC0S0Read(t *testing.T) { var c0s0 C0S0 diff --git a/internal/rtmp/handshake/c1s1.go b/internal/rtmp/handshake/c1s1.go index 4bfe7585..379ad0ff 100644 --- a/internal/rtmp/handshake/c1s1.go +++ b/internal/rtmp/handshake/c1s1.go @@ -9,140 +9,180 @@ import ( "io" ) -var ( - hsClientFullKey = []byte{ - 'G', 'e', 'n', 'u', 'i', 'n', 'e', ' ', 'A', 'd', 'o', 'b', 'e', ' ', - 'F', 'l', 'a', 's', 'h', ' ', 'P', 'l', 'a', 'y', 'e', 'r', ' ', - '0', '0', '1', - 0xF0, 0xEE, 0xC2, 0x4A, 0x80, 0x68, 0xBE, 0xE8, 0x2E, 0x00, 0xD0, 0xD1, - 0x02, 0x9E, 0x7E, 0x57, 0x6E, 0xEC, 0x5D, 0x2D, 0x29, 0x80, 0x6F, 0xAB, - 0x93, 0xB8, 0xE6, 0x36, 0xCF, 0xEB, 0x31, 0xAE, - } - hsServerFullKey = []byte{ - 'G', 'e', 'n', 'u', 'i', 'n', 'e', ' ', 'A', 'd', 'o', 'b', 'e', ' ', - 'F', 'l', 'a', 's', 'h', ' ', 'M', 'e', 'd', 'i', 'a', ' ', - 'S', 'e', 'r', 'v', 'e', 'r', ' ', - '0', '0', '1', - 0xF0, 0xEE, 0xC2, 0x4A, 0x80, 0x68, 0xBE, 0xE8, 0x2E, 0x00, 0xD0, 0xD1, - 0x02, 0x9E, 0x7E, 0x57, 0x6E, 0xEC, 0x5D, 0x2D, 0x29, 0x80, 0x6F, 0xAB, - 0x93, 0xB8, 0xE6, 0x36, 0xCF, 0xEB, 0x31, 0xAE, - } - hsClientPartialKey = hsClientFullKey[:30] - hsServerPartialKey = hsServerFullKey[:36] +const ( + c1s1Size = 1536 + digestPointerPos1 = 0 + digestPointerPos2 = 772 - 8 + digestChunkPos1 = digestPointerPos1 + 4 + digestChunkPos2 = digestPointerPos2 + 4 + digestChunkLength = 728 + digestLength = 32 + publicKeyPointerPos1 = 1532 - 8 + publicKeyPointerPos2 = 768 - 8 + publicKeyChunkPos1 = publicKeyPointerPos1 - 760 + publicKeyChunkPos2 = publicKeyPointerPos2 - 760 + publicKeyChunkLength = 632 ) -func hsCalcDigestPos(p []byte, base int) int { - pos := 0 - for i := 0; i < 4; i++ { - pos += int(p[base+i]) - } - return (pos % 728) + base + 4 -} +var ( + clientKeyC1 = []byte("Genuine Adobe Flash Player 001") + serverKeyS1 = []byte("Genuine Adobe Flash Media Server 001") +) -func hsMakeDigest(key []byte, src []byte, gap int) []byte { +func hmacSha256(key []byte, buf []byte) []byte { h := hmac.New(sha256.New, key) - if gap <= 0 { - h.Write(src) - } else { - h.Write(src[:gap]) - h.Write(src[gap+32:]) - } + h.Write(buf) return h.Sum(nil) } -func hsFindDigest(p []byte, key []byte, base int) int { - gap := hsCalcDigestPos(p, base) - digest := hsMakeDigest(key, p, gap) - if !bytes.Equal(p[gap:gap+32], digest) { - return -1 - } - return gap -} - -func hsParse1(p []byte, peerkey []byte, key []byte) (bool, []byte) { - var pos int - if pos = hsFindDigest(p, peerkey, 772); pos == -1 { - if pos = hsFindDigest(p, peerkey, 8); pos == -1 { - return false, nil - } - } - return true, hsMakeDigest(key, p[pos:pos+32], -1) -} - // C1S1 is a C1 or S1 packet. type C1S1 struct { - Time uint32 - Random []byte - Digest []byte + Time uint32 + Version uint32 + Data []byte } // Read reads a C1S1. -func (c *C1S1) Read(r io.Reader, isC1 bool, validateSignature bool) error { - buf := make([]byte, 1536) +func (c *C1S1) Read(r io.Reader) error { + buf := make([]byte, c1s1Size) _, err := io.ReadFull(r, buf) if err != nil { return err } - var peerKey []byte - var key []byte - if isC1 { - peerKey = hsClientPartialKey - key = hsServerFullKey - } else { - peerKey = hsServerPartialKey - key = hsClientFullKey - } - ok, digest := hsParse1(buf, peerKey, key) - if !ok { - if validateSignature { - return fmt.Errorf("unable to validate C1/S1 signature") - } - } else { - c.Digest = digest - } - c.Time = uint32(buf[0])<<24 | uint32(buf[1])<<16 | uint32(buf[2])<<8 | uint32(buf[3]) - c.Random = buf[8:] + c.Version = uint32(buf[4])<<24 | uint32(buf[5])<<16 | uint32(buf[6])<<8 | uint32(buf[7]) + c.Data = buf[8:] return nil } +func (c C1S1) readPointer(p int) int { + return int(c.Data[p]) + int(c.Data[p+1]) + int(c.Data[p+2]) + int(c.Data[p+3]) +} + +func (c C1S1) publicKeyPos1() int { + return publicKeyChunkPos1 + (c.readPointer(publicKeyPointerPos1) % publicKeyChunkLength) +} + +func (c C1S1) publicKeyPos2() int { + return publicKeyChunkPos2 + (c.readPointer(publicKeyPointerPos2) % publicKeyChunkLength) +} + +func (c C1S1) digestPos1() int { + return digestChunkPos1 + (c.readPointer(digestPointerPos1) % digestChunkLength) +} + +func (c C1S1) digestPos2() int { + return digestChunkPos2 + (c.readPointer(digestPointerPos2) % digestChunkLength) +} + +func (c C1S1) computeDigest(digestPos int, isS1 bool) []byte { + // hash entire message except digest + msg := make([]byte, c1s1Size-digestLength) + msg[0] = byte(c.Time >> 24) + msg[1] = byte(c.Time >> 16) + msg[2] = byte(c.Time >> 8) + msg[3] = byte(c.Time) + msg[4] = byte(c.Version >> 24) + msg[5] = byte(c.Version >> 16) + msg[6] = byte(c.Version >> 8) + msg[7] = byte(c.Version) + copy(msg[8:], c.Data[:digestPos]) + copy(msg[8+digestPos:], c.Data[digestPos+digestLength:]) + + if isS1 { + return hmacSha256(serverKeyS1, msg) + } + return hmacSha256(clientKeyC1, msg) +} + +func (c C1S1) validateDigest(isS1 bool) ([]byte, []byte, error) { + digestPos := c.digestPos1() + d1 := c.Data[digestPos : digestPos+digestLength] + d2 := c.computeDigest(digestPos, isS1) + + if bytes.Equal(d1, d2) { + publicKeyPos := c.publicKeyPos1() + publicKey := c.Data[publicKeyPos : publicKeyPos+dhKeyLength] + return d1, publicKey, nil + } + + digestPos = c.digestPos2() + d1 = c.Data[digestPos : digestPos+digestLength] + d2 = c.computeDigest(digestPos, isS1) + + if bytes.Equal(d1, d2) { + publicKeyPos := c.publicKeyPos2() + publicKey := c.Data[publicKeyPos : publicKeyPos+dhKeyLength] + return d1, publicKey, nil + } + + return nil, nil, fmt.Errorf("unable to validate C1/S1 digest") +} + +func (c C1S1) validate(isS1 bool) ([]byte, []byte, error) { + digest, publicKey, err := c.validateDigest(isS1) + if err != nil { + return nil, nil, err + } + + err = dhValidatePublicKey(publicKey) + if err != nil { + return nil, nil, err + } + + return digest, publicKey, nil +} + +func (c *C1S1) fillPlain() error { + c.Data = make([]byte, c1s1Size-8) + _, err := rand.Read(c.Data) + return err +} + +func (c *C1S1) fill(isS1 bool, publicKey []byte) ([]byte, error) { + err := c.fillPlain() + if err != nil { + return nil, err + } + + var r [1]byte + _, err = rand.Read(r[:]) + if err != nil { + return nil, err + } + + var digestPos int + var publicKeyPos int + + if r[0] == 0 { + digestPos = c.digestPos1() + publicKeyPos = c.publicKeyPos1() + } else { + digestPos = c.digestPos2() + publicKeyPos = c.publicKeyPos2() + } + + copy(c.Data[publicKeyPos:], publicKey) + digest := c.computeDigest(digestPos, isS1) + copy(c.Data[digestPos:], digest) + return digest, nil +} + // Write writes a C1S1. -func (c *C1S1) Write(w io.Writer, isC1 bool) error { - buf := make([]byte, 1536) +func (c C1S1) Write(w io.Writer) error { + buf := make([]byte, c1s1Size) buf[0] = byte(c.Time >> 24) buf[1] = byte(c.Time >> 16) buf[2] = byte(c.Time >> 8) buf[3] = byte(c.Time) - copy(buf[4:], []byte{0, 0, 0, 0}) - - if c.Random == nil { - _, err := rand.Read(buf[8:]) - if err != nil { - return err - } - c.Random = buf[8:] - } else { - copy(buf[8:], c.Random) - } - - // signature - gap := hsCalcDigestPos(buf, 8) - var peerKey []byte - var key []byte - if isC1 { - peerKey = hsServerFullKey - key = hsClientPartialKey - } else { - peerKey = hsClientFullKey - key = hsServerPartialKey - } - digest := hsMakeDigest(key, buf, gap) - copy(buf[gap:], digest) - pos := hsFindDigest(buf, key, 8) - c.Digest = hsMakeDigest(peerKey, buf[pos:pos+32], -1) + buf[4] = byte(c.Version >> 24) + buf[5] = byte(c.Version >> 16) + buf[6] = byte(c.Version >> 8) + buf[7] = byte(c.Version) + copy(buf[8:], c.Data) _, err := w.Write(buf) return err diff --git a/internal/rtmp/handshake/c1s1_test.go b/internal/rtmp/handshake/c1s1_test.go index 6dabf67b..e8a00fa0 100644 --- a/internal/rtmp/handshake/c1s1_test.go +++ b/internal/rtmp/handshake/c1s1_test.go @@ -7,132 +7,24 @@ import ( "github.com/stretchr/testify/require" ) -var c1enc = append( - []byte{ - 0x19, 0xf1, 0x27, 0xa3, 0x00, 0x00, 0x00, 0x00, - 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04, - 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x2d, 0x0a, - 0x37, 0x6f, 0x63, 0x2e, 0xa0, 0x21, 0xa0, 0xa4, - 0x81, 0xb1, 0x50, 0x21, 0x5a, 0x6d, 0x81, 0xad, - 0xf8, 0x44, 0x69, 0x13, 0xcc, 0x02, 0x8c, 0xd4, - 0x64, 0x43, 0xc9, 0x9f, 0xcf, 0xc6, 0x03, 0x04, - }, - bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 370)..., -) +var c1s1enc = bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 1536/4) -var s1enc = append( - []byte{ - 0x19, 0xf1, 0x27, 0xa3, 0x00, 0x00, 0x00, 0x00, - 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04, - 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x95, 0xc1, - 0xb6, 0x2c, 0x99, 0xbe, 0xa0, 0x0c, 0x07, 0x98, - 0xb0, 0xf1, 0xbe, 0x54, 0x50, 0x63, 0xa1, 0x25, - 0x1c, 0x9a, 0xcd, 0x12, 0x10, 0x98, 0x74, 0x8b, - 0x18, 0x66, 0x8d, 0xef, 0xcf, 0x22, 0x03, 0x04, - }, - bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 370)..., -) +var c1s1dec = C1S1{ + Time: 16909060, + Version: 16909060, + Data: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 1536/4-2), +} func TestC1S1Read(t *testing.T) { - for _, ca := range []struct { - isC1 bool - enc []byte - dec C1S1 - }{ - { - true, - c1enc, - C1S1{ - Time: 435234723, - Random: append( - []byte{ - 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04, - 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x2d, 0x0a, - 0x37, 0x6f, 0x63, 0x2e, 0xa0, 0x21, 0xa0, 0xa4, - 0x81, 0xb1, 0x50, 0x21, 0x5a, 0x6d, 0x81, 0xad, - 0xf8, 0x44, 0x69, 0x13, 0xcc, 0x02, 0x8c, 0xd4, - 0x64, 0x43, 0xc9, 0x9f, 0xcf, 0xc6, 0x03, 0x04, - }, - bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 370)..., - ), - Digest: []byte{ - 0x3f, 0xd0, 0xb1, 0xdf, 0xed, 0x6c, 0x9b, 0xc3, - 0x73, 0x68, 0xe2, 0x47, 0x92, 0x59, 0x32, 0x9a, - 0x3a, 0xc9, 0x1e, 0xeb, 0xfc, 0xad, 0x8e, 0x9d, - 0x4e, 0xf4, 0x30, 0xb1, 0x9a, 0xc9, 0x15, 0x99, - }, - }, - }, - { - false, - s1enc, - C1S1{ - Time: 435234723, - Random: append( - []byte{ - 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04, - 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x95, 0xc1, - 0xb6, 0x2c, 0x99, 0xbe, 0xa0, 0x0c, 0x07, 0x98, - 0xb0, 0xf1, 0xbe, 0x54, 0x50, 0x63, 0xa1, 0x25, - 0x1c, 0x9a, 0xcd, 0x12, 0x10, 0x98, 0x74, 0x8b, - 0x18, 0x66, 0x8d, 0xef, 0xcf, 0x22, 0x03, 0x04, - }, - bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 370)..., - ), - Digest: []byte{ - 0x0e, 0x8f, 0x96, 0x19, 0x19, 0xe6, 0xb7, 0xf2, - 0xac, 0x9a, 0xc8, 0x7e, 0x6e, 0xe9, 0xd4, 0x72, - 0xed, 0x82, 0x87, 0xf1, 0xfa, 0xbd, 0x93, 0xb8, - 0x7c, 0x48, 0x85, 0x03, 0x01, 0x7b, 0x54, 0xbe, - }, - }, - }, - } { - var c1s1 C1S1 - err := c1s1.Read((bytes.NewReader(ca.enc)), ca.isC1, true) - require.NoError(t, err) - require.Equal(t, ca.dec, c1s1) - } + var c1s1 C1S1 + err := c1s1.Read((bytes.NewReader(c1s1enc))) + require.NoError(t, err) + require.Equal(t, c1s1dec, c1s1) } func TestC1S1Write(t *testing.T) { - for _, ca := range []struct { - isC1 bool - enc []byte - dec C1S1 - }{ - { - true, - c1enc, - C1S1{ - Time: 435234723, - Random: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 382), - Digest: []byte{ - 0x3f, 0xd0, 0xb1, 0xdf, 0xed, 0x6c, 0x9b, 0xc3, - 0x73, 0x68, 0xe2, 0x47, 0x92, 0x59, 0x32, 0x9a, - 0x3a, 0xc9, 0x1e, 0xeb, 0xfc, 0xad, 0x8e, 0x9d, - 0x4e, 0xf4, 0x30, 0xb1, 0x9a, 0xc9, 0x15, 0x99, - }, - }, - }, - { - false, - s1enc, - C1S1{ - Time: 435234723, - Random: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 382), - Digest: []byte{ - 0x3f, 0xd0, 0xb1, 0xdf, 0xed, 0x6c, 0x9b, 0xc3, - 0x73, 0x68, 0xe2, 0x47, 0x92, 0x59, 0x32, 0x9a, - 0x3a, 0xc9, 0x1e, 0xeb, 0xfc, 0xad, 0x8e, 0x9d, - 0x4e, 0xf4, 0x30, 0xb1, 0x9a, 0xc9, 0x15, 0x99, - }, - }, - }, - } { - var buf bytes.Buffer - err := ca.dec.Write(&buf, ca.isC1) - require.NoError(t, err) - require.Equal(t, ca.enc, buf.Bytes()) - } + var buf bytes.Buffer + err := c1s1dec.Write(&buf) + require.NoError(t, err) + require.Equal(t, c1s1enc, buf.Bytes()) } diff --git a/internal/rtmp/handshake/c2s2.go b/internal/rtmp/handshake/c2s2.go index 31f0c24c..d5fa8590 100644 --- a/internal/rtmp/handshake/c2s2.go +++ b/internal/rtmp/handshake/c2s2.go @@ -2,44 +2,103 @@ package handshake import ( "bytes" + "crypto/rand" "fmt" "io" ) +const ( + c2s2Size = c1s1Size + c2s2DigestPos = c2s2Size - 8 - digestLength +) + +var ( + randomCrud = []byte{ + 0xf0, 0xee, 0xc2, 0x4a, 0x80, 0x68, 0xbe, 0xe8, + 0x2e, 0x00, 0xd0, 0xd1, 0x02, 0x9e, 0x7e, 0x57, + 0x6e, 0xec, 0x5d, 0x2d, 0x29, 0x80, 0x6f, 0xab, + 0x93, 0xb8, 0xe6, 0x36, 0xcf, 0xeb, 0x31, 0xae, + } + clientKeyC2 = append([]byte(nil), append(clientKeyC1, randomCrud...)...) + serverKeyS2 = append([]byte(nil), append(serverKeyS1, randomCrud...)...) +) + // C2S2 is a C2 or S2 packet. type C2S2 struct { - Time uint32 - Time2 uint32 - Random []byte - Digest []byte + Time uint32 + Time2 uint32 + Data []byte } // Read reads a C2S2. -func (c *C2S2) Read(r io.Reader, validateSignature bool) error { - buf := make([]byte, 1536) +func (c *C2S2) Read(r io.Reader) error { + buf := make([]byte, c2s2Size) _, err := io.ReadFull(r, buf) if err != nil { return err } - if validateSignature { - gap := len(buf) - 32 - digest := hsMakeDigest(c.Digest, buf, gap) - if !bytes.Equal(buf[gap:gap+32], digest) { - return fmt.Errorf("unable to validate C2/S2 signature") - } - } - c.Time = uint32(buf[0])<<24 | uint32(buf[1])<<16 | uint32(buf[2])<<8 | uint32(buf[3]) c.Time2 = uint32(buf[4])<<24 | uint32(buf[5])<<16 | uint32(buf[6])<<8 | uint32(buf[7]) - c.Random = buf[8:] + c.Data = buf[8:] return nil } +func (c C2S2) computeDigest(isS2 bool, prevDigest []byte) []byte { + // hash entire message except digest + msg := make([]byte, c2s2Size-digestLength) + msg[0] = byte(c.Time >> 24) + msg[1] = byte(c.Time >> 16) + msg[2] = byte(c.Time >> 8) + msg[3] = byte(c.Time) + msg[4] = byte(c.Time2 >> 24) + msg[5] = byte(c.Time2 >> 16) + msg[6] = byte(c.Time2 >> 8) + msg[7] = byte(c.Time2) + copy(msg[8:], c.Data[:c2s2DigestPos]) + + var key []byte + if isS2 { + key = hmacSha256(serverKeyS2, prevDigest) + } else { + key = hmacSha256(clientKeyC2, prevDigest) + } + + return hmacSha256(key, msg) +} + +func (c C2S2) validate(isS2 bool, prevDigest []byte) error { + d1 := c.Data[c2s2DigestPos : c2s2DigestPos+digestLength] + d2 := c.computeDigest(isS2, prevDigest) + + if !bytes.Equal(d1, d2) { + return fmt.Errorf("unable to validate C2/S2 digest") + } + + return nil +} + +func (c *C2S2) fillPlain() error { + c.Data = make([]byte, c2s2Size-8) + _, err := rand.Read(c.Data) + return err +} + +func (c *C2S2) fill(isS2 bool, prevDigest []byte) error { + err := c.fillPlain() + if err != nil { + return err + } + + digest := c.computeDigest(isS2, prevDigest) + copy(c.Data[c2s2DigestPos:], digest) + return nil +} + // Write writes a C2S2. func (c C2S2) Write(w io.Writer) error { - buf := make([]byte, 1536) + buf := make([]byte, c2s2Size) buf[0] = byte(c.Time >> 24) buf[1] = byte(c.Time >> 16) @@ -49,15 +108,7 @@ func (c C2S2) Write(w io.Writer) error { buf[5] = byte(c.Time2 >> 16) buf[6] = byte(c.Time2 >> 8) buf[7] = byte(c.Time2) - - copy(buf[8:], c.Random) - - // signature - if c.Digest != nil { - gap := len(buf) - 32 - digest := hsMakeDigest(c.Digest, buf, gap) - copy(buf[gap:], digest) - } + copy(buf[8:], c.Data) _, err := w.Write(buf) return err diff --git a/internal/rtmp/handshake/c2s2_test.go b/internal/rtmp/handshake/c2s2_test.go index db532ea8..8f063e5c 100644 --- a/internal/rtmp/handshake/c2s2_test.go +++ b/internal/rtmp/handshake/c2s2_test.go @@ -7,71 +7,22 @@ import ( "github.com/stretchr/testify/require" ) +var c2s2enc = bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 1536/4) + +var c2s2dec = C2S2{ + Time: 16909060, + Time2: 16909060, + Data: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 1536/4-2), +} + func TestC2S2Read(t *testing.T) { - c2s2dec := C2S2{ - Time: 435234723, - Time2: 7893542, - Random: append( - bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 372), - []byte{ - 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04, - 0x96, 0x07, 0x2f, 0xe4, 0x04, 0xc5, 0x84, 0xa2, - 0x21, 0x05, 0xcc, 0xb5, 0x7f, 0x93, 0x02, 0x14, - 0xaf, 0xb0, 0x76, 0x75, 0xfd, 0x82, 0x29, 0xbe, - 0xb9, 0x27, 0x9d, 0x4b, 0x0c, 0x81, 0x13, 0xec, - }...), - Digest: []byte{ - 0x3f, 0xd0, 0xb1, 0xdf, 0xed, 0x6c, 0x9b, 0xc3, - 0x73, 0x68, 0xe2, 0x47, 0x92, 0x59, 0x32, 0x9a, - 0x3a, 0xc9, 0x1e, 0xeb, 0xfc, 0xad, 0x8e, 0x9d, - 0x4e, 0xf4, 0x30, 0xb1, 0x9a, 0xc9, 0x15, 0x99, - }, - } - - c2s2enc := append(append( - []byte{ - 0x19, 0xf1, 0x27, 0xa3, 0x00, 0x78, 0x72, 0x26, - }, - bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 374)..., - ), []byte{ - 0x96, 0x07, 0x2f, 0xe4, 0x04, 0xc5, 0x84, 0xa2, - 0x21, 0x05, 0xcc, 0xb5, 0x7f, 0x93, 0x02, 0x14, - 0xaf, 0xb0, 0x76, 0x75, 0xfd, 0x82, 0x29, 0xbe, - 0xb9, 0x27, 0x9d, 0x4b, 0x0c, 0x81, 0x13, 0xec, - }...) - var c2s2 C2S2 - c2s2.Digest = c2s2dec.Digest - err := c2s2.Read((bytes.NewReader(c2s2enc)), true) + err := c2s2.Read((bytes.NewReader(c2s2enc))) require.NoError(t, err) require.Equal(t, c2s2dec, c2s2) } func TestC2S2Write(t *testing.T) { - c2s2dec := C2S2{ - Time: 435234723, - Time2: 7893542, - Random: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 382), - Digest: []byte{ - 0x3f, 0xd0, 0xb1, 0xdf, 0xed, 0x6c, 0x9b, 0xc3, - 0x73, 0x68, 0xe2, 0x47, 0x92, 0x59, 0x32, 0x9a, - 0x3a, 0xc9, 0x1e, 0xeb, 0xfc, 0xad, 0x8e, 0x9d, - 0x4e, 0xf4, 0x30, 0xb1, 0x9a, 0xc9, 0x15, 0x99, - }, - } - - c2s2enc := append(append( - []byte{ - 0x19, 0xf1, 0x27, 0xa3, 0x00, 0x78, 0x72, 0x26, - }, - bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 374)..., - ), []byte{ - 0x96, 0x07, 0x2f, 0xe4, 0x04, 0xc5, 0x84, 0xa2, - 0x21, 0x05, 0xcc, 0xb5, 0x7f, 0x93, 0x02, 0x14, - 0xaf, 0xb0, 0x76, 0x75, 0xfd, 0x82, 0x29, 0xbe, - 0xb9, 0x27, 0x9d, 0x4b, 0x0c, 0x81, 0x13, 0xec, - }...) - var buf bytes.Buffer err := c2s2dec.Write(&buf) require.NoError(t, err) diff --git a/internal/rtmp/handshake/dh.go b/internal/rtmp/handshake/dh.go new file mode 100644 index 00000000..aff74c99 --- /dev/null +++ b/internal/rtmp/handshake/dh.go @@ -0,0 +1,118 @@ +package handshake + +import ( + "crypto/rand" + "fmt" + "math/big" +) + +const ( + dhKeyLength = 128 +) + +var ( + p1024 = []byte{ + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xc9, 0x0f, 0xda, 0xa2, 0x21, 0x68, 0xc2, 0x34, + 0xc4, 0xc6, 0x62, 0x8b, 0x80, 0xdc, 0x1c, 0xd1, + 0x29, 0x02, 0x4e, 0x08, 0x8a, 0x67, 0xcc, 0x74, + 0x02, 0x0b, 0xbe, 0xa6, 0x3b, 0x13, 0x9b, 0x22, + 0x51, 0x4a, 0x08, 0x79, 0x8e, 0x34, 0x04, 0xdd, + 0xef, 0x95, 0x19, 0xb3, 0xcd, 0x3a, 0x43, 0x1b, + 0x30, 0x2b, 0x0a, 0x6d, 0xf2, 0x5f, 0x14, 0x37, + 0x4f, 0xe1, 0x35, 0x6d, 0x6d, 0x51, 0xc2, 0x45, + 0xe4, 0x85, 0xb5, 0x76, 0x62, 0x5e, 0x7e, 0xc6, + 0xf4, 0x4c, 0x42, 0xe9, 0xa6, 0x37, 0xed, 0x6b, + 0x0b, 0xff, 0x5c, 0xb6, 0xf4, 0x06, 0xb7, 0xed, + 0xee, 0x38, 0x6b, 0xfb, 0x5a, 0x89, 0x9f, 0xa5, + 0xae, 0x9f, 0x24, 0x11, 0x7c, 0x4b, 0x1f, 0xe6, + 0x49, 0x28, 0x66, 0x51, 0xec, 0xe6, 0x53, 0x81, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + } + q1024 = []byte{ + 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xe4, 0x87, 0xed, 0x51, 0x10, 0xb4, 0x61, 0x1a, + 0x62, 0x63, 0x31, 0x45, 0xc0, 0x6e, 0x0e, 0x68, + 0x94, 0x81, 0x27, 0x04, 0x45, 0x33, 0xe6, 0x3a, + 0x01, 0x05, 0xdf, 0x53, 0x1d, 0x89, 0xcd, 0x91, + 0x28, 0xa5, 0x04, 0x3c, 0xc7, 0x1a, 0x02, 0x6e, + 0xf7, 0xca, 0x8c, 0xd9, 0xe6, 0x9d, 0x21, 0x8d, + 0x98, 0x15, 0x85, 0x36, 0xf9, 0x2f, 0x8a, 0x1b, + 0xa7, 0xf0, 0x9a, 0xb6, 0xb6, 0xa8, 0xe1, 0x22, + 0xf2, 0x42, 0xda, 0xbb, 0x31, 0x2f, 0x3f, 0x63, + 0x7a, 0x26, 0x21, 0x74, 0xd3, 0x1b, 0xf6, 0xb5, + 0x85, 0xff, 0xae, 0x5b, 0x7a, 0x03, 0x5b, 0xf6, + 0xf7, 0x1c, 0x35, 0xfd, 0xad, 0x44, 0xcf, 0xd2, + 0xd7, 0x4f, 0x92, 0x08, 0xbe, 0x25, 0x8f, 0xf3, + 0x24, 0x94, 0x33, 0x28, 0xf6, 0x73, 0x29, 0xc0, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + } +) + +// https://datatracker.ietf.org/doc/html/rfc2631#section-2.1.5 +func dhValidatePublicKey(key []byte) error { + // 1. y >= 2 && y < p + var y big.Int + y.SetBytes(key) + var two big.Int + two.SetUint64(2) + r := y.Cmp(&two) + if r < 0 { + return fmt.Errorf("key is < 2") + } + var p big.Int + p.SetBytes(p1024) + r = y.Cmp(&p) + if r >= 0 { + return fmt.Errorf("key is >= p") + } + + // 2. (y^q mod p) == 1 + var q big.Int + q.SetBytes(q1024) + var z big.Int + z.Exp(&y, &q, &p) + var one big.Int + one.SetUint64(1) + r = z.Cmp(&one) + if r != 0 { + return fmt.Errorf("y^q mod p is != 1") + } + + return nil +} + +func dhGenerateKeyPair() ([]byte, []byte, error) { + priv := make([]byte, dhKeyLength) + _, err := rand.Read(priv) + if err != nil { + return nil, nil, err + } + + // y = g ^ x mod p + var g big.Int + g.SetUint64(2) + var x big.Int + x.SetBytes(priv) + var p big.Int + p.SetBytes(p1024) + var y big.Int + y.Exp(&g, &x, &p) + pub := y.Bytes() + + return priv, pub, nil +} + +// https://datatracker.ietf.org/doc/html/rfc2631#section-2.1.1 +func dhComputeSharedSecret(priv []byte, pub []byte) []byte { + // ZZ = (ya ^ xb) mod p + var y big.Int + y.SetBytes(pub) + var x big.Int + x.SetBytes(priv) + var p big.Int + p.SetBytes(p1024) + var z big.Int + z.Exp(&y, &x, &p) + return z.Bytes() +} diff --git a/internal/rtmp/handshake/handshake.go b/internal/rtmp/handshake/handshake.go index 0718972e..9bd2bd57 100644 --- a/internal/rtmp/handshake/handshake.go +++ b/internal/rtmp/handshake/handshake.go @@ -2,96 +2,297 @@ package handshake import ( + "bytes" + "fmt" "io" ) -// DoClient performs a client-side handshake. -func DoClient(rw io.ReadWriter, validateSignature bool) error { - c0 := C0S0{} +const ( + encryptedVersion = 3<<24 | 5<<16 | 1<<8 | 1 +) + +func doClientEncrypted(rw io.ReadWriter) ([]byte, []byte, error) { + var c0 C0S0 + + c0.Version = 6 + + err := c0.Write(rw) + if err != nil { + return nil, nil, err + } + + localPrivateKey, localPublicKey, err := dhGenerateKeyPair() + if err != nil { + return nil, nil, err + } + + var c1 C1S1 + + c1Digest, err := c1.fill(false, localPublicKey) + if err != nil { + return nil, nil, err + } + + err = c1.Write(rw) + if err != nil { + return nil, nil, err + } + + var s0 C0S0 + + err = s0.Read(rw) + if err != nil { + return nil, nil, err + } + + if s0.Version != 6 { + return nil, nil, fmt.Errorf("server replied with unexpected version %d", s0.Version) + } + + var s1 C1S1 + + err = s1.Read(rw) + if err != nil { + return nil, nil, err + } + + s1Digest, remotePublicKey, err := s1.validate(true) + if err != nil { + return nil, nil, err + } + + var s2 C2S2 + + err = s2.Read(rw) + if err != nil { + return nil, nil, err + } + + err = s2.validate(true, c1Digest) + if err != nil { + return nil, nil, err + } + + var c2 C2S2 + + err = c2.fill(false, s1Digest) + if err != nil { + return nil, nil, err + } + + err = c2.Write(rw) + if err != nil { + return nil, nil, err + } + + sharedSecret := dhComputeSharedSecret(localPrivateKey, remotePublicKey) + keyIn := hmacSha256(sharedSecret, localPublicKey)[:16] + keyOut := hmacSha256(sharedSecret, remotePublicKey)[:16] + return keyIn, keyOut, nil +} + +func doClientPlain(rw io.ReadWriter, strict bool) error { + var c0 C0S0 + + c0.Version = 3 + err := c0.Write(rw) if err != nil { return err } - c1 := C1S1{} - err = c1.Write(rw, true) + var c1 C1S1 + + err = c1.fillPlain() if err != nil { return err } - s0 := C0S0{} + err = c1.Write(rw) + if err != nil { + return err + } + + var s0 C0S0 + err = s0.Read(rw) if err != nil { return err } - s1 := C1S1{} - err = s1.Read(rw, false, validateSignature) + if s0.Version != 3 { + return fmt.Errorf("server replied with unexpected version %d", s0.Version) + } + + var s1 C1S1 + + err = s1.Read(rw) if err != nil { return err } - s2 := C2S2{ - Digest: c1.Digest, - } - err = s2.Read(rw, validateSignature) + var s2 C2S2 + + err = s2.Read(rw) if err != nil { return err } - c2 := C2S2{ - Time: s1.Time, - Random: s1.Random, - Digest: s1.Digest, + if strict && !bytes.Equal(s2.Data, c1.Data) { + return fmt.Errorf("data in S2 does not correspond") } - err = c2.Write(rw) + + var c2 C2S2 + + c2.Data = s1.Data + + return c2.Write(rw) +} + +// DoClient performs a client-side handshake. +func DoClient(rw io.ReadWriter, encrypted bool, strict bool) ([]byte, []byte, error) { + if encrypted { + return doClientEncrypted(rw) + } + return nil, nil, doClientPlain(rw, strict) +} + +func doServerEncrypted(rw io.ReadWriter) ([]byte, []byte, error) { + var c1 C1S1 + + err := c1.Read(rw) + if err != nil { + return nil, nil, err + } + + c1Digest, remotePublicKey, err := c1.validate(false) + if err != nil { + return nil, nil, err + } + + localPrivateKey, localPublicKey, err := dhGenerateKeyPair() + if err != nil { + return nil, nil, err + } + + var s0 C0S0 + + s0.Version = 6 + + err = s0.Write(rw) + if err != nil { + return nil, nil, err + } + + var s1 C1S1 + + s1.Version = encryptedVersion + + s1Digest, err := s1.fill(true, localPublicKey) + if err != nil { + return nil, nil, err + } + + err = s1.Write(rw) + if err != nil { + return nil, nil, err + } + + var s2 C2S2 + + s2.Time2 = encryptedVersion + + err = s2.fill(true, c1Digest) + if err != nil { + return nil, nil, err + } + + err = s2.Write(rw) + if err != nil { + return nil, nil, err + } + + var c2 C2S2 + + err = c2.Read(rw) + if err != nil { + return nil, nil, err + } + + err = c2.validate(false, s1Digest) + if err != nil { + return nil, nil, err + } + + sharedSecret := dhComputeSharedSecret(localPrivateKey, remotePublicKey) + keyIn := hmacSha256(sharedSecret, localPublicKey)[:16] + keyOut := hmacSha256(sharedSecret, remotePublicKey)[:16] + return keyIn, keyOut, nil +} + +func doServerPlain(rw io.ReadWriter, strict bool) error { + var c1 C1S1 + + err := c1.Read(rw) if err != nil { return err } + var s0 C0S0 + + s0.Version = 3 + + err = s0.Write(rw) + if err != nil { + return err + } + + var s1 C1S1 + + err = s1.fillPlain() + if err != nil { + return err + } + + err = s1.Write(rw) + if err != nil { + return err + } + + var s2 C2S2 + + s2.Data = c1.Data + + err = s2.Write(rw) + if err != nil { + return err + } + + var c2 C2S2 + + err = c2.Read(rw) + if err != nil { + return err + } + + if strict && !bytes.Equal(c2.Data, s1.Data) { + return fmt.Errorf("data in C2 does not correspond") + } + return nil } // DoServer performs a server-side handshake. -func DoServer(rw io.ReadWriter, validateSignature bool) error { - err := C0S0{}.Read(rw) +func DoServer(rw io.ReadWriter, strict bool) ([]byte, []byte, error) { + var c0 C0S0 + + err := c0.Read(rw) if err != nil { - return err + return nil, nil, err } - c1 := C1S1{} - err = c1.Read(rw, true, validateSignature) - if err != nil { - return err + if c0.Version == 6 { + return doServerEncrypted(rw) } - - s0 := C0S0{} - err = s0.Write(rw) - if err != nil { - return err - } - - s1 := C1S1{} - err = s1.Write(rw, false) - if err != nil { - return err - } - - s2 := C2S2{ - Time: c1.Time, - Random: c1.Random, - Digest: c1.Digest, - } - err = s2.Write(rw) - if err != nil { - return err - } - - c2 := C2S2{Digest: s1.Digest} - err = c2.Read(rw, validateSignature) - if err != nil { - return err - } - - return nil + return nil, nil, doServerPlain(rw, strict) } diff --git a/internal/rtmp/handshake/handshake_test.go b/internal/rtmp/handshake/handshake_test.go index 744854d7..d9ee4136 100644 --- a/internal/rtmp/handshake/handshake_test.go +++ b/internal/rtmp/handshake/handshake_test.go @@ -1,91 +1,50 @@ package handshake import ( - "crypto/rand" - "net" "testing" "github.com/stretchr/testify/require" ) +type testReadWriter struct { + ch chan []byte +} + +func (rw *testReadWriter) Read(p []byte) (int, error) { + in := <-rw.ch + n := copy(p, in) + return n, nil +} + +func (rw *testReadWriter) Write(p []byte) (int, error) { + rw.ch <- p + return len(p), nil +} + func TestHandshake(t *testing.T) { - ln, err := net.Listen("tcp", "127.0.0.1:9122") - require.NoError(t, err) - defer ln.Close() + for _, ca := range []string{"plain", "encrypted"} { + t.Run(ca, func(t *testing.T) { + rw := &testReadWriter{ch: make(chan []byte)} + var serverInKey []byte + var serverOutKey []byte + done := make(chan struct{}) - done := make(chan struct{}) + go func() { + var err error + serverInKey, serverOutKey, err = DoServer(rw, true) + require.NoError(t, err) + close(done) + }() - go func() { - conn, err := ln.Accept() - require.NoError(t, err) - defer conn.Close() + clientInKey, clientOutKey, err := DoClient(rw, ca == "encrypted", true) + require.NoError(t, err) + <-done - err = DoServer(conn, true) - require.NoError(t, err) - - close(done) - }() - - conn, err := net.Dial("tcp", "127.0.0.1:9122") - require.NoError(t, err) - defer conn.Close() - - err = DoClient(conn, true) - require.NoError(t, err) - - <-done -} - -// when C1 signature is invalid, S2 must be equal to C1. -func TestHandshakeFallback(t *testing.T) { - ln, err := net.Listen("tcp", "127.0.0.1:9122") - require.NoError(t, err) - defer ln.Close() - - done := make(chan struct{}) - - go func() { - conn, err := ln.Accept() - require.NoError(t, err) - defer conn.Close() - - err = DoServer(conn, false) - require.NoError(t, err) - - close(done) - }() - - conn, err := net.Dial("tcp", "127.0.0.1:9122") - require.NoError(t, err) - defer conn.Close() - - err = C0S0{}.Write(conn) - require.NoError(t, err) - - c1 := make([]byte, 1536) - _, err = rand.Read(c1[8:]) - require.NoError(t, err) - _, err = conn.Write(c1) - require.NoError(t, err) - - err = C0S0{}.Read(conn) - require.NoError(t, err) - - s1 := C1S1{} - err = s1.Read(conn, false, false) - require.NoError(t, err) - - s2 := C2S2{} - err = s2.Read(conn, false) - require.NoError(t, err) - require.Equal(t, c1[8:], s2.Random) - - err = C2S2{ - Time: s1.Time, - Random: s1.Random, - Digest: s1.Digest, - }.Write(conn) - require.NoError(t, err) - - <-done + if ca == "encrypted" { + require.NotNil(t, serverInKey) + require.Equal(t, serverInKey, clientOutKey) + require.Equal(t, serverOutKey, clientInKey) + } + }) + } } diff --git a/internal/rtmp/message/reader.go b/internal/rtmp/message/reader.go index f7a458a0..0aad13c0 100644 --- a/internal/rtmp/message/reader.go +++ b/internal/rtmp/message/reader.go @@ -2,6 +2,7 @@ package message import ( "fmt" + "io" "github.com/bluenviron/mediamtx/internal/rtmp/bytecounter" "github.com/bluenviron/mediamtx/internal/rtmp/rawmessage" @@ -116,9 +117,13 @@ type Reader struct { } // NewReader allocates a Reader. -func NewReader(r *bytecounter.Reader, onAckNeeded func(uint32) error) *Reader { +func NewReader( + r io.Reader, + bcr *bytecounter.Reader, + onAckNeeded func(uint32) error, +) *Reader { return &Reader{ - r: rawmessage.NewReader(r, onAckNeeded), + r: rawmessage.NewReader(r, bcr, onAckNeeded), } } diff --git a/internal/rtmp/message/reader_test.go b/internal/rtmp/message/reader_test.go index b0e5aaa6..dd18d379 100644 --- a/internal/rtmp/message/reader_test.go +++ b/internal/rtmp/message/reader_test.go @@ -268,7 +268,8 @@ var readWriterCases = []struct { func TestReader(t *testing.T) { for _, ca := range readWriterCases { t.Run(ca.name, func(t *testing.T) { - r := NewReader(bytecounter.NewReader(bytes.NewReader(ca.enc)), nil) + bc := bytecounter.NewReader(bytes.NewReader(ca.enc)) + r := NewReader(bc, bc, nil) dec, err := r.Read() require.NoError(t, err) require.Equal(t, ca.dec, dec) diff --git a/internal/rtmp/message/readwriter.go b/internal/rtmp/message/readwriter.go index fe4b3ea1..cb03ec22 100644 --- a/internal/rtmp/message/readwriter.go +++ b/internal/rtmp/message/readwriter.go @@ -1,6 +1,8 @@ package message import ( + "io" + "github.com/bluenviron/mediamtx/internal/rtmp/bytecounter" ) @@ -11,10 +13,14 @@ type ReadWriter struct { } // NewReadWriter allocates a ReadWriter. -func NewReadWriter(bc *bytecounter.ReadWriter, checkAcknowledge bool) *ReadWriter { - w := NewWriter(bc.Writer, checkAcknowledge) +func NewReadWriter( + rw io.ReadWriter, + bcrw *bytecounter.ReadWriter, + checkAcknowledge bool, +) *ReadWriter { + w := NewWriter(rw, bcrw.Writer, checkAcknowledge) - r := NewReader(bc.Reader, func(count uint32) error { + r := NewReader(rw, bcrw.Reader, func(count uint32) error { return w.Write(&Acknowledge{ Value: count, }) diff --git a/internal/rtmp/message/readwriter_test.go b/internal/rtmp/message/readwriter_test.go index af00c9d6..54327c8c 100644 --- a/internal/rtmp/message/readwriter_test.go +++ b/internal/rtmp/message/readwriter_test.go @@ -27,19 +27,21 @@ func TestReadWriterAcknowledge(t *testing.T) { var buf1 bytes.Buffer var buf2 bytes.Buffer - rw1 := NewReadWriter(bytecounter.NewReadWriter(&duplexRW{ + bc1 := bytecounter.NewReadWriter(&duplexRW{ Reader: &buf2, Writer: &buf1, - }), true) + }) + rw1 := NewReadWriter(bc1, bc1, true) err := rw1.Write(&Acknowledge{ Value: 7863534, }) require.NoError(t, err) - rw2 := NewReadWriter(bytecounter.NewReadWriter(&duplexRW{ + bc2 := bytecounter.NewReadWriter(&duplexRW{ Reader: &buf1, Writer: &buf2, - }), true) + }) + rw2 := NewReadWriter(bc2, bc2, true) _, err = rw2.Read() require.NoError(t, err) } @@ -48,19 +50,21 @@ func TestReadWriterPing(t *testing.T) { var buf1 bytes.Buffer var buf2 bytes.Buffer - rw1 := NewReadWriter(bytecounter.NewReadWriter(&duplexRW{ + bc1 := bytecounter.NewReadWriter(&duplexRW{ Reader: &buf2, Writer: &buf1, - }), true) + }) + rw1 := NewReadWriter(bc1, bc1, true) err := rw1.Write(&UserControlPingRequest{ ServerTime: 143424312, }) require.NoError(t, err) - rw2 := NewReadWriter(bytecounter.NewReadWriter(&duplexRW{ + bc2 := bytecounter.NewReadWriter(&duplexRW{ Reader: &buf1, Writer: &buf2, - }), true) + }) + rw2 := NewReadWriter(bc2, bc2, true) _, err = rw2.Read() require.NoError(t, err) diff --git a/internal/rtmp/message/writer.go b/internal/rtmp/message/writer.go index 6f67f025..a0ad2fdf 100644 --- a/internal/rtmp/message/writer.go +++ b/internal/rtmp/message/writer.go @@ -1,6 +1,8 @@ package message import ( + "io" + "github.com/bluenviron/mediamtx/internal/rtmp/bytecounter" "github.com/bluenviron/mediamtx/internal/rtmp/rawmessage" ) @@ -11,9 +13,13 @@ type Writer struct { } // NewWriter allocates a Writer. -func NewWriter(w *bytecounter.Writer, checkAcknowledge bool) *Writer { +func NewWriter( + w io.Writer, + bcw *bytecounter.Writer, + checkAcknowledge bool, +) *Writer { return &Writer{ - w: rawmessage.NewWriter(w, checkAcknowledge), + w: rawmessage.NewWriter(w, bcw, checkAcknowledge), } } diff --git a/internal/rtmp/message/writer_test.go b/internal/rtmp/message/writer_test.go index c26b8744..178a5b64 100644 --- a/internal/rtmp/message/writer_test.go +++ b/internal/rtmp/message/writer_test.go @@ -13,7 +13,8 @@ func TestWriter(t *testing.T) { for _, ca := range readWriterCases { t.Run(ca.name, func(t *testing.T) { var buf bytes.Buffer - r := NewWriter(bytecounter.NewWriter(&buf), true) + bc := bytecounter.NewWriter(&buf) + r := NewWriter(bc, bc, true) err := r.Write(ca.dec) require.NoError(t, err) require.Equal(t, ca.enc, buf.Bytes()) diff --git a/internal/rtmp/rawmessage/reader.go b/internal/rtmp/rawmessage/reader.go index bd097ad2..970f2267 100644 --- a/internal/rtmp/rawmessage/reader.go +++ b/internal/rtmp/rawmessage/reader.go @@ -4,6 +4,7 @@ import ( "bufio" "errors" "fmt" + "io" "time" "github.com/bluenviron/mediamtx/internal/rtmp/bytecounter" @@ -30,7 +31,7 @@ func (rc *readerChunkStream) readChunk(c chunk.Chunk, chunkBodySize uint32) erro // check if an ack is needed if rc.mr.ackWindowSize != 0 { - count := uint32(rc.mr.r.Count()) + count := uint32(rc.mr.bcr.Count()) diff := count - rc.mr.lastAckCount if diff > (rc.mr.ackWindowSize) { @@ -208,7 +209,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) { // Reader is a raw message reader. type Reader struct { - r *bytecounter.Reader + bcr *bytecounter.Reader onAckNeeded func(uint32) error br *bufio.Reader @@ -224,9 +225,13 @@ type Reader struct { } // NewReader allocates a Reader. -func NewReader(r *bytecounter.Reader, onAckNeeded func(uint32) error) *Reader { +func NewReader( + r io.Reader, + bcr *bytecounter.Reader, + onAckNeeded func(uint32) error, +) *Reader { return &Reader{ - r: r, + bcr: bcr, br: bufio.NewReader(r), onAckNeeded: onAckNeeded, chunkSize: 128, diff --git a/internal/rtmp/rawmessage/reader_test.go b/internal/rtmp/rawmessage/reader_test.go index 51890b6e..3ea39a82 100644 --- a/internal/rtmp/rawmessage/reader_test.go +++ b/internal/rtmp/rawmessage/reader_test.go @@ -198,7 +198,8 @@ func TestReader(t *testing.T) { for _, ca := range cases { t.Run(ca.name, func(t *testing.T) { var buf bytes.Buffer - r := NewReader(bytecounter.NewReader(&buf), func(count uint32) error { + br := bytecounter.NewReader(&buf) + r := NewReader(br, br, func(count uint32) error { return nil }) @@ -224,7 +225,7 @@ func TestReaderAcknowledge(t *testing.T) { var buf bytes.Buffer bc := bytecounter.NewReader(&buf) - r := NewReader(bc, func(count uint32) error { + r := NewReader(bc, bc, func(count uint32) error { close(onAckCalled) return nil }) diff --git a/internal/rtmp/rawmessage/writer.go b/internal/rtmp/rawmessage/writer.go index 9a7fe0b5..790b3c72 100644 --- a/internal/rtmp/rawmessage/writer.go +++ b/internal/rtmp/rawmessage/writer.go @@ -3,6 +3,7 @@ package rawmessage import ( "bufio" "fmt" + "io" "time" "github.com/bluenviron/mediamtx/internal/rtmp/bytecounter" @@ -21,7 +22,7 @@ type writerChunkStream struct { func (wc *writerChunkStream) writeChunk(c chunk.Chunk) error { // check if we received an acknowledge if wc.mw.checkAcknowledge && wc.mw.ackWindowSize != 0 { - diff := uint32(wc.mw.w.Count()) - wc.mw.ackValue + diff := uint32(wc.mw.bcw.Count()) - wc.mw.ackValue if diff > (wc.mw.ackWindowSize * 3 / 2) { return fmt.Errorf("no acknowledge received within window") @@ -148,7 +149,7 @@ func (wc *writerChunkStream) writeMessage(msg *Message) error { // Writer is a raw message writer. type Writer struct { - w *bytecounter.Writer + bcw *bytecounter.Writer bw *bufio.Writer checkAcknowledge bool chunkSize uint32 @@ -158,9 +159,13 @@ type Writer struct { } // NewWriter allocates a Writer. -func NewWriter(w *bytecounter.Writer, checkAcknowledge bool) *Writer { +func NewWriter( + w io.Writer, + bcw *bytecounter.Writer, + checkAcknowledge bool, +) *Writer { return &Writer{ - w: w, + bcw: bcw, bw: bufio.NewWriter(w), checkAcknowledge: checkAcknowledge, chunkSize: 128, diff --git a/internal/rtmp/rawmessage/writer_test.go b/internal/rtmp/rawmessage/writer_test.go index 7827747c..7987ff99 100644 --- a/internal/rtmp/rawmessage/writer_test.go +++ b/internal/rtmp/rawmessage/writer_test.go @@ -15,7 +15,8 @@ func TestWriter(t *testing.T) { for _, ca := range cases { t.Run(ca.name, func(t *testing.T) { var buf bytes.Buffer - w := NewWriter(bytecounter.NewWriter(&buf), true) + bc := bytecounter.NewWriter(&buf) + w := NewWriter(bc, bc, true) for _, msg := range ca.messages { err := w.Write(msg) @@ -36,11 +37,11 @@ func TestWriterAcknowledge(t *testing.T) { for _, ca := range []string{"standard", "overflow"} { t.Run(ca, func(t *testing.T) { var buf bytes.Buffer - bcw := bytecounter.NewWriter(&buf) - w := NewWriter(bcw, true) + bc := bytecounter.NewWriter(&buf) + w := NewWriter(bc, bc, true) if ca == "overflow" { - bcw.SetCount(4294967096) + bc.SetCount(4294967096) w.ackValue = 4294967096 } diff --git a/internal/rtmp/rc4_readwriter.go b/internal/rtmp/rc4_readwriter.go new file mode 100644 index 00000000..bee263d0 --- /dev/null +++ b/internal/rtmp/rc4_readwriter.go @@ -0,0 +1,49 @@ +package rtmp + +import ( + "crypto/rc4" + "io" +) + +type rc4ReadWriter struct { + rw io.ReadWriter + in *rc4.Cipher + out *rc4.Cipher +} + +func newRC4ReadWriter(rw io.ReadWriter, keyIn []byte, keyOut []byte) (*rc4ReadWriter, error) { + in, err := rc4.NewCipher(keyIn) + if err != nil { + return nil, err + } + + out, err := rc4.NewCipher(keyOut) + if err != nil { + return nil, err + } + + p := make([]byte, 1536) + in.XORKeyStream(p, p) + out.XORKeyStream(p, p) + + return &rc4ReadWriter{ + rw: rw, + in: in, + out: out, + }, nil +} + +func (r *rc4ReadWriter) Read(p []byte) (int, error) { + n, err := r.rw.Read(p) + if n == 0 { + return 0, err + } + + r.in.XORKeyStream(p[:n], p[:n]) + return n, err +} + +func (r *rc4ReadWriter) Write(p []byte) (int, error) { + r.out.XORKeyStream(p, p) + return r.rw.Write(p) +} diff --git a/internal/rtmp/reader_test.go b/internal/rtmp/reader_test.go index 01e7c9a7..4775f7d1 100644 --- a/internal/rtmp/reader_test.go +++ b/internal/rtmp/reader_test.go @@ -166,7 +166,7 @@ func TestReadTracks(t *testing.T) { t.Run(ca.name, func(t *testing.T) { var buf bytes.Buffer bc := bytecounter.NewReadWriter(&buf) - mrw := message.NewReadWriter(bc, true) + mrw := message.NewReadWriter(bc, bc, true) switch ca.name { case "video+audio": diff --git a/internal/rtmp/writer_test.go b/internal/rtmp/writer_test.go index c04778c9..46cfbc97 100644 --- a/internal/rtmp/writer_test.go +++ b/internal/rtmp/writer_test.go @@ -46,7 +46,7 @@ func TestWriteTracks(t *testing.T) { require.NoError(t, err) bc := bytecounter.NewReadWriter(&buf) - mrw := message.NewReadWriter(bc, true) + mrw := message.NewReadWriter(bc, bc, true) msg, err := mrw.Read() require.NoError(t, err)