1
0
Fork 0
forked from External/mediamtx

srt: process connection requests in parallel (#3382) (#3534)

This commit is contained in:
Alessandro Ros 2024-07-05 22:17:40 +02:00 committed by GitHub
parent c4987d020a
commit 342c257df5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 50 additions and 123 deletions

2
go.mod
View file

@ -82,3 +82,5 @@ replace code.cloudfoundry.org/bytefmt => github.com/cloudfoundry/bytefmt v0.0.0-
replace github.com/pion/ice/v2 => github.com/aler9/ice/v2 v2.0.0-20240608212222-2eebc68350c9 replace github.com/pion/ice/v2 => github.com/aler9/ice/v2 v2.0.0-20240608212222-2eebc68350c9
replace github.com/pion/webrtc/v3 => github.com/aler9/webrtc/v3 v3.0.0-20240610104456-eaec24056d06 replace github.com/pion/webrtc/v3 => github.com/aler9/webrtc/v3 v3.0.0-20240610104456-eaec24056d06
replace github.com/datarhei/gosrt => github.com/aler9/gosrt v0.0.0-20240705192040-d4bc5eaa3ee7

4
go.sum
View file

@ -10,6 +10,8 @@ github.com/alecthomas/kong v0.9.0 h1:G5diXxc85KvoV2f0ZRVuMsi45IrBgx9zDNGNj165aPA
github.com/alecthomas/kong v0.9.0/go.mod h1:Y47y5gKfHp1hDc7CH7OeXgLIpp+Q2m1Ni0L5s3bI8Os= github.com/alecthomas/kong v0.9.0/go.mod h1:Y47y5gKfHp1hDc7CH7OeXgLIpp+Q2m1Ni0L5s3bI8Os=
github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc= github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc=
github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
github.com/aler9/gosrt v0.0.0-20240705192040-d4bc5eaa3ee7 h1:4WE1Nez3YyD1CgJfWlnyp+uLLPZOKD5ywWPvwbf/Jp4=
github.com/aler9/gosrt v0.0.0-20240705192040-d4bc5eaa3ee7/go.mod h1:fsOWdLSHUHShHjgi/46h6wjtdQrtnSdAQFnlas8ONxs=
github.com/aler9/ice/v2 v2.0.0-20240608212222-2eebc68350c9 h1:Vax9SzYE68ZYLwFaK7lnCV2ZhX9/YqAJX6xxROPRqEM= github.com/aler9/ice/v2 v2.0.0-20240608212222-2eebc68350c9 h1:Vax9SzYE68ZYLwFaK7lnCV2ZhX9/YqAJX6xxROPRqEM=
github.com/aler9/ice/v2 v2.0.0-20240608212222-2eebc68350c9/go.mod h1:KXJJcZK7E8WzrBEYnV4UtqEZsGeWfHxsNqhVcVvgjxw= github.com/aler9/ice/v2 v2.0.0-20240608212222-2eebc68350c9/go.mod h1:KXJJcZK7E8WzrBEYnV4UtqEZsGeWfHxsNqhVcVvgjxw=
github.com/aler9/webrtc/v3 v3.0.0-20240610104456-eaec24056d06 h1:WtKhXOpd8lgTeXF3RQVOzkNRuy83ygvWEpMYD2aoY3Q= github.com/aler9/webrtc/v3 v3.0.0-20240610104456-eaec24056d06 h1:WtKhXOpd8lgTeXF3RQVOzkNRuy83ygvWEpMYD2aoY3Q=
@ -37,8 +39,6 @@ github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJ
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
github.com/datarhei/gosrt v0.6.0 h1:HrrXAw90V78ok4WMIhX6se1aTHPCn82Sg2hj+PhdmGc=
github.com/datarhei/gosrt v0.6.0/go.mod h1:fsOWdLSHUHShHjgi/46h6wjtdQrtnSdAQFnlas8ONxs=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=

View file

@ -74,9 +74,6 @@ type conn struct {
pathName string pathName string
query string query string
sconn srt.Conn sconn srt.Conn
chNew chan srtNewConnReq
chSetConn chan srt.Conn
} }
func (c *conn) initialize() { func (c *conn) initialize() {
@ -84,8 +81,6 @@ func (c *conn) initialize() {
c.created = time.Now() c.created = time.Now()
c.uuid = uuid.New() c.uuid = uuid.New()
c.chNew = make(chan srtNewConnReq)
c.chSetConn = make(chan srt.Conn)
c.Log(logger.Info, "opened") c.Log(logger.Info, "opened")
@ -130,36 +125,20 @@ func (c *conn) run() { //nolint:dupl
} }
func (c *conn) runInner() error { func (c *conn) runInner() error {
var req srtNewConnReq
select {
case req = <-c.chNew:
case <-c.ctx.Done():
return errors.New("terminated")
}
answerSent, err := c.runInner2(req)
if !answerSent {
req.res <- nil
}
return err
}
func (c *conn) runInner2(req srtNewConnReq) (bool, error) {
var streamID streamID var streamID streamID
err := streamID.unmarshal(req.connReq.StreamId()) err := streamID.unmarshal(c.connReq.StreamId())
if err != nil { if err != nil {
return false, fmt.Errorf("invalid stream ID '%s': %w", req.connReq.StreamId(), err) c.connReq.Reject(srt.REJ_PEER)
return fmt.Errorf("invalid stream ID '%s': %w", c.connReq.StreamId(), err)
} }
if streamID.mode == streamIDModePublish { if streamID.mode == streamIDModePublish {
return c.runPublish(req, &streamID) return c.runPublish(&streamID)
} }
return c.runRead(req, &streamID) return c.runRead(&streamID)
} }
func (c *conn) runPublish(req srtNewConnReq, streamID *streamID) (bool, error) { func (c *conn) runPublish(streamID *streamID) error {
path, err := c.pathManager.AddPublisher(defs.PathAddPublisherReq{ path, err := c.pathManager.AddPublisher(defs.PathAddPublisherReq{
Author: c, Author: c,
AccessRequest: defs.PathAccessRequest{ AccessRequest: defs.PathAccessRequest{
@ -178,21 +157,24 @@ func (c *conn) runPublish(req srtNewConnReq, streamID *streamID) (bool, error) {
if errors.As(err, &terr) { if errors.As(err, &terr) {
// wait some seconds to mitigate brute force attacks // wait some seconds to mitigate brute force attacks
<-time.After(auth.PauseAfterError) <-time.After(auth.PauseAfterError)
return false, terr c.connReq.Reject(srt.REJ_PEER)
return terr
} }
return false, err c.connReq.Reject(srt.REJ_PEER)
return err
} }
defer path.RemovePublisher(defs.PathRemovePublisherReq{Author: c}) defer path.RemovePublisher(defs.PathRemovePublisherReq{Author: c})
err = srtCheckPassphrase(req.connReq, path.SafeConf().SRTPublishPassphrase) err = srtCheckPassphrase(c.connReq, path.SafeConf().SRTPublishPassphrase)
if err != nil { if err != nil {
return false, err c.connReq.Reject(srt.REJ_PEER)
return err
} }
sconn, err := c.exchangeRequestWithConn(req) sconn, err := c.connReq.Accept()
if err != nil { if err != nil {
return true, err return err
} }
c.mutex.Lock() c.mutex.Lock()
@ -210,12 +192,12 @@ func (c *conn) runPublish(req srtNewConnReq, streamID *streamID) (bool, error) {
select { select {
case err := <-readerErr: case err := <-readerErr:
sconn.Close() sconn.Close()
return true, err return err
case <-c.ctx.Done(): case <-c.ctx.Done():
sconn.Close() sconn.Close()
<-readerErr <-readerErr
return true, errors.New("terminated") return errors.New("terminated")
} }
} }
@ -256,7 +238,7 @@ func (c *conn) runPublishReader(sconn srt.Conn, path defs.Path) error {
} }
} }
func (c *conn) runRead(req srtNewConnReq, streamID *streamID) (bool, error) { func (c *conn) runRead(streamID *streamID) error {
path, stream, err := c.pathManager.AddReader(defs.PathAddReaderReq{ path, stream, err := c.pathManager.AddReader(defs.PathAddReaderReq{
Author: c, Author: c,
AccessRequest: defs.PathAccessRequest{ AccessRequest: defs.PathAccessRequest{
@ -274,21 +256,24 @@ func (c *conn) runRead(req srtNewConnReq, streamID *streamID) (bool, error) {
if errors.As(err, &terr) { if errors.As(err, &terr) {
// wait some seconds to mitigate brute force attacks // wait some seconds to mitigate brute force attacks
<-time.After(auth.PauseAfterError) <-time.After(auth.PauseAfterError)
return false, err c.connReq.Reject(srt.REJ_PEER)
return terr
} }
return false, err c.connReq.Reject(srt.REJ_PEER)
return err
} }
defer path.RemoveReader(defs.PathRemoveReaderReq{Author: c}) defer path.RemoveReader(defs.PathRemoveReaderReq{Author: c})
err = srtCheckPassphrase(req.connReq, path.SafeConf().SRTReadPassphrase) err = srtCheckPassphrase(c.connReq, path.SafeConf().SRTReadPassphrase)
if err != nil { if err != nil {
return false, err c.connReq.Reject(srt.REJ_PEER)
return err
} }
sconn, err := c.exchangeRequestWithConn(req) sconn, err := c.connReq.Accept()
if err != nil { if err != nil {
return true, err return err
} }
defer sconn.Close() defer sconn.Close()
@ -307,7 +292,7 @@ func (c *conn) runRead(req srtNewConnReq, streamID *streamID) (bool, error) {
err = mpegts.FromStream(stream, writer, bw, sconn, time.Duration(c.writeTimeout)) err = mpegts.FromStream(stream, writer, bw, sconn, time.Duration(c.writeTimeout))
if err != nil { if err != nil {
return true, err return err
} }
c.Log(logger.Info, "is reading from path '%s', %s", c.Log(logger.Info, "is reading from path '%s', %s",
@ -331,41 +316,10 @@ func (c *conn) runRead(req srtNewConnReq, streamID *streamID) (bool, error) {
select { select {
case <-c.ctx.Done(): case <-c.ctx.Done():
return true, fmt.Errorf("terminated") return fmt.Errorf("terminated")
case err := <-writer.Error(): case err := <-writer.Error():
return true, err return err
}
}
func (c *conn) exchangeRequestWithConn(req srtNewConnReq) (srt.Conn, error) {
req.res <- c
select {
case sconn := <-c.chSetConn:
return sconn, nil
case <-c.ctx.Done():
return nil, errors.New("terminated")
}
}
// new is called by srtListener through srtServer.
func (c *conn) new(req srtNewConnReq) *conn {
select {
case c.chNew <- req:
return <-req.res
case <-c.ctx.Done():
return nil
}
}
// setConn is called by srtListener .
func (c *conn) setConn(sconn srt.Conn) {
select {
case c.chSetConn <- sconn:
case <-c.ctx.Done():
} }
} }

View file

@ -27,24 +27,11 @@ func (l *listener) run() {
func (l *listener) runInner() error { func (l *listener) runInner() error {
for { for {
var sconn *conn req, err := l.ln.Accept2()
conn, _, err := l.ln.Accept(func(req srt.ConnRequest) srt.ConnType {
sconn = l.parent.newConnRequest(req)
if sconn == nil {
return srt.REJECT
}
// currently it's the same to return SUBSCRIBE or PUBLISH
return srt.SUBSCRIBE
})
if err != nil { if err != nil {
return err return err
} }
if conn == nil { l.parent.newConnRequest(req)
continue
}
sconn.setConn(conn)
} }
} }

View file

@ -26,11 +26,6 @@ func srtMaxPayloadSize(u int) int {
return ((u - 16) / 188) * 188 // 16 = SRT header, 188 = MPEG-TS packet return ((u - 16) / 188) * 188 // 16 = SRT header, 188 = MPEG-TS packet
} }
type srtNewConnReq struct {
connReq srt.ConnRequest
res chan *conn
}
type serverAPIConnsListRes struct { type serverAPIConnsListRes struct {
data *defs.APISRTConnList data *defs.APISRTConnList
err error err error
@ -90,7 +85,7 @@ type Server struct {
conns map[*conn]struct{} conns map[*conn]struct{}
// in // in
chNewConnRequest chan srtNewConnReq chNewConnRequest chan srt.ConnRequest
chAcceptErr chan error chAcceptErr chan error
chCloseConn chan *conn chCloseConn chan *conn
chAPIConnsList chan serverAPIConnsListReq chAPIConnsList chan serverAPIConnsListReq
@ -113,7 +108,7 @@ func (s *Server) Initialize() error {
s.ctx, s.ctxCancel = context.WithCancel(context.Background()) s.ctx, s.ctxCancel = context.WithCancel(context.Background())
s.conns = make(map[*conn]struct{}) s.conns = make(map[*conn]struct{})
s.chNewConnRequest = make(chan srtNewConnReq) s.chNewConnRequest = make(chan srt.ConnRequest)
s.chAcceptErr = make(chan error) s.chAcceptErr = make(chan error)
s.chCloseConn = make(chan *conn) s.chCloseConn = make(chan *conn)
s.chAPIConnsList = make(chan serverAPIConnsListReq) s.chAPIConnsList = make(chan serverAPIConnsListReq)
@ -165,7 +160,7 @@ outer:
writeTimeout: s.WriteTimeout, writeTimeout: s.WriteTimeout,
writeQueueSize: s.WriteQueueSize, writeQueueSize: s.WriteQueueSize,
udpMaxPayloadSize: s.UDPMaxPayloadSize, udpMaxPayloadSize: s.UDPMaxPayloadSize,
connReq: req.connReq, connReq: req,
runOnConnect: s.RunOnConnect, runOnConnect: s.RunOnConnect,
runOnConnectRestart: s.RunOnConnectRestart, runOnConnectRestart: s.RunOnConnectRestart,
runOnDisconnect: s.RunOnDisconnect, runOnDisconnect: s.RunOnDisconnect,
@ -176,7 +171,6 @@ outer:
} }
c.initialize() c.initialize()
s.conns[c] = struct{}{} s.conns[c] = struct{}{}
req.res <- c
case c := <-s.chCloseConn: case c := <-s.chCloseConn:
delete(s.conns, c) delete(s.conns, c)
@ -236,20 +230,11 @@ func (s *Server) findConnByUUID(uuid uuid.UUID) *conn {
} }
// newConnRequest is called by srtListener. // newConnRequest is called by srtListener.
func (s *Server) newConnRequest(connReq srt.ConnRequest) *conn { func (s *Server) newConnRequest(connReq srt.ConnRequest) {
req := srtNewConnReq{
connReq: connReq,
res: make(chan *conn),
}
select { select {
case s.chNewConnRequest <- req: case s.chNewConnRequest <- connReq:
c := <-req.res
return c.new(req)
case <-s.ctx.Done(): case <-s.ctx.Done():
return nil connReq.Reject(srt.REJ_CLOSE)
} }
} }

View file

@ -20,16 +20,15 @@ func TestSource(t *testing.T) {
defer ln.Close() defer ln.Close()
go func() { go func() {
conn, _, err := ln.Accept(func(req srt.ConnRequest) srt.ConnType { req, err := ln.Accept2()
require.Equal(t, "sidname", req.StreamId()) require.NoError(t, err)
err := req.SetPassphrase("ttest1234567")
if err != nil { require.Equal(t, "sidname", req.StreamId())
return srt.REJECT err = req.SetPassphrase("ttest1234567")
} require.NoError(t, err)
return srt.SUBSCRIBE
}) conn, err := req.Accept()
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, conn)
defer conn.Close() defer conn.Close()
track := &mpegts.Track{ track := &mpegts.Track{