mirror of
https://github.com/bluenviron/mediamtx.git
synced 2025-12-28 05:51:59 -08:00
rtmp: support ingesting RTMPE streams (#2189)
This commit is contained in:
parent
57a436b0d5
commit
7e180ceea2
24 changed files with 806 additions and 497 deletions
|
|
@ -156,12 +156,12 @@ func NewClientConn(rw io.ReadWriter, u *url.URL, publish bool) (*Conn, error) {
|
|||
func (c *Conn) initializeClient(u *url.URL, publish bool) error {
|
||||
connectpath, actionpath := splitPath(u)
|
||||
|
||||
err := handshake.DoClient(c.bc, false)
|
||||
_, _, err := handshake.DoClient(c.bc, false, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.mrw = message.NewReadWriter(c.bc, false)
|
||||
c.mrw = message.NewReadWriter(c.bc, c.bc, false)
|
||||
|
||||
err = c.mrw.Write(&message.SetWindowAckSize{
|
||||
Value: 2500000,
|
||||
|
|
@ -329,12 +329,23 @@ func NewServerConn(rw io.ReadWriter) (*Conn, *url.URL, bool, error) {
|
|||
}
|
||||
|
||||
func (c *Conn) initializeServer() (*url.URL, bool, error) {
|
||||
err := handshake.DoServer(c.bc, false)
|
||||
keyIn, keyOut, err := handshake.DoServer(c.bc, false)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
c.mrw = message.NewReadWriter(c.bc, false)
|
||||
var rw io.ReadWriter
|
||||
if keyIn != nil {
|
||||
var err error
|
||||
rw, err = newRC4ReadWriter(c.bc, keyIn, keyOut)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
} else {
|
||||
rw = c.bc
|
||||
}
|
||||
|
||||
c.mrw = message.NewReadWriter(rw, c.bc, false)
|
||||
|
||||
cmd, err := readCommand(c.mrw)
|
||||
if err != nil {
|
||||
|
|
@ -581,7 +592,7 @@ func newNoHandshakeConn(rw io.ReadWriter) *Conn {
|
|||
bc: bytecounter.NewReadWriter(rw),
|
||||
}
|
||||
|
||||
c.mrw = message.NewReadWriter(c.bc, false)
|
||||
c.mrw = message.NewReadWriter(c.bc, c.bc, false)
|
||||
|
||||
return c
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,10 +29,10 @@ func TestNewClientConn(t *testing.T) {
|
|||
defer conn.Close()
|
||||
bc := bytecounter.NewReadWriter(conn)
|
||||
|
||||
err = handshake.DoServer(bc, true)
|
||||
_, _, err = handshake.DoServer(bc, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
mrw := message.NewReadWriter(bc, true)
|
||||
mrw := message.NewReadWriter(bc, bc, true)
|
||||
|
||||
msg, err := mrw.Read()
|
||||
require.NoError(t, err)
|
||||
|
|
@ -289,10 +289,10 @@ func TestNewServerConn(t *testing.T) {
|
|||
defer conn.Close()
|
||||
bc := bytecounter.NewReadWriter(conn)
|
||||
|
||||
err = handshake.DoClient(bc, true)
|
||||
_, _, err = handshake.DoClient(bc, false, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
mrw := message.NewReadWriter(bc, true)
|
||||
mrw := message.NewReadWriter(bc, bc, true)
|
||||
|
||||
tcURL := "rtmp://127.0.0.1:9121/stream"
|
||||
if ca == "publish neko" {
|
||||
|
|
|
|||
|
|
@ -5,30 +5,30 @@ import (
|
|||
"io"
|
||||
)
|
||||
|
||||
const (
|
||||
rtmpVersion = 0x03
|
||||
)
|
||||
|
||||
// C0S0 is a C0 or S0 packet.
|
||||
type C0S0 struct{}
|
||||
type C0S0 struct {
|
||||
Version byte
|
||||
}
|
||||
|
||||
// Read reads a C0S0.
|
||||
func (C0S0) Read(r io.Reader) error {
|
||||
func (c *C0S0) Read(r io.Reader) error {
|
||||
buf := make([]byte, 1)
|
||||
_, err := io.ReadFull(r, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if buf[0] != rtmpVersion {
|
||||
return fmt.Errorf("invalid rtmp version (%d)", buf[0])
|
||||
c.Version = buf[0]
|
||||
|
||||
if c.Version != 3 && c.Version != 6 {
|
||||
return fmt.Errorf("invalid rtmp version (%d)", c.Version)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write writes a C0S0.
|
||||
func (C0S0) Write(w io.Writer) error {
|
||||
_, err := w.Write([]byte{rtmpVersion})
|
||||
func (c C0S0) Write(w io.Writer) error {
|
||||
_, err := w.Write([]byte{c.Version})
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,9 +7,11 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var c0s0enc = []byte{0x03}
|
||||
var c0s0enc = []byte{3}
|
||||
|
||||
var c0s0dec = C0S0{}
|
||||
var c0s0dec = C0S0{
|
||||
Version: 3,
|
||||
}
|
||||
|
||||
func TestC0S0Read(t *testing.T) {
|
||||
var c0s0 C0S0
|
||||
|
|
|
|||
|
|
@ -9,140 +9,180 @@ import (
|
|||
"io"
|
||||
)
|
||||
|
||||
var (
|
||||
hsClientFullKey = []byte{
|
||||
'G', 'e', 'n', 'u', 'i', 'n', 'e', ' ', 'A', 'd', 'o', 'b', 'e', ' ',
|
||||
'F', 'l', 'a', 's', 'h', ' ', 'P', 'l', 'a', 'y', 'e', 'r', ' ',
|
||||
'0', '0', '1',
|
||||
0xF0, 0xEE, 0xC2, 0x4A, 0x80, 0x68, 0xBE, 0xE8, 0x2E, 0x00, 0xD0, 0xD1,
|
||||
0x02, 0x9E, 0x7E, 0x57, 0x6E, 0xEC, 0x5D, 0x2D, 0x29, 0x80, 0x6F, 0xAB,
|
||||
0x93, 0xB8, 0xE6, 0x36, 0xCF, 0xEB, 0x31, 0xAE,
|
||||
}
|
||||
hsServerFullKey = []byte{
|
||||
'G', 'e', 'n', 'u', 'i', 'n', 'e', ' ', 'A', 'd', 'o', 'b', 'e', ' ',
|
||||
'F', 'l', 'a', 's', 'h', ' ', 'M', 'e', 'd', 'i', 'a', ' ',
|
||||
'S', 'e', 'r', 'v', 'e', 'r', ' ',
|
||||
'0', '0', '1',
|
||||
0xF0, 0xEE, 0xC2, 0x4A, 0x80, 0x68, 0xBE, 0xE8, 0x2E, 0x00, 0xD0, 0xD1,
|
||||
0x02, 0x9E, 0x7E, 0x57, 0x6E, 0xEC, 0x5D, 0x2D, 0x29, 0x80, 0x6F, 0xAB,
|
||||
0x93, 0xB8, 0xE6, 0x36, 0xCF, 0xEB, 0x31, 0xAE,
|
||||
}
|
||||
hsClientPartialKey = hsClientFullKey[:30]
|
||||
hsServerPartialKey = hsServerFullKey[:36]
|
||||
const (
|
||||
c1s1Size = 1536
|
||||
digestPointerPos1 = 0
|
||||
digestPointerPos2 = 772 - 8
|
||||
digestChunkPos1 = digestPointerPos1 + 4
|
||||
digestChunkPos2 = digestPointerPos2 + 4
|
||||
digestChunkLength = 728
|
||||
digestLength = 32
|
||||
publicKeyPointerPos1 = 1532 - 8
|
||||
publicKeyPointerPos2 = 768 - 8
|
||||
publicKeyChunkPos1 = publicKeyPointerPos1 - 760
|
||||
publicKeyChunkPos2 = publicKeyPointerPos2 - 760
|
||||
publicKeyChunkLength = 632
|
||||
)
|
||||
|
||||
func hsCalcDigestPos(p []byte, base int) int {
|
||||
pos := 0
|
||||
for i := 0; i < 4; i++ {
|
||||
pos += int(p[base+i])
|
||||
}
|
||||
return (pos % 728) + base + 4
|
||||
}
|
||||
var (
|
||||
clientKeyC1 = []byte("Genuine Adobe Flash Player 001")
|
||||
serverKeyS1 = []byte("Genuine Adobe Flash Media Server 001")
|
||||
)
|
||||
|
||||
func hsMakeDigest(key []byte, src []byte, gap int) []byte {
|
||||
func hmacSha256(key []byte, buf []byte) []byte {
|
||||
h := hmac.New(sha256.New, key)
|
||||
if gap <= 0 {
|
||||
h.Write(src)
|
||||
} else {
|
||||
h.Write(src[:gap])
|
||||
h.Write(src[gap+32:])
|
||||
}
|
||||
h.Write(buf)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func hsFindDigest(p []byte, key []byte, base int) int {
|
||||
gap := hsCalcDigestPos(p, base)
|
||||
digest := hsMakeDigest(key, p, gap)
|
||||
if !bytes.Equal(p[gap:gap+32], digest) {
|
||||
return -1
|
||||
}
|
||||
return gap
|
||||
}
|
||||
|
||||
func hsParse1(p []byte, peerkey []byte, key []byte) (bool, []byte) {
|
||||
var pos int
|
||||
if pos = hsFindDigest(p, peerkey, 772); pos == -1 {
|
||||
if pos = hsFindDigest(p, peerkey, 8); pos == -1 {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
return true, hsMakeDigest(key, p[pos:pos+32], -1)
|
||||
}
|
||||
|
||||
// C1S1 is a C1 or S1 packet.
|
||||
type C1S1 struct {
|
||||
Time uint32
|
||||
Random []byte
|
||||
Digest []byte
|
||||
Time uint32
|
||||
Version uint32
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// Read reads a C1S1.
|
||||
func (c *C1S1) Read(r io.Reader, isC1 bool, validateSignature bool) error {
|
||||
buf := make([]byte, 1536)
|
||||
func (c *C1S1) Read(r io.Reader) error {
|
||||
buf := make([]byte, c1s1Size)
|
||||
_, err := io.ReadFull(r, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var peerKey []byte
|
||||
var key []byte
|
||||
if isC1 {
|
||||
peerKey = hsClientPartialKey
|
||||
key = hsServerFullKey
|
||||
} else {
|
||||
peerKey = hsServerPartialKey
|
||||
key = hsClientFullKey
|
||||
}
|
||||
ok, digest := hsParse1(buf, peerKey, key)
|
||||
if !ok {
|
||||
if validateSignature {
|
||||
return fmt.Errorf("unable to validate C1/S1 signature")
|
||||
}
|
||||
} else {
|
||||
c.Digest = digest
|
||||
}
|
||||
|
||||
c.Time = uint32(buf[0])<<24 | uint32(buf[1])<<16 | uint32(buf[2])<<8 | uint32(buf[3])
|
||||
c.Random = buf[8:]
|
||||
c.Version = uint32(buf[4])<<24 | uint32(buf[5])<<16 | uint32(buf[6])<<8 | uint32(buf[7])
|
||||
c.Data = buf[8:]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c C1S1) readPointer(p int) int {
|
||||
return int(c.Data[p]) + int(c.Data[p+1]) + int(c.Data[p+2]) + int(c.Data[p+3])
|
||||
}
|
||||
|
||||
func (c C1S1) publicKeyPos1() int {
|
||||
return publicKeyChunkPos1 + (c.readPointer(publicKeyPointerPos1) % publicKeyChunkLength)
|
||||
}
|
||||
|
||||
func (c C1S1) publicKeyPos2() int {
|
||||
return publicKeyChunkPos2 + (c.readPointer(publicKeyPointerPos2) % publicKeyChunkLength)
|
||||
}
|
||||
|
||||
func (c C1S1) digestPos1() int {
|
||||
return digestChunkPos1 + (c.readPointer(digestPointerPos1) % digestChunkLength)
|
||||
}
|
||||
|
||||
func (c C1S1) digestPos2() int {
|
||||
return digestChunkPos2 + (c.readPointer(digestPointerPos2) % digestChunkLength)
|
||||
}
|
||||
|
||||
func (c C1S1) computeDigest(digestPos int, isS1 bool) []byte {
|
||||
// hash entire message except digest
|
||||
msg := make([]byte, c1s1Size-digestLength)
|
||||
msg[0] = byte(c.Time >> 24)
|
||||
msg[1] = byte(c.Time >> 16)
|
||||
msg[2] = byte(c.Time >> 8)
|
||||
msg[3] = byte(c.Time)
|
||||
msg[4] = byte(c.Version >> 24)
|
||||
msg[5] = byte(c.Version >> 16)
|
||||
msg[6] = byte(c.Version >> 8)
|
||||
msg[7] = byte(c.Version)
|
||||
copy(msg[8:], c.Data[:digestPos])
|
||||
copy(msg[8+digestPos:], c.Data[digestPos+digestLength:])
|
||||
|
||||
if isS1 {
|
||||
return hmacSha256(serverKeyS1, msg)
|
||||
}
|
||||
return hmacSha256(clientKeyC1, msg)
|
||||
}
|
||||
|
||||
func (c C1S1) validateDigest(isS1 bool) ([]byte, []byte, error) {
|
||||
digestPos := c.digestPos1()
|
||||
d1 := c.Data[digestPos : digestPos+digestLength]
|
||||
d2 := c.computeDigest(digestPos, isS1)
|
||||
|
||||
if bytes.Equal(d1, d2) {
|
||||
publicKeyPos := c.publicKeyPos1()
|
||||
publicKey := c.Data[publicKeyPos : publicKeyPos+dhKeyLength]
|
||||
return d1, publicKey, nil
|
||||
}
|
||||
|
||||
digestPos = c.digestPos2()
|
||||
d1 = c.Data[digestPos : digestPos+digestLength]
|
||||
d2 = c.computeDigest(digestPos, isS1)
|
||||
|
||||
if bytes.Equal(d1, d2) {
|
||||
publicKeyPos := c.publicKeyPos2()
|
||||
publicKey := c.Data[publicKeyPos : publicKeyPos+dhKeyLength]
|
||||
return d1, publicKey, nil
|
||||
}
|
||||
|
||||
return nil, nil, fmt.Errorf("unable to validate C1/S1 digest")
|
||||
}
|
||||
|
||||
func (c C1S1) validate(isS1 bool) ([]byte, []byte, error) {
|
||||
digest, publicKey, err := c.validateDigest(isS1)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
err = dhValidatePublicKey(publicKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return digest, publicKey, nil
|
||||
}
|
||||
|
||||
func (c *C1S1) fillPlain() error {
|
||||
c.Data = make([]byte, c1s1Size-8)
|
||||
_, err := rand.Read(c.Data)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *C1S1) fill(isS1 bool, publicKey []byte) ([]byte, error) {
|
||||
err := c.fillPlain()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var r [1]byte
|
||||
_, err = rand.Read(r[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var digestPos int
|
||||
var publicKeyPos int
|
||||
|
||||
if r[0] == 0 {
|
||||
digestPos = c.digestPos1()
|
||||
publicKeyPos = c.publicKeyPos1()
|
||||
} else {
|
||||
digestPos = c.digestPos2()
|
||||
publicKeyPos = c.publicKeyPos2()
|
||||
}
|
||||
|
||||
copy(c.Data[publicKeyPos:], publicKey)
|
||||
digest := c.computeDigest(digestPos, isS1)
|
||||
copy(c.Data[digestPos:], digest)
|
||||
return digest, nil
|
||||
}
|
||||
|
||||
// Write writes a C1S1.
|
||||
func (c *C1S1) Write(w io.Writer, isC1 bool) error {
|
||||
buf := make([]byte, 1536)
|
||||
func (c C1S1) Write(w io.Writer) error {
|
||||
buf := make([]byte, c1s1Size)
|
||||
|
||||
buf[0] = byte(c.Time >> 24)
|
||||
buf[1] = byte(c.Time >> 16)
|
||||
buf[2] = byte(c.Time >> 8)
|
||||
buf[3] = byte(c.Time)
|
||||
copy(buf[4:], []byte{0, 0, 0, 0})
|
||||
|
||||
if c.Random == nil {
|
||||
_, err := rand.Read(buf[8:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Random = buf[8:]
|
||||
} else {
|
||||
copy(buf[8:], c.Random)
|
||||
}
|
||||
|
||||
// signature
|
||||
gap := hsCalcDigestPos(buf, 8)
|
||||
var peerKey []byte
|
||||
var key []byte
|
||||
if isC1 {
|
||||
peerKey = hsServerFullKey
|
||||
key = hsClientPartialKey
|
||||
} else {
|
||||
peerKey = hsClientFullKey
|
||||
key = hsServerPartialKey
|
||||
}
|
||||
digest := hsMakeDigest(key, buf, gap)
|
||||
copy(buf[gap:], digest)
|
||||
pos := hsFindDigest(buf, key, 8)
|
||||
c.Digest = hsMakeDigest(peerKey, buf[pos:pos+32], -1)
|
||||
buf[4] = byte(c.Version >> 24)
|
||||
buf[5] = byte(c.Version >> 16)
|
||||
buf[6] = byte(c.Version >> 8)
|
||||
buf[7] = byte(c.Version)
|
||||
copy(buf[8:], c.Data)
|
||||
|
||||
_, err := w.Write(buf)
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -7,132 +7,24 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var c1enc = append(
|
||||
[]byte{
|
||||
0x19, 0xf1, 0x27, 0xa3, 0x00, 0x00, 0x00, 0x00,
|
||||
0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04,
|
||||
0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x2d, 0x0a,
|
||||
0x37, 0x6f, 0x63, 0x2e, 0xa0, 0x21, 0xa0, 0xa4,
|
||||
0x81, 0xb1, 0x50, 0x21, 0x5a, 0x6d, 0x81, 0xad,
|
||||
0xf8, 0x44, 0x69, 0x13, 0xcc, 0x02, 0x8c, 0xd4,
|
||||
0x64, 0x43, 0xc9, 0x9f, 0xcf, 0xc6, 0x03, 0x04,
|
||||
},
|
||||
bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 370)...,
|
||||
)
|
||||
var c1s1enc = bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 1536/4)
|
||||
|
||||
var s1enc = append(
|
||||
[]byte{
|
||||
0x19, 0xf1, 0x27, 0xa3, 0x00, 0x00, 0x00, 0x00,
|
||||
0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04,
|
||||
0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x95, 0xc1,
|
||||
0xb6, 0x2c, 0x99, 0xbe, 0xa0, 0x0c, 0x07, 0x98,
|
||||
0xb0, 0xf1, 0xbe, 0x54, 0x50, 0x63, 0xa1, 0x25,
|
||||
0x1c, 0x9a, 0xcd, 0x12, 0x10, 0x98, 0x74, 0x8b,
|
||||
0x18, 0x66, 0x8d, 0xef, 0xcf, 0x22, 0x03, 0x04,
|
||||
},
|
||||
bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 370)...,
|
||||
)
|
||||
var c1s1dec = C1S1{
|
||||
Time: 16909060,
|
||||
Version: 16909060,
|
||||
Data: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 1536/4-2),
|
||||
}
|
||||
|
||||
func TestC1S1Read(t *testing.T) {
|
||||
for _, ca := range []struct {
|
||||
isC1 bool
|
||||
enc []byte
|
||||
dec C1S1
|
||||
}{
|
||||
{
|
||||
true,
|
||||
c1enc,
|
||||
C1S1{
|
||||
Time: 435234723,
|
||||
Random: append(
|
||||
[]byte{
|
||||
0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04,
|
||||
0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x2d, 0x0a,
|
||||
0x37, 0x6f, 0x63, 0x2e, 0xa0, 0x21, 0xa0, 0xa4,
|
||||
0x81, 0xb1, 0x50, 0x21, 0x5a, 0x6d, 0x81, 0xad,
|
||||
0xf8, 0x44, 0x69, 0x13, 0xcc, 0x02, 0x8c, 0xd4,
|
||||
0x64, 0x43, 0xc9, 0x9f, 0xcf, 0xc6, 0x03, 0x04,
|
||||
},
|
||||
bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 370)...,
|
||||
),
|
||||
Digest: []byte{
|
||||
0x3f, 0xd0, 0xb1, 0xdf, 0xed, 0x6c, 0x9b, 0xc3,
|
||||
0x73, 0x68, 0xe2, 0x47, 0x92, 0x59, 0x32, 0x9a,
|
||||
0x3a, 0xc9, 0x1e, 0xeb, 0xfc, 0xad, 0x8e, 0x9d,
|
||||
0x4e, 0xf4, 0x30, 0xb1, 0x9a, 0xc9, 0x15, 0x99,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
false,
|
||||
s1enc,
|
||||
C1S1{
|
||||
Time: 435234723,
|
||||
Random: append(
|
||||
[]byte{
|
||||
0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04,
|
||||
0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x95, 0xc1,
|
||||
0xb6, 0x2c, 0x99, 0xbe, 0xa0, 0x0c, 0x07, 0x98,
|
||||
0xb0, 0xf1, 0xbe, 0x54, 0x50, 0x63, 0xa1, 0x25,
|
||||
0x1c, 0x9a, 0xcd, 0x12, 0x10, 0x98, 0x74, 0x8b,
|
||||
0x18, 0x66, 0x8d, 0xef, 0xcf, 0x22, 0x03, 0x04,
|
||||
},
|
||||
bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 370)...,
|
||||
),
|
||||
Digest: []byte{
|
||||
0x0e, 0x8f, 0x96, 0x19, 0x19, 0xe6, 0xb7, 0xf2,
|
||||
0xac, 0x9a, 0xc8, 0x7e, 0x6e, 0xe9, 0xd4, 0x72,
|
||||
0xed, 0x82, 0x87, 0xf1, 0xfa, 0xbd, 0x93, 0xb8,
|
||||
0x7c, 0x48, 0x85, 0x03, 0x01, 0x7b, 0x54, 0xbe,
|
||||
},
|
||||
},
|
||||
},
|
||||
} {
|
||||
var c1s1 C1S1
|
||||
err := c1s1.Read((bytes.NewReader(ca.enc)), ca.isC1, true)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ca.dec, c1s1)
|
||||
}
|
||||
var c1s1 C1S1
|
||||
err := c1s1.Read((bytes.NewReader(c1s1enc)))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, c1s1dec, c1s1)
|
||||
}
|
||||
|
||||
func TestC1S1Write(t *testing.T) {
|
||||
for _, ca := range []struct {
|
||||
isC1 bool
|
||||
enc []byte
|
||||
dec C1S1
|
||||
}{
|
||||
{
|
||||
true,
|
||||
c1enc,
|
||||
C1S1{
|
||||
Time: 435234723,
|
||||
Random: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 382),
|
||||
Digest: []byte{
|
||||
0x3f, 0xd0, 0xb1, 0xdf, 0xed, 0x6c, 0x9b, 0xc3,
|
||||
0x73, 0x68, 0xe2, 0x47, 0x92, 0x59, 0x32, 0x9a,
|
||||
0x3a, 0xc9, 0x1e, 0xeb, 0xfc, 0xad, 0x8e, 0x9d,
|
||||
0x4e, 0xf4, 0x30, 0xb1, 0x9a, 0xc9, 0x15, 0x99,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
false,
|
||||
s1enc,
|
||||
C1S1{
|
||||
Time: 435234723,
|
||||
Random: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 382),
|
||||
Digest: []byte{
|
||||
0x3f, 0xd0, 0xb1, 0xdf, 0xed, 0x6c, 0x9b, 0xc3,
|
||||
0x73, 0x68, 0xe2, 0x47, 0x92, 0x59, 0x32, 0x9a,
|
||||
0x3a, 0xc9, 0x1e, 0xeb, 0xfc, 0xad, 0x8e, 0x9d,
|
||||
0x4e, 0xf4, 0x30, 0xb1, 0x9a, 0xc9, 0x15, 0x99,
|
||||
},
|
||||
},
|
||||
},
|
||||
} {
|
||||
var buf bytes.Buffer
|
||||
err := ca.dec.Write(&buf, ca.isC1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ca.enc, buf.Bytes())
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
err := c1s1dec.Write(&buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, c1s1enc, buf.Bytes())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,44 +2,103 @@ package handshake
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
const (
|
||||
c2s2Size = c1s1Size
|
||||
c2s2DigestPos = c2s2Size - 8 - digestLength
|
||||
)
|
||||
|
||||
var (
|
||||
randomCrud = []byte{
|
||||
0xf0, 0xee, 0xc2, 0x4a, 0x80, 0x68, 0xbe, 0xe8,
|
||||
0x2e, 0x00, 0xd0, 0xd1, 0x02, 0x9e, 0x7e, 0x57,
|
||||
0x6e, 0xec, 0x5d, 0x2d, 0x29, 0x80, 0x6f, 0xab,
|
||||
0x93, 0xb8, 0xe6, 0x36, 0xcf, 0xeb, 0x31, 0xae,
|
||||
}
|
||||
clientKeyC2 = append([]byte(nil), append(clientKeyC1, randomCrud...)...)
|
||||
serverKeyS2 = append([]byte(nil), append(serverKeyS1, randomCrud...)...)
|
||||
)
|
||||
|
||||
// C2S2 is a C2 or S2 packet.
|
||||
type C2S2 struct {
|
||||
Time uint32
|
||||
Time2 uint32
|
||||
Random []byte
|
||||
Digest []byte
|
||||
Time uint32
|
||||
Time2 uint32
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// Read reads a C2S2.
|
||||
func (c *C2S2) Read(r io.Reader, validateSignature bool) error {
|
||||
buf := make([]byte, 1536)
|
||||
func (c *C2S2) Read(r io.Reader) error {
|
||||
buf := make([]byte, c2s2Size)
|
||||
_, err := io.ReadFull(r, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if validateSignature {
|
||||
gap := len(buf) - 32
|
||||
digest := hsMakeDigest(c.Digest, buf, gap)
|
||||
if !bytes.Equal(buf[gap:gap+32], digest) {
|
||||
return fmt.Errorf("unable to validate C2/S2 signature")
|
||||
}
|
||||
}
|
||||
|
||||
c.Time = uint32(buf[0])<<24 | uint32(buf[1])<<16 | uint32(buf[2])<<8 | uint32(buf[3])
|
||||
c.Time2 = uint32(buf[4])<<24 | uint32(buf[5])<<16 | uint32(buf[6])<<8 | uint32(buf[7])
|
||||
c.Random = buf[8:]
|
||||
c.Data = buf[8:]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c C2S2) computeDigest(isS2 bool, prevDigest []byte) []byte {
|
||||
// hash entire message except digest
|
||||
msg := make([]byte, c2s2Size-digestLength)
|
||||
msg[0] = byte(c.Time >> 24)
|
||||
msg[1] = byte(c.Time >> 16)
|
||||
msg[2] = byte(c.Time >> 8)
|
||||
msg[3] = byte(c.Time)
|
||||
msg[4] = byte(c.Time2 >> 24)
|
||||
msg[5] = byte(c.Time2 >> 16)
|
||||
msg[6] = byte(c.Time2 >> 8)
|
||||
msg[7] = byte(c.Time2)
|
||||
copy(msg[8:], c.Data[:c2s2DigestPos])
|
||||
|
||||
var key []byte
|
||||
if isS2 {
|
||||
key = hmacSha256(serverKeyS2, prevDigest)
|
||||
} else {
|
||||
key = hmacSha256(clientKeyC2, prevDigest)
|
||||
}
|
||||
|
||||
return hmacSha256(key, msg)
|
||||
}
|
||||
|
||||
func (c C2S2) validate(isS2 bool, prevDigest []byte) error {
|
||||
d1 := c.Data[c2s2DigestPos : c2s2DigestPos+digestLength]
|
||||
d2 := c.computeDigest(isS2, prevDigest)
|
||||
|
||||
if !bytes.Equal(d1, d2) {
|
||||
return fmt.Errorf("unable to validate C2/S2 digest")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *C2S2) fillPlain() error {
|
||||
c.Data = make([]byte, c2s2Size-8)
|
||||
_, err := rand.Read(c.Data)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *C2S2) fill(isS2 bool, prevDigest []byte) error {
|
||||
err := c.fillPlain()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
digest := c.computeDigest(isS2, prevDigest)
|
||||
copy(c.Data[c2s2DigestPos:], digest)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write writes a C2S2.
|
||||
func (c C2S2) Write(w io.Writer) error {
|
||||
buf := make([]byte, 1536)
|
||||
buf := make([]byte, c2s2Size)
|
||||
|
||||
buf[0] = byte(c.Time >> 24)
|
||||
buf[1] = byte(c.Time >> 16)
|
||||
|
|
@ -49,15 +108,7 @@ func (c C2S2) Write(w io.Writer) error {
|
|||
buf[5] = byte(c.Time2 >> 16)
|
||||
buf[6] = byte(c.Time2 >> 8)
|
||||
buf[7] = byte(c.Time2)
|
||||
|
||||
copy(buf[8:], c.Random)
|
||||
|
||||
// signature
|
||||
if c.Digest != nil {
|
||||
gap := len(buf) - 32
|
||||
digest := hsMakeDigest(c.Digest, buf, gap)
|
||||
copy(buf[gap:], digest)
|
||||
}
|
||||
copy(buf[8:], c.Data)
|
||||
|
||||
_, err := w.Write(buf)
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -7,71 +7,22 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var c2s2enc = bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 1536/4)
|
||||
|
||||
var c2s2dec = C2S2{
|
||||
Time: 16909060,
|
||||
Time2: 16909060,
|
||||
Data: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 1536/4-2),
|
||||
}
|
||||
|
||||
func TestC2S2Read(t *testing.T) {
|
||||
c2s2dec := C2S2{
|
||||
Time: 435234723,
|
||||
Time2: 7893542,
|
||||
Random: append(
|
||||
bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 372),
|
||||
[]byte{
|
||||
0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04,
|
||||
0x96, 0x07, 0x2f, 0xe4, 0x04, 0xc5, 0x84, 0xa2,
|
||||
0x21, 0x05, 0xcc, 0xb5, 0x7f, 0x93, 0x02, 0x14,
|
||||
0xaf, 0xb0, 0x76, 0x75, 0xfd, 0x82, 0x29, 0xbe,
|
||||
0xb9, 0x27, 0x9d, 0x4b, 0x0c, 0x81, 0x13, 0xec,
|
||||
}...),
|
||||
Digest: []byte{
|
||||
0x3f, 0xd0, 0xb1, 0xdf, 0xed, 0x6c, 0x9b, 0xc3,
|
||||
0x73, 0x68, 0xe2, 0x47, 0x92, 0x59, 0x32, 0x9a,
|
||||
0x3a, 0xc9, 0x1e, 0xeb, 0xfc, 0xad, 0x8e, 0x9d,
|
||||
0x4e, 0xf4, 0x30, 0xb1, 0x9a, 0xc9, 0x15, 0x99,
|
||||
},
|
||||
}
|
||||
|
||||
c2s2enc := append(append(
|
||||
[]byte{
|
||||
0x19, 0xf1, 0x27, 0xa3, 0x00, 0x78, 0x72, 0x26,
|
||||
},
|
||||
bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 374)...,
|
||||
), []byte{
|
||||
0x96, 0x07, 0x2f, 0xe4, 0x04, 0xc5, 0x84, 0xa2,
|
||||
0x21, 0x05, 0xcc, 0xb5, 0x7f, 0x93, 0x02, 0x14,
|
||||
0xaf, 0xb0, 0x76, 0x75, 0xfd, 0x82, 0x29, 0xbe,
|
||||
0xb9, 0x27, 0x9d, 0x4b, 0x0c, 0x81, 0x13, 0xec,
|
||||
}...)
|
||||
|
||||
var c2s2 C2S2
|
||||
c2s2.Digest = c2s2dec.Digest
|
||||
err := c2s2.Read((bytes.NewReader(c2s2enc)), true)
|
||||
err := c2s2.Read((bytes.NewReader(c2s2enc)))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, c2s2dec, c2s2)
|
||||
}
|
||||
|
||||
func TestC2S2Write(t *testing.T) {
|
||||
c2s2dec := C2S2{
|
||||
Time: 435234723,
|
||||
Time2: 7893542,
|
||||
Random: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 382),
|
||||
Digest: []byte{
|
||||
0x3f, 0xd0, 0xb1, 0xdf, 0xed, 0x6c, 0x9b, 0xc3,
|
||||
0x73, 0x68, 0xe2, 0x47, 0x92, 0x59, 0x32, 0x9a,
|
||||
0x3a, 0xc9, 0x1e, 0xeb, 0xfc, 0xad, 0x8e, 0x9d,
|
||||
0x4e, 0xf4, 0x30, 0xb1, 0x9a, 0xc9, 0x15, 0x99,
|
||||
},
|
||||
}
|
||||
|
||||
c2s2enc := append(append(
|
||||
[]byte{
|
||||
0x19, 0xf1, 0x27, 0xa3, 0x00, 0x78, 0x72, 0x26,
|
||||
},
|
||||
bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 374)...,
|
||||
), []byte{
|
||||
0x96, 0x07, 0x2f, 0xe4, 0x04, 0xc5, 0x84, 0xa2,
|
||||
0x21, 0x05, 0xcc, 0xb5, 0x7f, 0x93, 0x02, 0x14,
|
||||
0xaf, 0xb0, 0x76, 0x75, 0xfd, 0x82, 0x29, 0xbe,
|
||||
0xb9, 0x27, 0x9d, 0x4b, 0x0c, 0x81, 0x13, 0xec,
|
||||
}...)
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := c2s2dec.Write(&buf)
|
||||
require.NoError(t, err)
|
||||
|
|
|
|||
118
internal/rtmp/handshake/dh.go
Normal file
118
internal/rtmp/handshake/dh.go
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
const (
|
||||
dhKeyLength = 128
|
||||
)
|
||||
|
||||
var (
|
||||
p1024 = []byte{
|
||||
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
0xc9, 0x0f, 0xda, 0xa2, 0x21, 0x68, 0xc2, 0x34,
|
||||
0xc4, 0xc6, 0x62, 0x8b, 0x80, 0xdc, 0x1c, 0xd1,
|
||||
0x29, 0x02, 0x4e, 0x08, 0x8a, 0x67, 0xcc, 0x74,
|
||||
0x02, 0x0b, 0xbe, 0xa6, 0x3b, 0x13, 0x9b, 0x22,
|
||||
0x51, 0x4a, 0x08, 0x79, 0x8e, 0x34, 0x04, 0xdd,
|
||||
0xef, 0x95, 0x19, 0xb3, 0xcd, 0x3a, 0x43, 0x1b,
|
||||
0x30, 0x2b, 0x0a, 0x6d, 0xf2, 0x5f, 0x14, 0x37,
|
||||
0x4f, 0xe1, 0x35, 0x6d, 0x6d, 0x51, 0xc2, 0x45,
|
||||
0xe4, 0x85, 0xb5, 0x76, 0x62, 0x5e, 0x7e, 0xc6,
|
||||
0xf4, 0x4c, 0x42, 0xe9, 0xa6, 0x37, 0xed, 0x6b,
|
||||
0x0b, 0xff, 0x5c, 0xb6, 0xf4, 0x06, 0xb7, 0xed,
|
||||
0xee, 0x38, 0x6b, 0xfb, 0x5a, 0x89, 0x9f, 0xa5,
|
||||
0xae, 0x9f, 0x24, 0x11, 0x7c, 0x4b, 0x1f, 0xe6,
|
||||
0x49, 0x28, 0x66, 0x51, 0xec, 0xe6, 0x53, 0x81,
|
||||
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
}
|
||||
q1024 = []byte{
|
||||
0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
0xe4, 0x87, 0xed, 0x51, 0x10, 0xb4, 0x61, 0x1a,
|
||||
0x62, 0x63, 0x31, 0x45, 0xc0, 0x6e, 0x0e, 0x68,
|
||||
0x94, 0x81, 0x27, 0x04, 0x45, 0x33, 0xe6, 0x3a,
|
||||
0x01, 0x05, 0xdf, 0x53, 0x1d, 0x89, 0xcd, 0x91,
|
||||
0x28, 0xa5, 0x04, 0x3c, 0xc7, 0x1a, 0x02, 0x6e,
|
||||
0xf7, 0xca, 0x8c, 0xd9, 0xe6, 0x9d, 0x21, 0x8d,
|
||||
0x98, 0x15, 0x85, 0x36, 0xf9, 0x2f, 0x8a, 0x1b,
|
||||
0xa7, 0xf0, 0x9a, 0xb6, 0xb6, 0xa8, 0xe1, 0x22,
|
||||
0xf2, 0x42, 0xda, 0xbb, 0x31, 0x2f, 0x3f, 0x63,
|
||||
0x7a, 0x26, 0x21, 0x74, 0xd3, 0x1b, 0xf6, 0xb5,
|
||||
0x85, 0xff, 0xae, 0x5b, 0x7a, 0x03, 0x5b, 0xf6,
|
||||
0xf7, 0x1c, 0x35, 0xfd, 0xad, 0x44, 0xcf, 0xd2,
|
||||
0xd7, 0x4f, 0x92, 0x08, 0xbe, 0x25, 0x8f, 0xf3,
|
||||
0x24, 0x94, 0x33, 0x28, 0xf6, 0x73, 0x29, 0xc0,
|
||||
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
}
|
||||
)
|
||||
|
||||
// https://datatracker.ietf.org/doc/html/rfc2631#section-2.1.5
|
||||
func dhValidatePublicKey(key []byte) error {
|
||||
// 1. y >= 2 && y < p
|
||||
var y big.Int
|
||||
y.SetBytes(key)
|
||||
var two big.Int
|
||||
two.SetUint64(2)
|
||||
r := y.Cmp(&two)
|
||||
if r < 0 {
|
||||
return fmt.Errorf("key is < 2")
|
||||
}
|
||||
var p big.Int
|
||||
p.SetBytes(p1024)
|
||||
r = y.Cmp(&p)
|
||||
if r >= 0 {
|
||||
return fmt.Errorf("key is >= p")
|
||||
}
|
||||
|
||||
// 2. (y^q mod p) == 1
|
||||
var q big.Int
|
||||
q.SetBytes(q1024)
|
||||
var z big.Int
|
||||
z.Exp(&y, &q, &p)
|
||||
var one big.Int
|
||||
one.SetUint64(1)
|
||||
r = z.Cmp(&one)
|
||||
if r != 0 {
|
||||
return fmt.Errorf("y^q mod p is != 1")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func dhGenerateKeyPair() ([]byte, []byte, error) {
|
||||
priv := make([]byte, dhKeyLength)
|
||||
_, err := rand.Read(priv)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// y = g ^ x mod p
|
||||
var g big.Int
|
||||
g.SetUint64(2)
|
||||
var x big.Int
|
||||
x.SetBytes(priv)
|
||||
var p big.Int
|
||||
p.SetBytes(p1024)
|
||||
var y big.Int
|
||||
y.Exp(&g, &x, &p)
|
||||
pub := y.Bytes()
|
||||
|
||||
return priv, pub, nil
|
||||
}
|
||||
|
||||
// https://datatracker.ietf.org/doc/html/rfc2631#section-2.1.1
|
||||
func dhComputeSharedSecret(priv []byte, pub []byte) []byte {
|
||||
// ZZ = (ya ^ xb) mod p
|
||||
var y big.Int
|
||||
y.SetBytes(pub)
|
||||
var x big.Int
|
||||
x.SetBytes(priv)
|
||||
var p big.Int
|
||||
p.SetBytes(p1024)
|
||||
var z big.Int
|
||||
z.Exp(&y, &x, &p)
|
||||
return z.Bytes()
|
||||
}
|
||||
|
|
@ -2,96 +2,297 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// DoClient performs a client-side handshake.
|
||||
func DoClient(rw io.ReadWriter, validateSignature bool) error {
|
||||
c0 := C0S0{}
|
||||
const (
|
||||
encryptedVersion = 3<<24 | 5<<16 | 1<<8 | 1
|
||||
)
|
||||
|
||||
func doClientEncrypted(rw io.ReadWriter) ([]byte, []byte, error) {
|
||||
var c0 C0S0
|
||||
|
||||
c0.Version = 6
|
||||
|
||||
err := c0.Write(rw)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
localPrivateKey, localPublicKey, err := dhGenerateKeyPair()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var c1 C1S1
|
||||
|
||||
c1Digest, err := c1.fill(false, localPublicKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
err = c1.Write(rw)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var s0 C0S0
|
||||
|
||||
err = s0.Read(rw)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if s0.Version != 6 {
|
||||
return nil, nil, fmt.Errorf("server replied with unexpected version %d", s0.Version)
|
||||
}
|
||||
|
||||
var s1 C1S1
|
||||
|
||||
err = s1.Read(rw)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
s1Digest, remotePublicKey, err := s1.validate(true)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var s2 C2S2
|
||||
|
||||
err = s2.Read(rw)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
err = s2.validate(true, c1Digest)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var c2 C2S2
|
||||
|
||||
err = c2.fill(false, s1Digest)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
err = c2.Write(rw)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sharedSecret := dhComputeSharedSecret(localPrivateKey, remotePublicKey)
|
||||
keyIn := hmacSha256(sharedSecret, localPublicKey)[:16]
|
||||
keyOut := hmacSha256(sharedSecret, remotePublicKey)[:16]
|
||||
return keyIn, keyOut, nil
|
||||
}
|
||||
|
||||
func doClientPlain(rw io.ReadWriter, strict bool) error {
|
||||
var c0 C0S0
|
||||
|
||||
c0.Version = 3
|
||||
|
||||
err := c0.Write(rw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c1 := C1S1{}
|
||||
err = c1.Write(rw, true)
|
||||
var c1 C1S1
|
||||
|
||||
err = c1.fillPlain()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s0 := C0S0{}
|
||||
err = c1.Write(rw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var s0 C0S0
|
||||
|
||||
err = s0.Read(rw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s1 := C1S1{}
|
||||
err = s1.Read(rw, false, validateSignature)
|
||||
if s0.Version != 3 {
|
||||
return fmt.Errorf("server replied with unexpected version %d", s0.Version)
|
||||
}
|
||||
|
||||
var s1 C1S1
|
||||
|
||||
err = s1.Read(rw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s2 := C2S2{
|
||||
Digest: c1.Digest,
|
||||
}
|
||||
err = s2.Read(rw, validateSignature)
|
||||
var s2 C2S2
|
||||
|
||||
err = s2.Read(rw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c2 := C2S2{
|
||||
Time: s1.Time,
|
||||
Random: s1.Random,
|
||||
Digest: s1.Digest,
|
||||
if strict && !bytes.Equal(s2.Data, c1.Data) {
|
||||
return fmt.Errorf("data in S2 does not correspond")
|
||||
}
|
||||
err = c2.Write(rw)
|
||||
|
||||
var c2 C2S2
|
||||
|
||||
c2.Data = s1.Data
|
||||
|
||||
return c2.Write(rw)
|
||||
}
|
||||
|
||||
// DoClient performs a client-side handshake.
|
||||
func DoClient(rw io.ReadWriter, encrypted bool, strict bool) ([]byte, []byte, error) {
|
||||
if encrypted {
|
||||
return doClientEncrypted(rw)
|
||||
}
|
||||
return nil, nil, doClientPlain(rw, strict)
|
||||
}
|
||||
|
||||
func doServerEncrypted(rw io.ReadWriter) ([]byte, []byte, error) {
|
||||
var c1 C1S1
|
||||
|
||||
err := c1.Read(rw)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
c1Digest, remotePublicKey, err := c1.validate(false)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
localPrivateKey, localPublicKey, err := dhGenerateKeyPair()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var s0 C0S0
|
||||
|
||||
s0.Version = 6
|
||||
|
||||
err = s0.Write(rw)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var s1 C1S1
|
||||
|
||||
s1.Version = encryptedVersion
|
||||
|
||||
s1Digest, err := s1.fill(true, localPublicKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
err = s1.Write(rw)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var s2 C2S2
|
||||
|
||||
s2.Time2 = encryptedVersion
|
||||
|
||||
err = s2.fill(true, c1Digest)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
err = s2.Write(rw)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var c2 C2S2
|
||||
|
||||
err = c2.Read(rw)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
err = c2.validate(false, s1Digest)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sharedSecret := dhComputeSharedSecret(localPrivateKey, remotePublicKey)
|
||||
keyIn := hmacSha256(sharedSecret, localPublicKey)[:16]
|
||||
keyOut := hmacSha256(sharedSecret, remotePublicKey)[:16]
|
||||
return keyIn, keyOut, nil
|
||||
}
|
||||
|
||||
func doServerPlain(rw io.ReadWriter, strict bool) error {
|
||||
var c1 C1S1
|
||||
|
||||
err := c1.Read(rw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var s0 C0S0
|
||||
|
||||
s0.Version = 3
|
||||
|
||||
err = s0.Write(rw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var s1 C1S1
|
||||
|
||||
err = s1.fillPlain()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = s1.Write(rw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var s2 C2S2
|
||||
|
||||
s2.Data = c1.Data
|
||||
|
||||
err = s2.Write(rw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var c2 C2S2
|
||||
|
||||
err = c2.Read(rw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if strict && !bytes.Equal(c2.Data, s1.Data) {
|
||||
return fmt.Errorf("data in C2 does not correspond")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DoServer performs a server-side handshake.
|
||||
func DoServer(rw io.ReadWriter, validateSignature bool) error {
|
||||
err := C0S0{}.Read(rw)
|
||||
func DoServer(rw io.ReadWriter, strict bool) ([]byte, []byte, error) {
|
||||
var c0 C0S0
|
||||
|
||||
err := c0.Read(rw)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
c1 := C1S1{}
|
||||
err = c1.Read(rw, true, validateSignature)
|
||||
if err != nil {
|
||||
return err
|
||||
if c0.Version == 6 {
|
||||
return doServerEncrypted(rw)
|
||||
}
|
||||
|
||||
s0 := C0S0{}
|
||||
err = s0.Write(rw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s1 := C1S1{}
|
||||
err = s1.Write(rw, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s2 := C2S2{
|
||||
Time: c1.Time,
|
||||
Random: c1.Random,
|
||||
Digest: c1.Digest,
|
||||
}
|
||||
err = s2.Write(rw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c2 := C2S2{Digest: s1.Digest}
|
||||
err = c2.Read(rw, validateSignature)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return nil, nil, doServerPlain(rw, strict)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,91 +1,50 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type testReadWriter struct {
|
||||
ch chan []byte
|
||||
}
|
||||
|
||||
func (rw *testReadWriter) Read(p []byte) (int, error) {
|
||||
in := <-rw.ch
|
||||
n := copy(p, in)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (rw *testReadWriter) Write(p []byte) (int, error) {
|
||||
rw.ch <- p
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func TestHandshake(t *testing.T) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:9122")
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
for _, ca := range []string{"plain", "encrypted"} {
|
||||
t.Run(ca, func(t *testing.T) {
|
||||
rw := &testReadWriter{ch: make(chan []byte)}
|
||||
var serverInKey []byte
|
||||
var serverOutKey []byte
|
||||
done := make(chan struct{})
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
var err error
|
||||
serverInKey, serverOutKey, err = DoServer(rw, true)
|
||||
require.NoError(t, err)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
clientInKey, clientOutKey, err := DoClient(rw, ca == "encrypted", true)
|
||||
require.NoError(t, err)
|
||||
<-done
|
||||
|
||||
err = DoServer(conn, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
close(done)
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("tcp", "127.0.0.1:9122")
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
err = DoClient(conn, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
<-done
|
||||
}
|
||||
|
||||
// when C1 signature is invalid, S2 must be equal to C1.
|
||||
func TestHandshakeFallback(t *testing.T) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:9122")
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
err = DoServer(conn, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
close(done)
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("tcp", "127.0.0.1:9122")
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
err = C0S0{}.Write(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
c1 := make([]byte, 1536)
|
||||
_, err = rand.Read(c1[8:])
|
||||
require.NoError(t, err)
|
||||
_, err = conn.Write(c1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = C0S0{}.Read(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
s1 := C1S1{}
|
||||
err = s1.Read(conn, false, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
s2 := C2S2{}
|
||||
err = s2.Read(conn, false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, c1[8:], s2.Random)
|
||||
|
||||
err = C2S2{
|
||||
Time: s1.Time,
|
||||
Random: s1.Random,
|
||||
Digest: s1.Digest,
|
||||
}.Write(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
<-done
|
||||
if ca == "encrypted" {
|
||||
require.NotNil(t, serverInKey)
|
||||
require.Equal(t, serverInKey, clientOutKey)
|
||||
require.Equal(t, serverOutKey, clientInKey)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package message
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/bluenviron/mediamtx/internal/rtmp/bytecounter"
|
||||
"github.com/bluenviron/mediamtx/internal/rtmp/rawmessage"
|
||||
|
|
@ -116,9 +117,13 @@ type Reader struct {
|
|||
}
|
||||
|
||||
// NewReader allocates a Reader.
|
||||
func NewReader(r *bytecounter.Reader, onAckNeeded func(uint32) error) *Reader {
|
||||
func NewReader(
|
||||
r io.Reader,
|
||||
bcr *bytecounter.Reader,
|
||||
onAckNeeded func(uint32) error,
|
||||
) *Reader {
|
||||
return &Reader{
|
||||
r: rawmessage.NewReader(r, onAckNeeded),
|
||||
r: rawmessage.NewReader(r, bcr, onAckNeeded),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -268,7 +268,8 @@ var readWriterCases = []struct {
|
|||
func TestReader(t *testing.T) {
|
||||
for _, ca := range readWriterCases {
|
||||
t.Run(ca.name, func(t *testing.T) {
|
||||
r := NewReader(bytecounter.NewReader(bytes.NewReader(ca.enc)), nil)
|
||||
bc := bytecounter.NewReader(bytes.NewReader(ca.enc))
|
||||
r := NewReader(bc, bc, nil)
|
||||
dec, err := r.Read()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ca.dec, dec)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
package message
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/bluenviron/mediamtx/internal/rtmp/bytecounter"
|
||||
)
|
||||
|
||||
|
|
@ -11,10 +13,14 @@ type ReadWriter struct {
|
|||
}
|
||||
|
||||
// NewReadWriter allocates a ReadWriter.
|
||||
func NewReadWriter(bc *bytecounter.ReadWriter, checkAcknowledge bool) *ReadWriter {
|
||||
w := NewWriter(bc.Writer, checkAcknowledge)
|
||||
func NewReadWriter(
|
||||
rw io.ReadWriter,
|
||||
bcrw *bytecounter.ReadWriter,
|
||||
checkAcknowledge bool,
|
||||
) *ReadWriter {
|
||||
w := NewWriter(rw, bcrw.Writer, checkAcknowledge)
|
||||
|
||||
r := NewReader(bc.Reader, func(count uint32) error {
|
||||
r := NewReader(rw, bcrw.Reader, func(count uint32) error {
|
||||
return w.Write(&Acknowledge{
|
||||
Value: count,
|
||||
})
|
||||
|
|
|
|||
|
|
@ -27,19 +27,21 @@ func TestReadWriterAcknowledge(t *testing.T) {
|
|||
var buf1 bytes.Buffer
|
||||
var buf2 bytes.Buffer
|
||||
|
||||
rw1 := NewReadWriter(bytecounter.NewReadWriter(&duplexRW{
|
||||
bc1 := bytecounter.NewReadWriter(&duplexRW{
|
||||
Reader: &buf2,
|
||||
Writer: &buf1,
|
||||
}), true)
|
||||
})
|
||||
rw1 := NewReadWriter(bc1, bc1, true)
|
||||
err := rw1.Write(&Acknowledge{
|
||||
Value: 7863534,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
rw2 := NewReadWriter(bytecounter.NewReadWriter(&duplexRW{
|
||||
bc2 := bytecounter.NewReadWriter(&duplexRW{
|
||||
Reader: &buf1,
|
||||
Writer: &buf2,
|
||||
}), true)
|
||||
})
|
||||
rw2 := NewReadWriter(bc2, bc2, true)
|
||||
_, err = rw2.Read()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
|
@ -48,19 +50,21 @@ func TestReadWriterPing(t *testing.T) {
|
|||
var buf1 bytes.Buffer
|
||||
var buf2 bytes.Buffer
|
||||
|
||||
rw1 := NewReadWriter(bytecounter.NewReadWriter(&duplexRW{
|
||||
bc1 := bytecounter.NewReadWriter(&duplexRW{
|
||||
Reader: &buf2,
|
||||
Writer: &buf1,
|
||||
}), true)
|
||||
})
|
||||
rw1 := NewReadWriter(bc1, bc1, true)
|
||||
err := rw1.Write(&UserControlPingRequest{
|
||||
ServerTime: 143424312,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
rw2 := NewReadWriter(bytecounter.NewReadWriter(&duplexRW{
|
||||
bc2 := bytecounter.NewReadWriter(&duplexRW{
|
||||
Reader: &buf1,
|
||||
Writer: &buf2,
|
||||
}), true)
|
||||
})
|
||||
rw2 := NewReadWriter(bc2, bc2, true)
|
||||
_, err = rw2.Read()
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
package message
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/bluenviron/mediamtx/internal/rtmp/bytecounter"
|
||||
"github.com/bluenviron/mediamtx/internal/rtmp/rawmessage"
|
||||
)
|
||||
|
|
@ -11,9 +13,13 @@ type Writer struct {
|
|||
}
|
||||
|
||||
// NewWriter allocates a Writer.
|
||||
func NewWriter(w *bytecounter.Writer, checkAcknowledge bool) *Writer {
|
||||
func NewWriter(
|
||||
w io.Writer,
|
||||
bcw *bytecounter.Writer,
|
||||
checkAcknowledge bool,
|
||||
) *Writer {
|
||||
return &Writer{
|
||||
w: rawmessage.NewWriter(w, checkAcknowledge),
|
||||
w: rawmessage.NewWriter(w, bcw, checkAcknowledge),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,8 @@ func TestWriter(t *testing.T) {
|
|||
for _, ca := range readWriterCases {
|
||||
t.Run(ca.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
r := NewWriter(bytecounter.NewWriter(&buf), true)
|
||||
bc := bytecounter.NewWriter(&buf)
|
||||
r := NewWriter(bc, bc, true)
|
||||
err := r.Write(ca.dec)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ca.enc, buf.Bytes())
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/bluenviron/mediamtx/internal/rtmp/bytecounter"
|
||||
|
|
@ -30,7 +31,7 @@ func (rc *readerChunkStream) readChunk(c chunk.Chunk, chunkBodySize uint32) erro
|
|||
|
||||
// check if an ack is needed
|
||||
if rc.mr.ackWindowSize != 0 {
|
||||
count := uint32(rc.mr.r.Count())
|
||||
count := uint32(rc.mr.bcr.Count())
|
||||
diff := count - rc.mr.lastAckCount
|
||||
|
||||
if diff > (rc.mr.ackWindowSize) {
|
||||
|
|
@ -208,7 +209,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
|
|||
|
||||
// Reader is a raw message reader.
|
||||
type Reader struct {
|
||||
r *bytecounter.Reader
|
||||
bcr *bytecounter.Reader
|
||||
onAckNeeded func(uint32) error
|
||||
|
||||
br *bufio.Reader
|
||||
|
|
@ -224,9 +225,13 @@ type Reader struct {
|
|||
}
|
||||
|
||||
// NewReader allocates a Reader.
|
||||
func NewReader(r *bytecounter.Reader, onAckNeeded func(uint32) error) *Reader {
|
||||
func NewReader(
|
||||
r io.Reader,
|
||||
bcr *bytecounter.Reader,
|
||||
onAckNeeded func(uint32) error,
|
||||
) *Reader {
|
||||
return &Reader{
|
||||
r: r,
|
||||
bcr: bcr,
|
||||
br: bufio.NewReader(r),
|
||||
onAckNeeded: onAckNeeded,
|
||||
chunkSize: 128,
|
||||
|
|
|
|||
|
|
@ -198,7 +198,8 @@ func TestReader(t *testing.T) {
|
|||
for _, ca := range cases {
|
||||
t.Run(ca.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
r := NewReader(bytecounter.NewReader(&buf), func(count uint32) error {
|
||||
br := bytecounter.NewReader(&buf)
|
||||
r := NewReader(br, br, func(count uint32) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
|
|
@ -224,7 +225,7 @@ func TestReaderAcknowledge(t *testing.T) {
|
|||
|
||||
var buf bytes.Buffer
|
||||
bc := bytecounter.NewReader(&buf)
|
||||
r := NewReader(bc, func(count uint32) error {
|
||||
r := NewReader(bc, bc, func(count uint32) error {
|
||||
close(onAckCalled)
|
||||
return nil
|
||||
})
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package rawmessage
|
|||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/bluenviron/mediamtx/internal/rtmp/bytecounter"
|
||||
|
|
@ -21,7 +22,7 @@ type writerChunkStream struct {
|
|||
func (wc *writerChunkStream) writeChunk(c chunk.Chunk) error {
|
||||
// check if we received an acknowledge
|
||||
if wc.mw.checkAcknowledge && wc.mw.ackWindowSize != 0 {
|
||||
diff := uint32(wc.mw.w.Count()) - wc.mw.ackValue
|
||||
diff := uint32(wc.mw.bcw.Count()) - wc.mw.ackValue
|
||||
|
||||
if diff > (wc.mw.ackWindowSize * 3 / 2) {
|
||||
return fmt.Errorf("no acknowledge received within window")
|
||||
|
|
@ -148,7 +149,7 @@ func (wc *writerChunkStream) writeMessage(msg *Message) error {
|
|||
|
||||
// Writer is a raw message writer.
|
||||
type Writer struct {
|
||||
w *bytecounter.Writer
|
||||
bcw *bytecounter.Writer
|
||||
bw *bufio.Writer
|
||||
checkAcknowledge bool
|
||||
chunkSize uint32
|
||||
|
|
@ -158,9 +159,13 @@ type Writer struct {
|
|||
}
|
||||
|
||||
// NewWriter allocates a Writer.
|
||||
func NewWriter(w *bytecounter.Writer, checkAcknowledge bool) *Writer {
|
||||
func NewWriter(
|
||||
w io.Writer,
|
||||
bcw *bytecounter.Writer,
|
||||
checkAcknowledge bool,
|
||||
) *Writer {
|
||||
return &Writer{
|
||||
w: w,
|
||||
bcw: bcw,
|
||||
bw: bufio.NewWriter(w),
|
||||
checkAcknowledge: checkAcknowledge,
|
||||
chunkSize: 128,
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@ func TestWriter(t *testing.T) {
|
|||
for _, ca := range cases {
|
||||
t.Run(ca.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
w := NewWriter(bytecounter.NewWriter(&buf), true)
|
||||
bc := bytecounter.NewWriter(&buf)
|
||||
w := NewWriter(bc, bc, true)
|
||||
|
||||
for _, msg := range ca.messages {
|
||||
err := w.Write(msg)
|
||||
|
|
@ -36,11 +37,11 @@ func TestWriterAcknowledge(t *testing.T) {
|
|||
for _, ca := range []string{"standard", "overflow"} {
|
||||
t.Run(ca, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
bcw := bytecounter.NewWriter(&buf)
|
||||
w := NewWriter(bcw, true)
|
||||
bc := bytecounter.NewWriter(&buf)
|
||||
w := NewWriter(bc, bc, true)
|
||||
|
||||
if ca == "overflow" {
|
||||
bcw.SetCount(4294967096)
|
||||
bc.SetCount(4294967096)
|
||||
w.ackValue = 4294967096
|
||||
}
|
||||
|
||||
|
|
|
|||
49
internal/rtmp/rc4_readwriter.go
Normal file
49
internal/rtmp/rc4_readwriter.go
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
package rtmp
|
||||
|
||||
import (
|
||||
"crypto/rc4"
|
||||
"io"
|
||||
)
|
||||
|
||||
type rc4ReadWriter struct {
|
||||
rw io.ReadWriter
|
||||
in *rc4.Cipher
|
||||
out *rc4.Cipher
|
||||
}
|
||||
|
||||
func newRC4ReadWriter(rw io.ReadWriter, keyIn []byte, keyOut []byte) (*rc4ReadWriter, error) {
|
||||
in, err := rc4.NewCipher(keyIn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out, err := rc4.NewCipher(keyOut)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p := make([]byte, 1536)
|
||||
in.XORKeyStream(p, p)
|
||||
out.XORKeyStream(p, p)
|
||||
|
||||
return &rc4ReadWriter{
|
||||
rw: rw,
|
||||
in: in,
|
||||
out: out,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *rc4ReadWriter) Read(p []byte) (int, error) {
|
||||
n, err := r.rw.Read(p)
|
||||
if n == 0 {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
r.in.XORKeyStream(p[:n], p[:n])
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *rc4ReadWriter) Write(p []byte) (int, error) {
|
||||
r.out.XORKeyStream(p, p)
|
||||
return r.rw.Write(p)
|
||||
}
|
||||
|
|
@ -166,7 +166,7 @@ func TestReadTracks(t *testing.T) {
|
|||
t.Run(ca.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
bc := bytecounter.NewReadWriter(&buf)
|
||||
mrw := message.NewReadWriter(bc, true)
|
||||
mrw := message.NewReadWriter(bc, bc, true)
|
||||
|
||||
switch ca.name {
|
||||
case "video+audio":
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ func TestWriteTracks(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
bc := bytecounter.NewReadWriter(&buf)
|
||||
mrw := message.NewReadWriter(bc, true)
|
||||
mrw := message.NewReadWriter(bc, bc, true)
|
||||
|
||||
msg, err := mrw.Read()
|
||||
require.NoError(t, err)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue