From 342c257df5189200d7fc789828a1022739c2edef Mon Sep 17 00:00:00 2001 From: Alessandro Ros Date: Fri, 5 Jul 2024 22:17:40 +0200 Subject: [PATCH] srt: process connection requests in parallel (#3382) (#3534) --- go.mod | 2 + go.sum | 4 +- internal/servers/srt/conn.go | 106 ++++++---------------- internal/servers/srt/listener.go | 17 +--- internal/servers/srt/server.go | 27 ++---- internal/staticsources/srt/source_test.go | 17 ++-- 6 files changed, 50 insertions(+), 123 deletions(-) diff --git a/go.mod b/go.mod index 964c9309..ebf4fad5 100644 --- a/go.mod +++ b/go.mod @@ -82,3 +82,5 @@ replace code.cloudfoundry.org/bytefmt => github.com/cloudfoundry/bytefmt v0.0.0- replace github.com/pion/ice/v2 => github.com/aler9/ice/v2 v2.0.0-20240608212222-2eebc68350c9 replace github.com/pion/webrtc/v3 => github.com/aler9/webrtc/v3 v3.0.0-20240610104456-eaec24056d06 + +replace github.com/datarhei/gosrt => github.com/aler9/gosrt v0.0.0-20240705192040-d4bc5eaa3ee7 diff --git a/go.sum b/go.sum index 94d74fe0..786bbb70 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/alecthomas/kong v0.9.0 h1:G5diXxc85KvoV2f0ZRVuMsi45IrBgx9zDNGNj165aPA github.com/alecthomas/kong v0.9.0/go.mod h1:Y47y5gKfHp1hDc7CH7OeXgLIpp+Q2m1Ni0L5s3bI8Os= github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc= github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= +github.com/aler9/gosrt v0.0.0-20240705192040-d4bc5eaa3ee7 h1:4WE1Nez3YyD1CgJfWlnyp+uLLPZOKD5ywWPvwbf/Jp4= +github.com/aler9/gosrt v0.0.0-20240705192040-d4bc5eaa3ee7/go.mod h1:fsOWdLSHUHShHjgi/46h6wjtdQrtnSdAQFnlas8ONxs= github.com/aler9/ice/v2 v2.0.0-20240608212222-2eebc68350c9 h1:Vax9SzYE68ZYLwFaK7lnCV2ZhX9/YqAJX6xxROPRqEM= github.com/aler9/ice/v2 v2.0.0-20240608212222-2eebc68350c9/go.mod h1:KXJJcZK7E8WzrBEYnV4UtqEZsGeWfHxsNqhVcVvgjxw= github.com/aler9/webrtc/v3 v3.0.0-20240610104456-eaec24056d06 h1:WtKhXOpd8lgTeXF3RQVOzkNRuy83ygvWEpMYD2aoY3Q= @@ -37,8 +39,6 @@ github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJ github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= -github.com/datarhei/gosrt v0.6.0 h1:HrrXAw90V78ok4WMIhX6se1aTHPCn82Sg2hj+PhdmGc= -github.com/datarhei/gosrt v0.6.0/go.mod h1:fsOWdLSHUHShHjgi/46h6wjtdQrtnSdAQFnlas8ONxs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/internal/servers/srt/conn.go b/internal/servers/srt/conn.go index 971fb5af..01bfeecd 100644 --- a/internal/servers/srt/conn.go +++ b/internal/servers/srt/conn.go @@ -74,9 +74,6 @@ type conn struct { pathName string query string sconn srt.Conn - - chNew chan srtNewConnReq - chSetConn chan srt.Conn } func (c *conn) initialize() { @@ -84,8 +81,6 @@ func (c *conn) initialize() { c.created = time.Now() c.uuid = uuid.New() - c.chNew = make(chan srtNewConnReq) - c.chSetConn = make(chan srt.Conn) c.Log(logger.Info, "opened") @@ -130,36 +125,20 @@ func (c *conn) run() { //nolint:dupl } func (c *conn) runInner() error { - var req srtNewConnReq - select { - case req = <-c.chNew: - case <-c.ctx.Done(): - return errors.New("terminated") - } - - answerSent, err := c.runInner2(req) - - if !answerSent { - req.res <- nil - } - - return err -} - -func (c *conn) runInner2(req srtNewConnReq) (bool, error) { var streamID streamID - err := streamID.unmarshal(req.connReq.StreamId()) + err := streamID.unmarshal(c.connReq.StreamId()) if err != nil { - return false, fmt.Errorf("invalid stream ID '%s': %w", req.connReq.StreamId(), err) + c.connReq.Reject(srt.REJ_PEER) + return fmt.Errorf("invalid stream ID '%s': %w", c.connReq.StreamId(), err) } if streamID.mode == streamIDModePublish { - return c.runPublish(req, &streamID) + return c.runPublish(&streamID) } - return c.runRead(req, &streamID) + return c.runRead(&streamID) } -func (c *conn) runPublish(req srtNewConnReq, streamID *streamID) (bool, error) { +func (c *conn) runPublish(streamID *streamID) error { path, err := c.pathManager.AddPublisher(defs.PathAddPublisherReq{ Author: c, AccessRequest: defs.PathAccessRequest{ @@ -178,21 +157,24 @@ func (c *conn) runPublish(req srtNewConnReq, streamID *streamID) (bool, error) { if errors.As(err, &terr) { // wait some seconds to mitigate brute force attacks <-time.After(auth.PauseAfterError) - return false, terr + c.connReq.Reject(srt.REJ_PEER) + return terr } - return false, err + c.connReq.Reject(srt.REJ_PEER) + return err } defer path.RemovePublisher(defs.PathRemovePublisherReq{Author: c}) - err = srtCheckPassphrase(req.connReq, path.SafeConf().SRTPublishPassphrase) + err = srtCheckPassphrase(c.connReq, path.SafeConf().SRTPublishPassphrase) if err != nil { - return false, err + c.connReq.Reject(srt.REJ_PEER) + return err } - sconn, err := c.exchangeRequestWithConn(req) + sconn, err := c.connReq.Accept() if err != nil { - return true, err + return err } c.mutex.Lock() @@ -210,12 +192,12 @@ func (c *conn) runPublish(req srtNewConnReq, streamID *streamID) (bool, error) { select { case err := <-readerErr: sconn.Close() - return true, err + return err case <-c.ctx.Done(): sconn.Close() <-readerErr - return true, errors.New("terminated") + return errors.New("terminated") } } @@ -256,7 +238,7 @@ func (c *conn) runPublishReader(sconn srt.Conn, path defs.Path) error { } } -func (c *conn) runRead(req srtNewConnReq, streamID *streamID) (bool, error) { +func (c *conn) runRead(streamID *streamID) error { path, stream, err := c.pathManager.AddReader(defs.PathAddReaderReq{ Author: c, AccessRequest: defs.PathAccessRequest{ @@ -274,21 +256,24 @@ func (c *conn) runRead(req srtNewConnReq, streamID *streamID) (bool, error) { if errors.As(err, &terr) { // wait some seconds to mitigate brute force attacks <-time.After(auth.PauseAfterError) - return false, err + c.connReq.Reject(srt.REJ_PEER) + return terr } - return false, err + c.connReq.Reject(srt.REJ_PEER) + return err } defer path.RemoveReader(defs.PathRemoveReaderReq{Author: c}) - err = srtCheckPassphrase(req.connReq, path.SafeConf().SRTReadPassphrase) + err = srtCheckPassphrase(c.connReq, path.SafeConf().SRTReadPassphrase) if err != nil { - return false, err + c.connReq.Reject(srt.REJ_PEER) + return err } - sconn, err := c.exchangeRequestWithConn(req) + sconn, err := c.connReq.Accept() if err != nil { - return true, err + return err } defer sconn.Close() @@ -307,7 +292,7 @@ func (c *conn) runRead(req srtNewConnReq, streamID *streamID) (bool, error) { err = mpegts.FromStream(stream, writer, bw, sconn, time.Duration(c.writeTimeout)) if err != nil { - return true, err + return err } c.Log(logger.Info, "is reading from path '%s', %s", @@ -331,41 +316,10 @@ func (c *conn) runRead(req srtNewConnReq, streamID *streamID) (bool, error) { select { case <-c.ctx.Done(): - return true, fmt.Errorf("terminated") + return fmt.Errorf("terminated") case err := <-writer.Error(): - return true, err - } -} - -func (c *conn) exchangeRequestWithConn(req srtNewConnReq) (srt.Conn, error) { - req.res <- c - - select { - case sconn := <-c.chSetConn: - return sconn, nil - - case <-c.ctx.Done(): - return nil, errors.New("terminated") - } -} - -// new is called by srtListener through srtServer. -func (c *conn) new(req srtNewConnReq) *conn { - select { - case c.chNew <- req: - return <-req.res - - case <-c.ctx.Done(): - return nil - } -} - -// setConn is called by srtListener . -func (c *conn) setConn(sconn srt.Conn) { - select { - case c.chSetConn <- sconn: - case <-c.ctx.Done(): + return err } } diff --git a/internal/servers/srt/listener.go b/internal/servers/srt/listener.go index 38b9e297..000ba895 100644 --- a/internal/servers/srt/listener.go +++ b/internal/servers/srt/listener.go @@ -27,24 +27,11 @@ func (l *listener) run() { func (l *listener) runInner() error { for { - var sconn *conn - conn, _, err := l.ln.Accept(func(req srt.ConnRequest) srt.ConnType { - sconn = l.parent.newConnRequest(req) - if sconn == nil { - return srt.REJECT - } - - // currently it's the same to return SUBSCRIBE or PUBLISH - return srt.SUBSCRIBE - }) + req, err := l.ln.Accept2() if err != nil { return err } - if conn == nil { - continue - } - - sconn.setConn(conn) + l.parent.newConnRequest(req) } } diff --git a/internal/servers/srt/server.go b/internal/servers/srt/server.go index 2c1977b8..33b26f29 100644 --- a/internal/servers/srt/server.go +++ b/internal/servers/srt/server.go @@ -26,11 +26,6 @@ func srtMaxPayloadSize(u int) int { return ((u - 16) / 188) * 188 // 16 = SRT header, 188 = MPEG-TS packet } -type srtNewConnReq struct { - connReq srt.ConnRequest - res chan *conn -} - type serverAPIConnsListRes struct { data *defs.APISRTConnList err error @@ -90,7 +85,7 @@ type Server struct { conns map[*conn]struct{} // in - chNewConnRequest chan srtNewConnReq + chNewConnRequest chan srt.ConnRequest chAcceptErr chan error chCloseConn chan *conn chAPIConnsList chan serverAPIConnsListReq @@ -113,7 +108,7 @@ func (s *Server) Initialize() error { s.ctx, s.ctxCancel = context.WithCancel(context.Background()) s.conns = make(map[*conn]struct{}) - s.chNewConnRequest = make(chan srtNewConnReq) + s.chNewConnRequest = make(chan srt.ConnRequest) s.chAcceptErr = make(chan error) s.chCloseConn = make(chan *conn) s.chAPIConnsList = make(chan serverAPIConnsListReq) @@ -165,7 +160,7 @@ outer: writeTimeout: s.WriteTimeout, writeQueueSize: s.WriteQueueSize, udpMaxPayloadSize: s.UDPMaxPayloadSize, - connReq: req.connReq, + connReq: req, runOnConnect: s.RunOnConnect, runOnConnectRestart: s.RunOnConnectRestart, runOnDisconnect: s.RunOnDisconnect, @@ -176,7 +171,6 @@ outer: } c.initialize() s.conns[c] = struct{}{} - req.res <- c case c := <-s.chCloseConn: delete(s.conns, c) @@ -236,20 +230,11 @@ func (s *Server) findConnByUUID(uuid uuid.UUID) *conn { } // newConnRequest is called by srtListener. -func (s *Server) newConnRequest(connReq srt.ConnRequest) *conn { - req := srtNewConnReq{ - connReq: connReq, - res: make(chan *conn), - } - +func (s *Server) newConnRequest(connReq srt.ConnRequest) { select { - case s.chNewConnRequest <- req: - c := <-req.res - - return c.new(req) - + case s.chNewConnRequest <- connReq: case <-s.ctx.Done(): - return nil + connReq.Reject(srt.REJ_CLOSE) } } diff --git a/internal/staticsources/srt/source_test.go b/internal/staticsources/srt/source_test.go index f4257ec7..230aeeb3 100644 --- a/internal/staticsources/srt/source_test.go +++ b/internal/staticsources/srt/source_test.go @@ -20,16 +20,15 @@ func TestSource(t *testing.T) { defer ln.Close() go func() { - conn, _, err := ln.Accept(func(req srt.ConnRequest) srt.ConnType { - require.Equal(t, "sidname", req.StreamId()) - err := req.SetPassphrase("ttest1234567") - if err != nil { - return srt.REJECT - } - return srt.SUBSCRIBE - }) + req, err := ln.Accept2() + require.NoError(t, err) + + require.Equal(t, "sidname", req.StreamId()) + err = req.SetPassphrase("ttest1234567") + require.NoError(t, err) + + conn, err := req.Accept() require.NoError(t, err) - require.NotNil(t, conn) defer conn.Close() track := &mpegts.Track{