rtmp: support ingesting RTMPE streams (#2189)

This commit is contained in:
Alessandro Ros 2023-08-10 21:06:51 +02:00 committed by GitHub
parent 57a436b0d5
commit 7e180ceea2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 806 additions and 497 deletions

View file

@ -156,12 +156,12 @@ func NewClientConn(rw io.ReadWriter, u *url.URL, publish bool) (*Conn, error) {
func (c *Conn) initializeClient(u *url.URL, publish bool) error {
connectpath, actionpath := splitPath(u)
err := handshake.DoClient(c.bc, false)
_, _, err := handshake.DoClient(c.bc, false, false)
if err != nil {
return err
}
c.mrw = message.NewReadWriter(c.bc, false)
c.mrw = message.NewReadWriter(c.bc, c.bc, false)
err = c.mrw.Write(&message.SetWindowAckSize{
Value: 2500000,
@ -329,12 +329,23 @@ func NewServerConn(rw io.ReadWriter) (*Conn, *url.URL, bool, error) {
}
func (c *Conn) initializeServer() (*url.URL, bool, error) {
err := handshake.DoServer(c.bc, false)
keyIn, keyOut, err := handshake.DoServer(c.bc, false)
if err != nil {
return nil, false, err
}
c.mrw = message.NewReadWriter(c.bc, false)
var rw io.ReadWriter
if keyIn != nil {
var err error
rw, err = newRC4ReadWriter(c.bc, keyIn, keyOut)
if err != nil {
return nil, false, err
}
} else {
rw = c.bc
}
c.mrw = message.NewReadWriter(rw, c.bc, false)
cmd, err := readCommand(c.mrw)
if err != nil {
@ -581,7 +592,7 @@ func newNoHandshakeConn(rw io.ReadWriter) *Conn {
bc: bytecounter.NewReadWriter(rw),
}
c.mrw = message.NewReadWriter(c.bc, false)
c.mrw = message.NewReadWriter(c.bc, c.bc, false)
return c
}

View file

@ -29,10 +29,10 @@ func TestNewClientConn(t *testing.T) {
defer conn.Close()
bc := bytecounter.NewReadWriter(conn)
err = handshake.DoServer(bc, true)
_, _, err = handshake.DoServer(bc, false)
require.NoError(t, err)
mrw := message.NewReadWriter(bc, true)
mrw := message.NewReadWriter(bc, bc, true)
msg, err := mrw.Read()
require.NoError(t, err)
@ -289,10 +289,10 @@ func TestNewServerConn(t *testing.T) {
defer conn.Close()
bc := bytecounter.NewReadWriter(conn)
err = handshake.DoClient(bc, true)
_, _, err = handshake.DoClient(bc, false, false)
require.NoError(t, err)
mrw := message.NewReadWriter(bc, true)
mrw := message.NewReadWriter(bc, bc, true)
tcURL := "rtmp://127.0.0.1:9121/stream"
if ca == "publish neko" {

View file

@ -5,30 +5,30 @@ import (
"io"
)
const (
rtmpVersion = 0x03
)
// C0S0 is a C0 or S0 packet.
type C0S0 struct{}
type C0S0 struct {
Version byte
}
// Read reads a C0S0.
func (C0S0) Read(r io.Reader) error {
func (c *C0S0) Read(r io.Reader) error {
buf := make([]byte, 1)
_, err := io.ReadFull(r, buf)
if err != nil {
return err
}
if buf[0] != rtmpVersion {
return fmt.Errorf("invalid rtmp version (%d)", buf[0])
c.Version = buf[0]
if c.Version != 3 && c.Version != 6 {
return fmt.Errorf("invalid rtmp version (%d)", c.Version)
}
return nil
}
// Write writes a C0S0.
func (C0S0) Write(w io.Writer) error {
_, err := w.Write([]byte{rtmpVersion})
func (c C0S0) Write(w io.Writer) error {
_, err := w.Write([]byte{c.Version})
return err
}

View file

@ -7,9 +7,11 @@ import (
"github.com/stretchr/testify/require"
)
var c0s0enc = []byte{0x03}
var c0s0enc = []byte{3}
var c0s0dec = C0S0{}
var c0s0dec = C0S0{
Version: 3,
}
func TestC0S0Read(t *testing.T) {
var c0s0 C0S0

View file

@ -9,140 +9,180 @@ import (
"io"
)
var (
hsClientFullKey = []byte{
'G', 'e', 'n', 'u', 'i', 'n', 'e', ' ', 'A', 'd', 'o', 'b', 'e', ' ',
'F', 'l', 'a', 's', 'h', ' ', 'P', 'l', 'a', 'y', 'e', 'r', ' ',
'0', '0', '1',
0xF0, 0xEE, 0xC2, 0x4A, 0x80, 0x68, 0xBE, 0xE8, 0x2E, 0x00, 0xD0, 0xD1,
0x02, 0x9E, 0x7E, 0x57, 0x6E, 0xEC, 0x5D, 0x2D, 0x29, 0x80, 0x6F, 0xAB,
0x93, 0xB8, 0xE6, 0x36, 0xCF, 0xEB, 0x31, 0xAE,
}
hsServerFullKey = []byte{
'G', 'e', 'n', 'u', 'i', 'n', 'e', ' ', 'A', 'd', 'o', 'b', 'e', ' ',
'F', 'l', 'a', 's', 'h', ' ', 'M', 'e', 'd', 'i', 'a', ' ',
'S', 'e', 'r', 'v', 'e', 'r', ' ',
'0', '0', '1',
0xF0, 0xEE, 0xC2, 0x4A, 0x80, 0x68, 0xBE, 0xE8, 0x2E, 0x00, 0xD0, 0xD1,
0x02, 0x9E, 0x7E, 0x57, 0x6E, 0xEC, 0x5D, 0x2D, 0x29, 0x80, 0x6F, 0xAB,
0x93, 0xB8, 0xE6, 0x36, 0xCF, 0xEB, 0x31, 0xAE,
}
hsClientPartialKey = hsClientFullKey[:30]
hsServerPartialKey = hsServerFullKey[:36]
const (
c1s1Size = 1536
digestPointerPos1 = 0
digestPointerPos2 = 772 - 8
digestChunkPos1 = digestPointerPos1 + 4
digestChunkPos2 = digestPointerPos2 + 4
digestChunkLength = 728
digestLength = 32
publicKeyPointerPos1 = 1532 - 8
publicKeyPointerPos2 = 768 - 8
publicKeyChunkPos1 = publicKeyPointerPos1 - 760
publicKeyChunkPos2 = publicKeyPointerPos2 - 760
publicKeyChunkLength = 632
)
func hsCalcDigestPos(p []byte, base int) int {
pos := 0
for i := 0; i < 4; i++ {
pos += int(p[base+i])
}
return (pos % 728) + base + 4
}
var (
clientKeyC1 = []byte("Genuine Adobe Flash Player 001")
serverKeyS1 = []byte("Genuine Adobe Flash Media Server 001")
)
func hsMakeDigest(key []byte, src []byte, gap int) []byte {
func hmacSha256(key []byte, buf []byte) []byte {
h := hmac.New(sha256.New, key)
if gap <= 0 {
h.Write(src)
} else {
h.Write(src[:gap])
h.Write(src[gap+32:])
}
h.Write(buf)
return h.Sum(nil)
}
func hsFindDigest(p []byte, key []byte, base int) int {
gap := hsCalcDigestPos(p, base)
digest := hsMakeDigest(key, p, gap)
if !bytes.Equal(p[gap:gap+32], digest) {
return -1
}
return gap
}
func hsParse1(p []byte, peerkey []byte, key []byte) (bool, []byte) {
var pos int
if pos = hsFindDigest(p, peerkey, 772); pos == -1 {
if pos = hsFindDigest(p, peerkey, 8); pos == -1 {
return false, nil
}
}
return true, hsMakeDigest(key, p[pos:pos+32], -1)
}
// C1S1 is a C1 or S1 packet.
type C1S1 struct {
Time uint32
Random []byte
Digest []byte
Time uint32
Version uint32
Data []byte
}
// Read reads a C1S1.
func (c *C1S1) Read(r io.Reader, isC1 bool, validateSignature bool) error {
buf := make([]byte, 1536)
func (c *C1S1) Read(r io.Reader) error {
buf := make([]byte, c1s1Size)
_, err := io.ReadFull(r, buf)
if err != nil {
return err
}
var peerKey []byte
var key []byte
if isC1 {
peerKey = hsClientPartialKey
key = hsServerFullKey
} else {
peerKey = hsServerPartialKey
key = hsClientFullKey
}
ok, digest := hsParse1(buf, peerKey, key)
if !ok {
if validateSignature {
return fmt.Errorf("unable to validate C1/S1 signature")
}
} else {
c.Digest = digest
}
c.Time = uint32(buf[0])<<24 | uint32(buf[1])<<16 | uint32(buf[2])<<8 | uint32(buf[3])
c.Random = buf[8:]
c.Version = uint32(buf[4])<<24 | uint32(buf[5])<<16 | uint32(buf[6])<<8 | uint32(buf[7])
c.Data = buf[8:]
return nil
}
func (c C1S1) readPointer(p int) int {
return int(c.Data[p]) + int(c.Data[p+1]) + int(c.Data[p+2]) + int(c.Data[p+3])
}
func (c C1S1) publicKeyPos1() int {
return publicKeyChunkPos1 + (c.readPointer(publicKeyPointerPos1) % publicKeyChunkLength)
}
func (c C1S1) publicKeyPos2() int {
return publicKeyChunkPos2 + (c.readPointer(publicKeyPointerPos2) % publicKeyChunkLength)
}
func (c C1S1) digestPos1() int {
return digestChunkPos1 + (c.readPointer(digestPointerPos1) % digestChunkLength)
}
func (c C1S1) digestPos2() int {
return digestChunkPos2 + (c.readPointer(digestPointerPos2) % digestChunkLength)
}
func (c C1S1) computeDigest(digestPos int, isS1 bool) []byte {
// hash entire message except digest
msg := make([]byte, c1s1Size-digestLength)
msg[0] = byte(c.Time >> 24)
msg[1] = byte(c.Time >> 16)
msg[2] = byte(c.Time >> 8)
msg[3] = byte(c.Time)
msg[4] = byte(c.Version >> 24)
msg[5] = byte(c.Version >> 16)
msg[6] = byte(c.Version >> 8)
msg[7] = byte(c.Version)
copy(msg[8:], c.Data[:digestPos])
copy(msg[8+digestPos:], c.Data[digestPos+digestLength:])
if isS1 {
return hmacSha256(serverKeyS1, msg)
}
return hmacSha256(clientKeyC1, msg)
}
func (c C1S1) validateDigest(isS1 bool) ([]byte, []byte, error) {
digestPos := c.digestPos1()
d1 := c.Data[digestPos : digestPos+digestLength]
d2 := c.computeDigest(digestPos, isS1)
if bytes.Equal(d1, d2) {
publicKeyPos := c.publicKeyPos1()
publicKey := c.Data[publicKeyPos : publicKeyPos+dhKeyLength]
return d1, publicKey, nil
}
digestPos = c.digestPos2()
d1 = c.Data[digestPos : digestPos+digestLength]
d2 = c.computeDigest(digestPos, isS1)
if bytes.Equal(d1, d2) {
publicKeyPos := c.publicKeyPos2()
publicKey := c.Data[publicKeyPos : publicKeyPos+dhKeyLength]
return d1, publicKey, nil
}
return nil, nil, fmt.Errorf("unable to validate C1/S1 digest")
}
func (c C1S1) validate(isS1 bool) ([]byte, []byte, error) {
digest, publicKey, err := c.validateDigest(isS1)
if err != nil {
return nil, nil, err
}
err = dhValidatePublicKey(publicKey)
if err != nil {
return nil, nil, err
}
return digest, publicKey, nil
}
func (c *C1S1) fillPlain() error {
c.Data = make([]byte, c1s1Size-8)
_, err := rand.Read(c.Data)
return err
}
func (c *C1S1) fill(isS1 bool, publicKey []byte) ([]byte, error) {
err := c.fillPlain()
if err != nil {
return nil, err
}
var r [1]byte
_, err = rand.Read(r[:])
if err != nil {
return nil, err
}
var digestPos int
var publicKeyPos int
if r[0] == 0 {
digestPos = c.digestPos1()
publicKeyPos = c.publicKeyPos1()
} else {
digestPos = c.digestPos2()
publicKeyPos = c.publicKeyPos2()
}
copy(c.Data[publicKeyPos:], publicKey)
digest := c.computeDigest(digestPos, isS1)
copy(c.Data[digestPos:], digest)
return digest, nil
}
// Write writes a C1S1.
func (c *C1S1) Write(w io.Writer, isC1 bool) error {
buf := make([]byte, 1536)
func (c C1S1) Write(w io.Writer) error {
buf := make([]byte, c1s1Size)
buf[0] = byte(c.Time >> 24)
buf[1] = byte(c.Time >> 16)
buf[2] = byte(c.Time >> 8)
buf[3] = byte(c.Time)
copy(buf[4:], []byte{0, 0, 0, 0})
if c.Random == nil {
_, err := rand.Read(buf[8:])
if err != nil {
return err
}
c.Random = buf[8:]
} else {
copy(buf[8:], c.Random)
}
// signature
gap := hsCalcDigestPos(buf, 8)
var peerKey []byte
var key []byte
if isC1 {
peerKey = hsServerFullKey
key = hsClientPartialKey
} else {
peerKey = hsClientFullKey
key = hsServerPartialKey
}
digest := hsMakeDigest(key, buf, gap)
copy(buf[gap:], digest)
pos := hsFindDigest(buf, key, 8)
c.Digest = hsMakeDigest(peerKey, buf[pos:pos+32], -1)
buf[4] = byte(c.Version >> 24)
buf[5] = byte(c.Version >> 16)
buf[6] = byte(c.Version >> 8)
buf[7] = byte(c.Version)
copy(buf[8:], c.Data)
_, err := w.Write(buf)
return err

View file

@ -7,132 +7,24 @@ import (
"github.com/stretchr/testify/require"
)
var c1enc = append(
[]byte{
0x19, 0xf1, 0x27, 0xa3, 0x00, 0x00, 0x00, 0x00,
0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04,
0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x2d, 0x0a,
0x37, 0x6f, 0x63, 0x2e, 0xa0, 0x21, 0xa0, 0xa4,
0x81, 0xb1, 0x50, 0x21, 0x5a, 0x6d, 0x81, 0xad,
0xf8, 0x44, 0x69, 0x13, 0xcc, 0x02, 0x8c, 0xd4,
0x64, 0x43, 0xc9, 0x9f, 0xcf, 0xc6, 0x03, 0x04,
},
bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 370)...,
)
var c1s1enc = bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 1536/4)
var s1enc = append(
[]byte{
0x19, 0xf1, 0x27, 0xa3, 0x00, 0x00, 0x00, 0x00,
0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04,
0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x95, 0xc1,
0xb6, 0x2c, 0x99, 0xbe, 0xa0, 0x0c, 0x07, 0x98,
0xb0, 0xf1, 0xbe, 0x54, 0x50, 0x63, 0xa1, 0x25,
0x1c, 0x9a, 0xcd, 0x12, 0x10, 0x98, 0x74, 0x8b,
0x18, 0x66, 0x8d, 0xef, 0xcf, 0x22, 0x03, 0x04,
},
bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 370)...,
)
var c1s1dec = C1S1{
Time: 16909060,
Version: 16909060,
Data: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 1536/4-2),
}
func TestC1S1Read(t *testing.T) {
for _, ca := range []struct {
isC1 bool
enc []byte
dec C1S1
}{
{
true,
c1enc,
C1S1{
Time: 435234723,
Random: append(
[]byte{
0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04,
0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x2d, 0x0a,
0x37, 0x6f, 0x63, 0x2e, 0xa0, 0x21, 0xa0, 0xa4,
0x81, 0xb1, 0x50, 0x21, 0x5a, 0x6d, 0x81, 0xad,
0xf8, 0x44, 0x69, 0x13, 0xcc, 0x02, 0x8c, 0xd4,
0x64, 0x43, 0xc9, 0x9f, 0xcf, 0xc6, 0x03, 0x04,
},
bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 370)...,
),
Digest: []byte{
0x3f, 0xd0, 0xb1, 0xdf, 0xed, 0x6c, 0x9b, 0xc3,
0x73, 0x68, 0xe2, 0x47, 0x92, 0x59, 0x32, 0x9a,
0x3a, 0xc9, 0x1e, 0xeb, 0xfc, 0xad, 0x8e, 0x9d,
0x4e, 0xf4, 0x30, 0xb1, 0x9a, 0xc9, 0x15, 0x99,
},
},
},
{
false,
s1enc,
C1S1{
Time: 435234723,
Random: append(
[]byte{
0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04,
0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x95, 0xc1,
0xb6, 0x2c, 0x99, 0xbe, 0xa0, 0x0c, 0x07, 0x98,
0xb0, 0xf1, 0xbe, 0x54, 0x50, 0x63, 0xa1, 0x25,
0x1c, 0x9a, 0xcd, 0x12, 0x10, 0x98, 0x74, 0x8b,
0x18, 0x66, 0x8d, 0xef, 0xcf, 0x22, 0x03, 0x04,
},
bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 370)...,
),
Digest: []byte{
0x0e, 0x8f, 0x96, 0x19, 0x19, 0xe6, 0xb7, 0xf2,
0xac, 0x9a, 0xc8, 0x7e, 0x6e, 0xe9, 0xd4, 0x72,
0xed, 0x82, 0x87, 0xf1, 0xfa, 0xbd, 0x93, 0xb8,
0x7c, 0x48, 0x85, 0x03, 0x01, 0x7b, 0x54, 0xbe,
},
},
},
} {
var c1s1 C1S1
err := c1s1.Read((bytes.NewReader(ca.enc)), ca.isC1, true)
require.NoError(t, err)
require.Equal(t, ca.dec, c1s1)
}
var c1s1 C1S1
err := c1s1.Read((bytes.NewReader(c1s1enc)))
require.NoError(t, err)
require.Equal(t, c1s1dec, c1s1)
}
func TestC1S1Write(t *testing.T) {
for _, ca := range []struct {
isC1 bool
enc []byte
dec C1S1
}{
{
true,
c1enc,
C1S1{
Time: 435234723,
Random: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 382),
Digest: []byte{
0x3f, 0xd0, 0xb1, 0xdf, 0xed, 0x6c, 0x9b, 0xc3,
0x73, 0x68, 0xe2, 0x47, 0x92, 0x59, 0x32, 0x9a,
0x3a, 0xc9, 0x1e, 0xeb, 0xfc, 0xad, 0x8e, 0x9d,
0x4e, 0xf4, 0x30, 0xb1, 0x9a, 0xc9, 0x15, 0x99,
},
},
},
{
false,
s1enc,
C1S1{
Time: 435234723,
Random: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 382),
Digest: []byte{
0x3f, 0xd0, 0xb1, 0xdf, 0xed, 0x6c, 0x9b, 0xc3,
0x73, 0x68, 0xe2, 0x47, 0x92, 0x59, 0x32, 0x9a,
0x3a, 0xc9, 0x1e, 0xeb, 0xfc, 0xad, 0x8e, 0x9d,
0x4e, 0xf4, 0x30, 0xb1, 0x9a, 0xc9, 0x15, 0x99,
},
},
},
} {
var buf bytes.Buffer
err := ca.dec.Write(&buf, ca.isC1)
require.NoError(t, err)
require.Equal(t, ca.enc, buf.Bytes())
}
var buf bytes.Buffer
err := c1s1dec.Write(&buf)
require.NoError(t, err)
require.Equal(t, c1s1enc, buf.Bytes())
}

View file

@ -2,44 +2,103 @@ package handshake
import (
"bytes"
"crypto/rand"
"fmt"
"io"
)
const (
c2s2Size = c1s1Size
c2s2DigestPos = c2s2Size - 8 - digestLength
)
var (
randomCrud = []byte{
0xf0, 0xee, 0xc2, 0x4a, 0x80, 0x68, 0xbe, 0xe8,
0x2e, 0x00, 0xd0, 0xd1, 0x02, 0x9e, 0x7e, 0x57,
0x6e, 0xec, 0x5d, 0x2d, 0x29, 0x80, 0x6f, 0xab,
0x93, 0xb8, 0xe6, 0x36, 0xcf, 0xeb, 0x31, 0xae,
}
clientKeyC2 = append([]byte(nil), append(clientKeyC1, randomCrud...)...)
serverKeyS2 = append([]byte(nil), append(serverKeyS1, randomCrud...)...)
)
// C2S2 is a C2 or S2 packet.
type C2S2 struct {
Time uint32
Time2 uint32
Random []byte
Digest []byte
Time uint32
Time2 uint32
Data []byte
}
// Read reads a C2S2.
func (c *C2S2) Read(r io.Reader, validateSignature bool) error {
buf := make([]byte, 1536)
func (c *C2S2) Read(r io.Reader) error {
buf := make([]byte, c2s2Size)
_, err := io.ReadFull(r, buf)
if err != nil {
return err
}
if validateSignature {
gap := len(buf) - 32
digest := hsMakeDigest(c.Digest, buf, gap)
if !bytes.Equal(buf[gap:gap+32], digest) {
return fmt.Errorf("unable to validate C2/S2 signature")
}
}
c.Time = uint32(buf[0])<<24 | uint32(buf[1])<<16 | uint32(buf[2])<<8 | uint32(buf[3])
c.Time2 = uint32(buf[4])<<24 | uint32(buf[5])<<16 | uint32(buf[6])<<8 | uint32(buf[7])
c.Random = buf[8:]
c.Data = buf[8:]
return nil
}
func (c C2S2) computeDigest(isS2 bool, prevDigest []byte) []byte {
// hash entire message except digest
msg := make([]byte, c2s2Size-digestLength)
msg[0] = byte(c.Time >> 24)
msg[1] = byte(c.Time >> 16)
msg[2] = byte(c.Time >> 8)
msg[3] = byte(c.Time)
msg[4] = byte(c.Time2 >> 24)
msg[5] = byte(c.Time2 >> 16)
msg[6] = byte(c.Time2 >> 8)
msg[7] = byte(c.Time2)
copy(msg[8:], c.Data[:c2s2DigestPos])
var key []byte
if isS2 {
key = hmacSha256(serverKeyS2, prevDigest)
} else {
key = hmacSha256(clientKeyC2, prevDigest)
}
return hmacSha256(key, msg)
}
func (c C2S2) validate(isS2 bool, prevDigest []byte) error {
d1 := c.Data[c2s2DigestPos : c2s2DigestPos+digestLength]
d2 := c.computeDigest(isS2, prevDigest)
if !bytes.Equal(d1, d2) {
return fmt.Errorf("unable to validate C2/S2 digest")
}
return nil
}
func (c *C2S2) fillPlain() error {
c.Data = make([]byte, c2s2Size-8)
_, err := rand.Read(c.Data)
return err
}
func (c *C2S2) fill(isS2 bool, prevDigest []byte) error {
err := c.fillPlain()
if err != nil {
return err
}
digest := c.computeDigest(isS2, prevDigest)
copy(c.Data[c2s2DigestPos:], digest)
return nil
}
// Write writes a C2S2.
func (c C2S2) Write(w io.Writer) error {
buf := make([]byte, 1536)
buf := make([]byte, c2s2Size)
buf[0] = byte(c.Time >> 24)
buf[1] = byte(c.Time >> 16)
@ -49,15 +108,7 @@ func (c C2S2) Write(w io.Writer) error {
buf[5] = byte(c.Time2 >> 16)
buf[6] = byte(c.Time2 >> 8)
buf[7] = byte(c.Time2)
copy(buf[8:], c.Random)
// signature
if c.Digest != nil {
gap := len(buf) - 32
digest := hsMakeDigest(c.Digest, buf, gap)
copy(buf[gap:], digest)
}
copy(buf[8:], c.Data)
_, err := w.Write(buf)
return err

View file

@ -7,71 +7,22 @@ import (
"github.com/stretchr/testify/require"
)
var c2s2enc = bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 1536/4)
var c2s2dec = C2S2{
Time: 16909060,
Time2: 16909060,
Data: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 1536/4-2),
}
func TestC2S2Read(t *testing.T) {
c2s2dec := C2S2{
Time: 435234723,
Time2: 7893542,
Random: append(
bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 372),
[]byte{
0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04,
0x96, 0x07, 0x2f, 0xe4, 0x04, 0xc5, 0x84, 0xa2,
0x21, 0x05, 0xcc, 0xb5, 0x7f, 0x93, 0x02, 0x14,
0xaf, 0xb0, 0x76, 0x75, 0xfd, 0x82, 0x29, 0xbe,
0xb9, 0x27, 0x9d, 0x4b, 0x0c, 0x81, 0x13, 0xec,
}...),
Digest: []byte{
0x3f, 0xd0, 0xb1, 0xdf, 0xed, 0x6c, 0x9b, 0xc3,
0x73, 0x68, 0xe2, 0x47, 0x92, 0x59, 0x32, 0x9a,
0x3a, 0xc9, 0x1e, 0xeb, 0xfc, 0xad, 0x8e, 0x9d,
0x4e, 0xf4, 0x30, 0xb1, 0x9a, 0xc9, 0x15, 0x99,
},
}
c2s2enc := append(append(
[]byte{
0x19, 0xf1, 0x27, 0xa3, 0x00, 0x78, 0x72, 0x26,
},
bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 374)...,
), []byte{
0x96, 0x07, 0x2f, 0xe4, 0x04, 0xc5, 0x84, 0xa2,
0x21, 0x05, 0xcc, 0xb5, 0x7f, 0x93, 0x02, 0x14,
0xaf, 0xb0, 0x76, 0x75, 0xfd, 0x82, 0x29, 0xbe,
0xb9, 0x27, 0x9d, 0x4b, 0x0c, 0x81, 0x13, 0xec,
}...)
var c2s2 C2S2
c2s2.Digest = c2s2dec.Digest
err := c2s2.Read((bytes.NewReader(c2s2enc)), true)
err := c2s2.Read((bytes.NewReader(c2s2enc)))
require.NoError(t, err)
require.Equal(t, c2s2dec, c2s2)
}
func TestC2S2Write(t *testing.T) {
c2s2dec := C2S2{
Time: 435234723,
Time2: 7893542,
Random: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 382),
Digest: []byte{
0x3f, 0xd0, 0xb1, 0xdf, 0xed, 0x6c, 0x9b, 0xc3,
0x73, 0x68, 0xe2, 0x47, 0x92, 0x59, 0x32, 0x9a,
0x3a, 0xc9, 0x1e, 0xeb, 0xfc, 0xad, 0x8e, 0x9d,
0x4e, 0xf4, 0x30, 0xb1, 0x9a, 0xc9, 0x15, 0x99,
},
}
c2s2enc := append(append(
[]byte{
0x19, 0xf1, 0x27, 0xa3, 0x00, 0x78, 0x72, 0x26,
},
bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 374)...,
), []byte{
0x96, 0x07, 0x2f, 0xe4, 0x04, 0xc5, 0x84, 0xa2,
0x21, 0x05, 0xcc, 0xb5, 0x7f, 0x93, 0x02, 0x14,
0xaf, 0xb0, 0x76, 0x75, 0xfd, 0x82, 0x29, 0xbe,
0xb9, 0x27, 0x9d, 0x4b, 0x0c, 0x81, 0x13, 0xec,
}...)
var buf bytes.Buffer
err := c2s2dec.Write(&buf)
require.NoError(t, err)

View file

@ -0,0 +1,118 @@
package handshake
import (
"crypto/rand"
"fmt"
"math/big"
)
const (
dhKeyLength = 128
)
var (
p1024 = []byte{
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xc9, 0x0f, 0xda, 0xa2, 0x21, 0x68, 0xc2, 0x34,
0xc4, 0xc6, 0x62, 0x8b, 0x80, 0xdc, 0x1c, 0xd1,
0x29, 0x02, 0x4e, 0x08, 0x8a, 0x67, 0xcc, 0x74,
0x02, 0x0b, 0xbe, 0xa6, 0x3b, 0x13, 0x9b, 0x22,
0x51, 0x4a, 0x08, 0x79, 0x8e, 0x34, 0x04, 0xdd,
0xef, 0x95, 0x19, 0xb3, 0xcd, 0x3a, 0x43, 0x1b,
0x30, 0x2b, 0x0a, 0x6d, 0xf2, 0x5f, 0x14, 0x37,
0x4f, 0xe1, 0x35, 0x6d, 0x6d, 0x51, 0xc2, 0x45,
0xe4, 0x85, 0xb5, 0x76, 0x62, 0x5e, 0x7e, 0xc6,
0xf4, 0x4c, 0x42, 0xe9, 0xa6, 0x37, 0xed, 0x6b,
0x0b, 0xff, 0x5c, 0xb6, 0xf4, 0x06, 0xb7, 0xed,
0xee, 0x38, 0x6b, 0xfb, 0x5a, 0x89, 0x9f, 0xa5,
0xae, 0x9f, 0x24, 0x11, 0x7c, 0x4b, 0x1f, 0xe6,
0x49, 0x28, 0x66, 0x51, 0xec, 0xe6, 0x53, 0x81,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
}
q1024 = []byte{
0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xe4, 0x87, 0xed, 0x51, 0x10, 0xb4, 0x61, 0x1a,
0x62, 0x63, 0x31, 0x45, 0xc0, 0x6e, 0x0e, 0x68,
0x94, 0x81, 0x27, 0x04, 0x45, 0x33, 0xe6, 0x3a,
0x01, 0x05, 0xdf, 0x53, 0x1d, 0x89, 0xcd, 0x91,
0x28, 0xa5, 0x04, 0x3c, 0xc7, 0x1a, 0x02, 0x6e,
0xf7, 0xca, 0x8c, 0xd9, 0xe6, 0x9d, 0x21, 0x8d,
0x98, 0x15, 0x85, 0x36, 0xf9, 0x2f, 0x8a, 0x1b,
0xa7, 0xf0, 0x9a, 0xb6, 0xb6, 0xa8, 0xe1, 0x22,
0xf2, 0x42, 0xda, 0xbb, 0x31, 0x2f, 0x3f, 0x63,
0x7a, 0x26, 0x21, 0x74, 0xd3, 0x1b, 0xf6, 0xb5,
0x85, 0xff, 0xae, 0x5b, 0x7a, 0x03, 0x5b, 0xf6,
0xf7, 0x1c, 0x35, 0xfd, 0xad, 0x44, 0xcf, 0xd2,
0xd7, 0x4f, 0x92, 0x08, 0xbe, 0x25, 0x8f, 0xf3,
0x24, 0x94, 0x33, 0x28, 0xf6, 0x73, 0x29, 0xc0,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
}
)
// https://datatracker.ietf.org/doc/html/rfc2631#section-2.1.5
func dhValidatePublicKey(key []byte) error {
// 1. y >= 2 && y < p
var y big.Int
y.SetBytes(key)
var two big.Int
two.SetUint64(2)
r := y.Cmp(&two)
if r < 0 {
return fmt.Errorf("key is < 2")
}
var p big.Int
p.SetBytes(p1024)
r = y.Cmp(&p)
if r >= 0 {
return fmt.Errorf("key is >= p")
}
// 2. (y^q mod p) == 1
var q big.Int
q.SetBytes(q1024)
var z big.Int
z.Exp(&y, &q, &p)
var one big.Int
one.SetUint64(1)
r = z.Cmp(&one)
if r != 0 {
return fmt.Errorf("y^q mod p is != 1")
}
return nil
}
func dhGenerateKeyPair() ([]byte, []byte, error) {
priv := make([]byte, dhKeyLength)
_, err := rand.Read(priv)
if err != nil {
return nil, nil, err
}
// y = g ^ x mod p
var g big.Int
g.SetUint64(2)
var x big.Int
x.SetBytes(priv)
var p big.Int
p.SetBytes(p1024)
var y big.Int
y.Exp(&g, &x, &p)
pub := y.Bytes()
return priv, pub, nil
}
// https://datatracker.ietf.org/doc/html/rfc2631#section-2.1.1
func dhComputeSharedSecret(priv []byte, pub []byte) []byte {
// ZZ = (ya ^ xb) mod p
var y big.Int
y.SetBytes(pub)
var x big.Int
x.SetBytes(priv)
var p big.Int
p.SetBytes(p1024)
var z big.Int
z.Exp(&y, &x, &p)
return z.Bytes()
}

View file

@ -2,96 +2,297 @@
package handshake
import (
"bytes"
"fmt"
"io"
)
// DoClient performs a client-side handshake.
func DoClient(rw io.ReadWriter, validateSignature bool) error {
c0 := C0S0{}
const (
encryptedVersion = 3<<24 | 5<<16 | 1<<8 | 1
)
func doClientEncrypted(rw io.ReadWriter) ([]byte, []byte, error) {
var c0 C0S0
c0.Version = 6
err := c0.Write(rw)
if err != nil {
return nil, nil, err
}
localPrivateKey, localPublicKey, err := dhGenerateKeyPair()
if err != nil {
return nil, nil, err
}
var c1 C1S1
c1Digest, err := c1.fill(false, localPublicKey)
if err != nil {
return nil, nil, err
}
err = c1.Write(rw)
if err != nil {
return nil, nil, err
}
var s0 C0S0
err = s0.Read(rw)
if err != nil {
return nil, nil, err
}
if s0.Version != 6 {
return nil, nil, fmt.Errorf("server replied with unexpected version %d", s0.Version)
}
var s1 C1S1
err = s1.Read(rw)
if err != nil {
return nil, nil, err
}
s1Digest, remotePublicKey, err := s1.validate(true)
if err != nil {
return nil, nil, err
}
var s2 C2S2
err = s2.Read(rw)
if err != nil {
return nil, nil, err
}
err = s2.validate(true, c1Digest)
if err != nil {
return nil, nil, err
}
var c2 C2S2
err = c2.fill(false, s1Digest)
if err != nil {
return nil, nil, err
}
err = c2.Write(rw)
if err != nil {
return nil, nil, err
}
sharedSecret := dhComputeSharedSecret(localPrivateKey, remotePublicKey)
keyIn := hmacSha256(sharedSecret, localPublicKey)[:16]
keyOut := hmacSha256(sharedSecret, remotePublicKey)[:16]
return keyIn, keyOut, nil
}
func doClientPlain(rw io.ReadWriter, strict bool) error {
var c0 C0S0
c0.Version = 3
err := c0.Write(rw)
if err != nil {
return err
}
c1 := C1S1{}
err = c1.Write(rw, true)
var c1 C1S1
err = c1.fillPlain()
if err != nil {
return err
}
s0 := C0S0{}
err = c1.Write(rw)
if err != nil {
return err
}
var s0 C0S0
err = s0.Read(rw)
if err != nil {
return err
}
s1 := C1S1{}
err = s1.Read(rw, false, validateSignature)
if s0.Version != 3 {
return fmt.Errorf("server replied with unexpected version %d", s0.Version)
}
var s1 C1S1
err = s1.Read(rw)
if err != nil {
return err
}
s2 := C2S2{
Digest: c1.Digest,
}
err = s2.Read(rw, validateSignature)
var s2 C2S2
err = s2.Read(rw)
if err != nil {
return err
}
c2 := C2S2{
Time: s1.Time,
Random: s1.Random,
Digest: s1.Digest,
if strict && !bytes.Equal(s2.Data, c1.Data) {
return fmt.Errorf("data in S2 does not correspond")
}
err = c2.Write(rw)
var c2 C2S2
c2.Data = s1.Data
return c2.Write(rw)
}
// DoClient performs a client-side handshake.
func DoClient(rw io.ReadWriter, encrypted bool, strict bool) ([]byte, []byte, error) {
if encrypted {
return doClientEncrypted(rw)
}
return nil, nil, doClientPlain(rw, strict)
}
func doServerEncrypted(rw io.ReadWriter) ([]byte, []byte, error) {
var c1 C1S1
err := c1.Read(rw)
if err != nil {
return nil, nil, err
}
c1Digest, remotePublicKey, err := c1.validate(false)
if err != nil {
return nil, nil, err
}
localPrivateKey, localPublicKey, err := dhGenerateKeyPair()
if err != nil {
return nil, nil, err
}
var s0 C0S0
s0.Version = 6
err = s0.Write(rw)
if err != nil {
return nil, nil, err
}
var s1 C1S1
s1.Version = encryptedVersion
s1Digest, err := s1.fill(true, localPublicKey)
if err != nil {
return nil, nil, err
}
err = s1.Write(rw)
if err != nil {
return nil, nil, err
}
var s2 C2S2
s2.Time2 = encryptedVersion
err = s2.fill(true, c1Digest)
if err != nil {
return nil, nil, err
}
err = s2.Write(rw)
if err != nil {
return nil, nil, err
}
var c2 C2S2
err = c2.Read(rw)
if err != nil {
return nil, nil, err
}
err = c2.validate(false, s1Digest)
if err != nil {
return nil, nil, err
}
sharedSecret := dhComputeSharedSecret(localPrivateKey, remotePublicKey)
keyIn := hmacSha256(sharedSecret, localPublicKey)[:16]
keyOut := hmacSha256(sharedSecret, remotePublicKey)[:16]
return keyIn, keyOut, nil
}
func doServerPlain(rw io.ReadWriter, strict bool) error {
var c1 C1S1
err := c1.Read(rw)
if err != nil {
return err
}
var s0 C0S0
s0.Version = 3
err = s0.Write(rw)
if err != nil {
return err
}
var s1 C1S1
err = s1.fillPlain()
if err != nil {
return err
}
err = s1.Write(rw)
if err != nil {
return err
}
var s2 C2S2
s2.Data = c1.Data
err = s2.Write(rw)
if err != nil {
return err
}
var c2 C2S2
err = c2.Read(rw)
if err != nil {
return err
}
if strict && !bytes.Equal(c2.Data, s1.Data) {
return fmt.Errorf("data in C2 does not correspond")
}
return nil
}
// DoServer performs a server-side handshake.
func DoServer(rw io.ReadWriter, validateSignature bool) error {
err := C0S0{}.Read(rw)
func DoServer(rw io.ReadWriter, strict bool) ([]byte, []byte, error) {
var c0 C0S0
err := c0.Read(rw)
if err != nil {
return err
return nil, nil, err
}
c1 := C1S1{}
err = c1.Read(rw, true, validateSignature)
if err != nil {
return err
if c0.Version == 6 {
return doServerEncrypted(rw)
}
s0 := C0S0{}
err = s0.Write(rw)
if err != nil {
return err
}
s1 := C1S1{}
err = s1.Write(rw, false)
if err != nil {
return err
}
s2 := C2S2{
Time: c1.Time,
Random: c1.Random,
Digest: c1.Digest,
}
err = s2.Write(rw)
if err != nil {
return err
}
c2 := C2S2{Digest: s1.Digest}
err = c2.Read(rw, validateSignature)
if err != nil {
return err
}
return nil
return nil, nil, doServerPlain(rw, strict)
}

View file

@ -1,91 +1,50 @@
package handshake
import (
"crypto/rand"
"net"
"testing"
"github.com/stretchr/testify/require"
)
type testReadWriter struct {
ch chan []byte
}
func (rw *testReadWriter) Read(p []byte) (int, error) {
in := <-rw.ch
n := copy(p, in)
return n, nil
}
func (rw *testReadWriter) Write(p []byte) (int, error) {
rw.ch <- p
return len(p), nil
}
func TestHandshake(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:9122")
require.NoError(t, err)
defer ln.Close()
for _, ca := range []string{"plain", "encrypted"} {
t.Run(ca, func(t *testing.T) {
rw := &testReadWriter{ch: make(chan []byte)}
var serverInKey []byte
var serverOutKey []byte
done := make(chan struct{})
done := make(chan struct{})
go func() {
var err error
serverInKey, serverOutKey, err = DoServer(rw, true)
require.NoError(t, err)
close(done)
}()
go func() {
conn, err := ln.Accept()
require.NoError(t, err)
defer conn.Close()
clientInKey, clientOutKey, err := DoClient(rw, ca == "encrypted", true)
require.NoError(t, err)
<-done
err = DoServer(conn, true)
require.NoError(t, err)
close(done)
}()
conn, err := net.Dial("tcp", "127.0.0.1:9122")
require.NoError(t, err)
defer conn.Close()
err = DoClient(conn, true)
require.NoError(t, err)
<-done
}
// when C1 signature is invalid, S2 must be equal to C1.
func TestHandshakeFallback(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:9122")
require.NoError(t, err)
defer ln.Close()
done := make(chan struct{})
go func() {
conn, err := ln.Accept()
require.NoError(t, err)
defer conn.Close()
err = DoServer(conn, false)
require.NoError(t, err)
close(done)
}()
conn, err := net.Dial("tcp", "127.0.0.1:9122")
require.NoError(t, err)
defer conn.Close()
err = C0S0{}.Write(conn)
require.NoError(t, err)
c1 := make([]byte, 1536)
_, err = rand.Read(c1[8:])
require.NoError(t, err)
_, err = conn.Write(c1)
require.NoError(t, err)
err = C0S0{}.Read(conn)
require.NoError(t, err)
s1 := C1S1{}
err = s1.Read(conn, false, false)
require.NoError(t, err)
s2 := C2S2{}
err = s2.Read(conn, false)
require.NoError(t, err)
require.Equal(t, c1[8:], s2.Random)
err = C2S2{
Time: s1.Time,
Random: s1.Random,
Digest: s1.Digest,
}.Write(conn)
require.NoError(t, err)
<-done
if ca == "encrypted" {
require.NotNil(t, serverInKey)
require.Equal(t, serverInKey, clientOutKey)
require.Equal(t, serverOutKey, clientInKey)
}
})
}
}

View file

@ -2,6 +2,7 @@ package message
import (
"fmt"
"io"
"github.com/bluenviron/mediamtx/internal/rtmp/bytecounter"
"github.com/bluenviron/mediamtx/internal/rtmp/rawmessage"
@ -116,9 +117,13 @@ type Reader struct {
}
// NewReader allocates a Reader.
func NewReader(r *bytecounter.Reader, onAckNeeded func(uint32) error) *Reader {
func NewReader(
r io.Reader,
bcr *bytecounter.Reader,
onAckNeeded func(uint32) error,
) *Reader {
return &Reader{
r: rawmessage.NewReader(r, onAckNeeded),
r: rawmessage.NewReader(r, bcr, onAckNeeded),
}
}

View file

@ -268,7 +268,8 @@ var readWriterCases = []struct {
func TestReader(t *testing.T) {
for _, ca := range readWriterCases {
t.Run(ca.name, func(t *testing.T) {
r := NewReader(bytecounter.NewReader(bytes.NewReader(ca.enc)), nil)
bc := bytecounter.NewReader(bytes.NewReader(ca.enc))
r := NewReader(bc, bc, nil)
dec, err := r.Read()
require.NoError(t, err)
require.Equal(t, ca.dec, dec)

View file

@ -1,6 +1,8 @@
package message
import (
"io"
"github.com/bluenviron/mediamtx/internal/rtmp/bytecounter"
)
@ -11,10 +13,14 @@ type ReadWriter struct {
}
// NewReadWriter allocates a ReadWriter.
func NewReadWriter(bc *bytecounter.ReadWriter, checkAcknowledge bool) *ReadWriter {
w := NewWriter(bc.Writer, checkAcknowledge)
func NewReadWriter(
rw io.ReadWriter,
bcrw *bytecounter.ReadWriter,
checkAcknowledge bool,
) *ReadWriter {
w := NewWriter(rw, bcrw.Writer, checkAcknowledge)
r := NewReader(bc.Reader, func(count uint32) error {
r := NewReader(rw, bcrw.Reader, func(count uint32) error {
return w.Write(&Acknowledge{
Value: count,
})

View file

@ -27,19 +27,21 @@ func TestReadWriterAcknowledge(t *testing.T) {
var buf1 bytes.Buffer
var buf2 bytes.Buffer
rw1 := NewReadWriter(bytecounter.NewReadWriter(&duplexRW{
bc1 := bytecounter.NewReadWriter(&duplexRW{
Reader: &buf2,
Writer: &buf1,
}), true)
})
rw1 := NewReadWriter(bc1, bc1, true)
err := rw1.Write(&Acknowledge{
Value: 7863534,
})
require.NoError(t, err)
rw2 := NewReadWriter(bytecounter.NewReadWriter(&duplexRW{
bc2 := bytecounter.NewReadWriter(&duplexRW{
Reader: &buf1,
Writer: &buf2,
}), true)
})
rw2 := NewReadWriter(bc2, bc2, true)
_, err = rw2.Read()
require.NoError(t, err)
}
@ -48,19 +50,21 @@ func TestReadWriterPing(t *testing.T) {
var buf1 bytes.Buffer
var buf2 bytes.Buffer
rw1 := NewReadWriter(bytecounter.NewReadWriter(&duplexRW{
bc1 := bytecounter.NewReadWriter(&duplexRW{
Reader: &buf2,
Writer: &buf1,
}), true)
})
rw1 := NewReadWriter(bc1, bc1, true)
err := rw1.Write(&UserControlPingRequest{
ServerTime: 143424312,
})
require.NoError(t, err)
rw2 := NewReadWriter(bytecounter.NewReadWriter(&duplexRW{
bc2 := bytecounter.NewReadWriter(&duplexRW{
Reader: &buf1,
Writer: &buf2,
}), true)
})
rw2 := NewReadWriter(bc2, bc2, true)
_, err = rw2.Read()
require.NoError(t, err)

View file

@ -1,6 +1,8 @@
package message
import (
"io"
"github.com/bluenviron/mediamtx/internal/rtmp/bytecounter"
"github.com/bluenviron/mediamtx/internal/rtmp/rawmessage"
)
@ -11,9 +13,13 @@ type Writer struct {
}
// NewWriter allocates a Writer.
func NewWriter(w *bytecounter.Writer, checkAcknowledge bool) *Writer {
func NewWriter(
w io.Writer,
bcw *bytecounter.Writer,
checkAcknowledge bool,
) *Writer {
return &Writer{
w: rawmessage.NewWriter(w, checkAcknowledge),
w: rawmessage.NewWriter(w, bcw, checkAcknowledge),
}
}

View file

@ -13,7 +13,8 @@ func TestWriter(t *testing.T) {
for _, ca := range readWriterCases {
t.Run(ca.name, func(t *testing.T) {
var buf bytes.Buffer
r := NewWriter(bytecounter.NewWriter(&buf), true)
bc := bytecounter.NewWriter(&buf)
r := NewWriter(bc, bc, true)
err := r.Write(ca.dec)
require.NoError(t, err)
require.Equal(t, ca.enc, buf.Bytes())

View file

@ -4,6 +4,7 @@ import (
"bufio"
"errors"
"fmt"
"io"
"time"
"github.com/bluenviron/mediamtx/internal/rtmp/bytecounter"
@ -30,7 +31,7 @@ func (rc *readerChunkStream) readChunk(c chunk.Chunk, chunkBodySize uint32) erro
// check if an ack is needed
if rc.mr.ackWindowSize != 0 {
count := uint32(rc.mr.r.Count())
count := uint32(rc.mr.bcr.Count())
diff := count - rc.mr.lastAckCount
if diff > (rc.mr.ackWindowSize) {
@ -208,7 +209,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
// Reader is a raw message reader.
type Reader struct {
r *bytecounter.Reader
bcr *bytecounter.Reader
onAckNeeded func(uint32) error
br *bufio.Reader
@ -224,9 +225,13 @@ type Reader struct {
}
// NewReader allocates a Reader.
func NewReader(r *bytecounter.Reader, onAckNeeded func(uint32) error) *Reader {
func NewReader(
r io.Reader,
bcr *bytecounter.Reader,
onAckNeeded func(uint32) error,
) *Reader {
return &Reader{
r: r,
bcr: bcr,
br: bufio.NewReader(r),
onAckNeeded: onAckNeeded,
chunkSize: 128,

View file

@ -198,7 +198,8 @@ func TestReader(t *testing.T) {
for _, ca := range cases {
t.Run(ca.name, func(t *testing.T) {
var buf bytes.Buffer
r := NewReader(bytecounter.NewReader(&buf), func(count uint32) error {
br := bytecounter.NewReader(&buf)
r := NewReader(br, br, func(count uint32) error {
return nil
})
@ -224,7 +225,7 @@ func TestReaderAcknowledge(t *testing.T) {
var buf bytes.Buffer
bc := bytecounter.NewReader(&buf)
r := NewReader(bc, func(count uint32) error {
r := NewReader(bc, bc, func(count uint32) error {
close(onAckCalled)
return nil
})

View file

@ -3,6 +3,7 @@ package rawmessage
import (
"bufio"
"fmt"
"io"
"time"
"github.com/bluenviron/mediamtx/internal/rtmp/bytecounter"
@ -21,7 +22,7 @@ type writerChunkStream struct {
func (wc *writerChunkStream) writeChunk(c chunk.Chunk) error {
// check if we received an acknowledge
if wc.mw.checkAcknowledge && wc.mw.ackWindowSize != 0 {
diff := uint32(wc.mw.w.Count()) - wc.mw.ackValue
diff := uint32(wc.mw.bcw.Count()) - wc.mw.ackValue
if diff > (wc.mw.ackWindowSize * 3 / 2) {
return fmt.Errorf("no acknowledge received within window")
@ -148,7 +149,7 @@ func (wc *writerChunkStream) writeMessage(msg *Message) error {
// Writer is a raw message writer.
type Writer struct {
w *bytecounter.Writer
bcw *bytecounter.Writer
bw *bufio.Writer
checkAcknowledge bool
chunkSize uint32
@ -158,9 +159,13 @@ type Writer struct {
}
// NewWriter allocates a Writer.
func NewWriter(w *bytecounter.Writer, checkAcknowledge bool) *Writer {
func NewWriter(
w io.Writer,
bcw *bytecounter.Writer,
checkAcknowledge bool,
) *Writer {
return &Writer{
w: w,
bcw: bcw,
bw: bufio.NewWriter(w),
checkAcknowledge: checkAcknowledge,
chunkSize: 128,

View file

@ -15,7 +15,8 @@ func TestWriter(t *testing.T) {
for _, ca := range cases {
t.Run(ca.name, func(t *testing.T) {
var buf bytes.Buffer
w := NewWriter(bytecounter.NewWriter(&buf), true)
bc := bytecounter.NewWriter(&buf)
w := NewWriter(bc, bc, true)
for _, msg := range ca.messages {
err := w.Write(msg)
@ -36,11 +37,11 @@ func TestWriterAcknowledge(t *testing.T) {
for _, ca := range []string{"standard", "overflow"} {
t.Run(ca, func(t *testing.T) {
var buf bytes.Buffer
bcw := bytecounter.NewWriter(&buf)
w := NewWriter(bcw, true)
bc := bytecounter.NewWriter(&buf)
w := NewWriter(bc, bc, true)
if ca == "overflow" {
bcw.SetCount(4294967096)
bc.SetCount(4294967096)
w.ackValue = 4294967096
}

View file

@ -0,0 +1,49 @@
package rtmp
import (
"crypto/rc4"
"io"
)
type rc4ReadWriter struct {
rw io.ReadWriter
in *rc4.Cipher
out *rc4.Cipher
}
func newRC4ReadWriter(rw io.ReadWriter, keyIn []byte, keyOut []byte) (*rc4ReadWriter, error) {
in, err := rc4.NewCipher(keyIn)
if err != nil {
return nil, err
}
out, err := rc4.NewCipher(keyOut)
if err != nil {
return nil, err
}
p := make([]byte, 1536)
in.XORKeyStream(p, p)
out.XORKeyStream(p, p)
return &rc4ReadWriter{
rw: rw,
in: in,
out: out,
}, nil
}
func (r *rc4ReadWriter) Read(p []byte) (int, error) {
n, err := r.rw.Read(p)
if n == 0 {
return 0, err
}
r.in.XORKeyStream(p[:n], p[:n])
return n, err
}
func (r *rc4ReadWriter) Write(p []byte) (int, error) {
r.out.XORKeyStream(p, p)
return r.rw.Write(p)
}

View file

@ -166,7 +166,7 @@ func TestReadTracks(t *testing.T) {
t.Run(ca.name, func(t *testing.T) {
var buf bytes.Buffer
bc := bytecounter.NewReadWriter(&buf)
mrw := message.NewReadWriter(bc, true)
mrw := message.NewReadWriter(bc, bc, true)
switch ca.name {
case "video+audio":

View file

@ -46,7 +46,7 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
bc := bytecounter.NewReadWriter(&buf)
mrw := message.NewReadWriter(bc, true)
mrw := message.NewReadWriter(bc, bc, true)
msg, err := mrw.Read()
require.NoError(t, err)