diff --git a/go.mod b/go.mod index 8e26f09e..5b241d87 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.15 require ( github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 // indirect github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d // indirect - github.com/aler9/gortsplib v0.0.0-20210424164934-262f28340026 + github.com/aler9/gortsplib v0.0.0-20210507133648-caab8c908245 github.com/asticode/go-astits v0.0.0-00010101000000-000000000000 github.com/davecgh/go-spew v1.1.1 // indirect github.com/fsnotify/fsnotify v1.4.9 diff --git a/go.sum b/go.sum index 2482e68d..2a19f45e 100644 --- a/go.sum +++ b/go.sum @@ -4,8 +4,8 @@ github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d h1:UQZhZ2O0vMHr2c github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= github.com/aler9/go-astits v0.0.0-20210423195926-582b09ed7c04 h1:CXgQLsU4uxWAmsXNOjGLbj0A+0IlRcpZpMgI13fmVwo= github.com/aler9/go-astits v0.0.0-20210423195926-582b09ed7c04/go.mod h1:DkOWmBNQpnr9mv24KfZjq4JawCFX1FCqjLVGvO0DygQ= -github.com/aler9/gortsplib v0.0.0-20210424164934-262f28340026 h1:KQ8G/yC8r1aPSMvto+L0UQEgHWgU6d6H1pCk5JVm8w4= -github.com/aler9/gortsplib v0.0.0-20210424164934-262f28340026/go.mod h1:zVCg+TQX445hh1pC5QgAuuBvvXZMWLY1XYz626dGFqY= +github.com/aler9/gortsplib v0.0.0-20210507133648-caab8c908245 h1:07JnQQwggiBI522bixbZClaB9TZVXflIS2V+GZKkafs= +github.com/aler9/gortsplib v0.0.0-20210507133648-caab8c908245/go.mod h1:zVCg+TQX445hh1pC5QgAuuBvvXZMWLY1XYz626dGFqY= github.com/aler9/rtmp v0.0.0-20210403095203-3be4a5535927 h1:95mXJ5fUCYpBRdSOnLAQAdJHHKxxxJrVCiaqDi965YQ= github.com/aler9/rtmp v0.0.0-20210403095203-3be4a5535927/go.mod h1:vzuE21rowz+lT1NGsWbreIvYulgBpCGnQyeTyFblUHc= github.com/asticode/go-astikit v0.20.0 h1:+7N+J4E4lWx2QOkRdOf6DafWJMv6O4RRfgClwQokrH8= diff --git a/internal/clientrtmp/client.go b/internal/clientrtmp/client.go index c699e648..ce1aaeda 100644 --- a/internal/clientrtmp/client.go +++ b/internal/clientrtmp/client.go @@ -11,7 +11,6 @@ import ( "time" "github.com/aler9/gortsplib" - "github.com/aler9/gortsplib/pkg/base" "github.com/aler9/gortsplib/pkg/headers" "github.com/aler9/gortsplib/pkg/ringbuffer" "github.com/aler9/gortsplib/pkg/rtpaac" @@ -36,23 +35,6 @@ const ( ptsOffset = 2 * time.Second ) -func ipEqualOrInRange(ip net.IP, ips []interface{}) bool { - for _, item := range ips { - switch titem := item.(type) { - case net.IP: - if titem.Equal(ip) { - return true - } - - case *net.IPNet: - if titem.Contains(ip) { - return true - } - } - } - return false -} - func pathNameAndQuery(inURL *url.URL) (string, url.Values) { // remove leading and trailing slashes inserted by OBS and some other clients tmp := strings.TrimRight(inURL.String(), "/") @@ -144,8 +126,8 @@ func (c *Client) Close() { close(c.terminate) } -// CloseRequest closes a Client. -func (c *Client) CloseRequest() { +// RequestClose closes a Client. +func (c *Client) RequestClose() { c.parent.OnClientClose(c) } @@ -206,7 +188,14 @@ func (c *Client) runRead() { pathName, query := pathNameAndQuery(c.conn.URL()) sres := make(chan readpublisher.SetupPlayRes) - c.pathMan.OnReadPublisherSetupPlay(readpublisher.SetupPlayReq{c, pathName, query, sres}) //nolint:govet + c.pathMan.OnReadPublisherSetupPlay(readpublisher.SetupPlayReq{ + Author: c, + PathName: pathName, + IP: c.ip(), + ValidateCredentials: func(authMethods []headers.AuthMethod, pathUser string, pathPass string) error { + return c.validateCredentials(authMethods, pathUser, pathPass, query) + }, + Res: sres}) res := <-sres if res.Err != nil { @@ -424,7 +413,16 @@ func (c *Client) runPublish() { pathName, query := pathNameAndQuery(c.conn.URL()) resc := make(chan readpublisher.AnnounceRes) - c.pathMan.OnReadPublisherAnnounce(readpublisher.AnnounceReq{c, pathName, tracks, query, resc}) //nolint:govet + c.pathMan.OnReadPublisherAnnounce(readpublisher.AnnounceReq{ + Author: c, + PathName: pathName, + Tracks: tracks, + IP: c.ip(), + ValidateCredentials: func(authMethods []headers.AuthMethod, pathUser string, pathPass string) error { + return c.validateCredentials(authMethods, pathUser, pathPass, query) + }, + Res: resc, + }) res := <-resc if res.Err != nil { @@ -466,7 +464,7 @@ func (c *Client) runPublish() { go func() { readerDone <- func() error { resc := make(chan readpublisher.RecordRes) - path.OnReadPublisherRecord(readpublisher.RecordReq{ReadPublisher: c, Res: resc}) + path.OnReadPublisherRecord(readpublisher.RecordReq{Author: c, Res: resc}) res := <-resc if res.Err != nil { @@ -599,32 +597,16 @@ func (c *Client) runPublish() { } } -// Authenticate performs an authentication. -func (c *Client) Authenticate(authMethods []headers.AuthMethod, - pathName string, ips []interface{}, - user string, pass string, req interface{}) error { +func (c *Client) validateCredentials( + authMethods []headers.AuthMethod, + pathUser string, + pathPass string, + query url.Values, +) error { - // validate ip - if ips != nil { - ip := c.ip() - - if !ipEqualOrInRange(ip, ips) { - c.log(logger.Info, "ERR: ip '%s' not allowed", ip) - - return readpublisher.ErrAuthCritical{&base.Response{ //nolint:govet - StatusCode: base.StatusUnauthorized, - }} - } - } - - // validate user - if user != "" { - values := req.(url.Values) - - if values.Get("user") != user || - values.Get("pass") != pass { - return readpublisher.ErrAuthCritical{nil} //nolint:govet - } + if query.Get("user") != pathUser || + query.Get("pass") != pathPass { + return readpublisher.ErrAuthCritical{} } return nil diff --git a/internal/clientrtsp/client.go b/internal/clientrtsp/client.go index 6b32190a..747de8b6 100644 --- a/internal/clientrtsp/client.go +++ b/internal/clientrtsp/client.go @@ -2,11 +2,8 @@ package clientrtsp import ( "errors" - "fmt" "io" "net" - "strconv" - "sync" "sync/atomic" "time" @@ -20,42 +17,28 @@ import ( "github.com/aler9/rtsp-simple-server/internal/logger" "github.com/aler9/rtsp-simple-server/internal/readpublisher" "github.com/aler9/rtsp-simple-server/internal/stats" - "github.com/aler9/rtsp-simple-server/internal/streamproc" ) const ( - sessionID = "12345678" pauseAfterAuthError = 2 * time.Second ) -func ipEqualOrInRange(ip net.IP, ips []interface{}) bool { - for _, item := range ips { - switch titem := item.(type) { - case net.IP: - if titem.Equal(ip) { - return true - } +var errTerminated = errors.New("terminated") - case *net.IPNet: - if titem.Contains(ip) { - return true - } - } - } - return false +func isTeardownErr(err error) bool { + _, ok := err.(liberrors.ErrServerSessionTeardown) + return ok } // PathMan is implemented by pathman.PathMan. type PathMan interface { OnReadPublisherDescribe(readpublisher.DescribeReq) - OnReadPublisherSetupPlay(readpublisher.SetupPlayReq) - OnReadPublisherAnnounce(readpublisher.AnnounceReq) } // Parent is implemented by serverrtsp.Server. type Parent interface { Log(logger.Level, string, ...interface{}) - OnClientClose(*Client) + // OnClientClose(*Client) } // Client is a RTSP client. @@ -64,43 +47,27 @@ type Client struct { readTimeout time.Duration runOnConnect string runOnConnectRestart bool - protocols map[gortsplib.StreamProtocol]struct{} - wg *sync.WaitGroup + pathMan PathMan stats *stats.Stats conn *gortsplib.ServerConn - pathMan PathMan parent Parent - path readpublisher.Path + onConnectCmd *externalcmd.Cmd authUser string authPass string authValidator *auth.Validator authFailures int - - // read - setuppedTracks map[int]*gortsplib.Track - onReadCmd *externalcmd.Cmd - - // publish - sp *streamproc.StreamProc - onPublishCmd *externalcmd.Cmd - - // in - terminate chan struct{} } // New allocates a Client. func New( - isTLS bool, rtspAddress string, readTimeout time.Duration, runOnConnect string, runOnConnectRestart bool, - protocols map[gortsplib.StreamProtocol]struct{}, - wg *sync.WaitGroup, + pathMan PathMan, stats *stats.Stats, conn *gortsplib.ServerConn, - pathMan PathMan, parent Parent) *Client { c := &Client{ @@ -108,460 +75,177 @@ func New( readTimeout: readTimeout, runOnConnect: runOnConnect, runOnConnectRestart: runOnConnectRestart, - protocols: protocols, - wg: wg, + pathMan: pathMan, stats: stats, conn: conn, - pathMan: pathMan, parent: parent, - terminate: make(chan struct{}), } atomic.AddInt64(c.stats.CountClients, 1) c.log(logger.Info, "connected") - c.wg.Add(1) - go c.run() + if c.runOnConnect != "" { + _, port, _ := net.SplitHostPort(c.rtspAddress) + c.onConnectCmd = externalcmd.New(c.runOnConnect, c.runOnConnectRestart, externalcmd.Environment{ + Path: "", + Port: port, + }) + } return c } // Close closes a Client. -func (c *Client) Close() { +func (c *Client) Close(err error) { + if err != io.EOF && err != errTerminated && !isTeardownErr(err) { + c.log(logger.Info, "ERR: %v", err) + } + atomic.AddInt64(c.stats.CountClients, -1) c.log(logger.Info, "disconnected") - close(c.terminate) + + if c.onConnectCmd != nil { + c.onConnectCmd.Close() + } } -// CloseRequest closes a Client. -func (c *Client) CloseRequest() { - c.parent.OnClientClose(c) -} - -// IsReadPublisher implements readpublisher.ReadPublisher. -func (c *Client) IsReadPublisher() {} - -// IsSource implements source.Source. -func (c *Client) IsSource() {} - func (c *Client) log(level logger.Level, format string, args ...interface{}) { c.parent.Log(level, "[client %s] "+format, append([]interface{}{c.conn.NetConn().RemoteAddr().String()}, args...)...) } +// Conn returns the RTSP connection. +func (c *Client) Conn() *gortsplib.ServerConn { + return c.conn +} + func (c *Client) ip() net.IP { return c.conn.NetConn().RemoteAddr().(*net.TCPAddr).IP } -var errTerminated = errors.New("terminated") - -func (c *Client) run() { - defer c.wg.Done() - - if c.runOnConnect != "" { - _, port, _ := net.SplitHostPort(c.rtspAddress) - onConnectCmd := externalcmd.New(c.runOnConnect, c.runOnConnectRestart, externalcmd.Environment{ - Path: "", - Port: port, - }) - defer onConnectCmd.Close() - } - - onRequest := func(req *base.Request) { - c.log(logger.Debug, "[c->s] %v", req) - } - - onResponse := func(res *base.Response) { - c.log(logger.Debug, "[s->c] %v", res) - } - - onDescribe := func(ctx *gortsplib.ServerConnDescribeCtx) (*base.Response, []byte, error) { - resc := make(chan readpublisher.DescribeRes) - c.pathMan.OnReadPublisherDescribe(readpublisher.DescribeReq{c, ctx.Path, ctx.Req, resc}) //nolint:govet - res := <-resc - - if res.Err != nil { - switch terr := res.Err.(type) { - case readpublisher.ErrAuthNotCritical: - return terr.Response, nil, nil - - case readpublisher.ErrAuthCritical: - // wait some seconds to stop brute force attacks - select { - case <-time.After(pauseAfterAuthError): - case <-c.terminate: - } - return terr.Response, nil, errTerminated - - case readpublisher.ErrNoOnePublishing: - return &base.Response{ - StatusCode: base.StatusNotFound, - }, nil, res.Err - - default: - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, nil, res.Err - } - } - - if res.Redirect != "" { - return &base.Response{ - StatusCode: base.StatusMovedPermanently, - Header: base.Header{ - "Location": base.HeaderValue{res.Redirect}, - }, - }, nil, nil - } - - return &base.Response{ - StatusCode: base.StatusOK, - }, res.SDP, nil - } - - onAnnounce := func(ctx *gortsplib.ServerConnAnnounceCtx) (*base.Response, error) { - resc := make(chan readpublisher.AnnounceRes) - c.pathMan.OnReadPublisherAnnounce(readpublisher.AnnounceReq{c, ctx.Path, ctx.Tracks, ctx.Req, resc}) //nolint:govet - res := <-resc - - if res.Err != nil { - switch terr := res.Err.(type) { - case readpublisher.ErrAuthNotCritical: - return terr.Response, nil - - case readpublisher.ErrAuthCritical: - // wait some seconds to stop brute force attacks - select { - case <-time.After(pauseAfterAuthError): - case <-c.terminate: - } - return terr.Response, errTerminated - - default: - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, res.Err - } - } - - c.path = res.Path - - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - onSetup := func(ctx *gortsplib.ServerConnSetupCtx) (*base.Response, error) { - if ctx.Transport.Protocol == gortsplib.StreamProtocolUDP { - if _, ok := c.protocols[gortsplib.StreamProtocolUDP]; !ok { - return &base.Response{ - StatusCode: base.StatusUnsupportedTransport, - }, nil - } - } else { - if _, ok := c.protocols[gortsplib.StreamProtocolTCP]; !ok { - return &base.Response{ - StatusCode: base.StatusUnsupportedTransport, - }, nil - } - } - - switch c.conn.State() { - case gortsplib.ServerConnStateInitial, gortsplib.ServerConnStatePrePlay: // play - resc := make(chan readpublisher.SetupPlayRes) - c.pathMan.OnReadPublisherSetupPlay(readpublisher.SetupPlayReq{c, ctx.Path, ctx.Req, resc}) //nolint:govet - res := <-resc - - if res.Err != nil { - switch terr := res.Err.(type) { - case readpublisher.ErrAuthNotCritical: - return terr.Response, nil - - case readpublisher.ErrAuthCritical: - // wait some seconds to stop brute force attacks - select { - case <-time.After(pauseAfterAuthError): - case <-c.terminate: - } - return terr.Response, errTerminated - - case readpublisher.ErrNoOnePublishing: - return &base.Response{ - StatusCode: base.StatusNotFound, - }, res.Err - - default: - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, res.Err - } - } - - c.path = res.Path - - if ctx.TrackID >= len(res.Tracks) { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("track %d does not exist", ctx.TrackID) - } - - if c.setuppedTracks == nil { - c.setuppedTracks = make(map[int]*gortsplib.Track) - } - c.setuppedTracks[ctx.TrackID] = res.Tracks[ctx.TrackID] - } - - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Session": base.HeaderValue{sessionID}, - }, - }, nil - } - - onPlay := func(ctx *gortsplib.ServerConnPlayCtx) (*base.Response, error) { - h := base.Header{ - "Session": base.HeaderValue{sessionID}, - } - - if c.conn.State() == gortsplib.ServerConnStatePrePlay { - if ctx.Path != c.path.Name() { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("path has changed, was '%s', now is '%s'", c.path.Name(), ctx.Path) - } - - res := c.playStart() - - // add RTP-Info - var ri headers.RTPInfo - for trackID, ti := range res.TrackInfos { - if ti.LastTimeNTP == 0 { - continue - } - - track, ok := c.setuppedTracks[trackID] - if !ok { - continue - } - - u := &base.URL{ - Scheme: ctx.Req.URL.Scheme, - User: ctx.Req.URL.User, - Host: ctx.Req.URL.Host, - Path: "/" + c.path.Name() + "/trackID=" + strconv.FormatInt(int64(trackID), 10), - } - - clockRate, _ := track.ClockRate() - ts := uint32(uint64(ti.LastTimeRTP) + - uint64(time.Since(time.Unix(ti.LastTimeNTP, 0)).Seconds()*float64(clockRate))) - lsn := ti.LastSequenceNumber - - ri = append(ri, &headers.RTPInfoEntry{ - URL: u.String(), - SequenceNumber: &lsn, - Timestamp: &ts, - }) - } - if len(ri) > 0 { - h["RTP-Info"] = ri.Write() - } - } - - return &base.Response{ - StatusCode: base.StatusOK, - Header: h, - }, nil - } - - onRecord := func(ctx *gortsplib.ServerConnRecordCtx) (*base.Response, error) { - if ctx.Path != c.path.Name() { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("path has changed, was '%s', now is '%s'", c.path.Name(), ctx.Path) - } - - err := c.recordStart() - if err != nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, err - } - - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Session": base.HeaderValue{sessionID}, - }, - }, nil - } - - onPause := func(ctx *gortsplib.ServerConnPauseCtx) (*base.Response, error) { - switch c.conn.State() { - case gortsplib.ServerConnStatePlay: - c.playStop() - res := make(chan struct{}) - c.path.OnReadPublisherPause(readpublisher.PauseReq{c, res}) //nolint:govet - <-res - - case gortsplib.ServerConnStateRecord: - c.recordStop() - res := make(chan struct{}) - c.path.OnReadPublisherPause(readpublisher.PauseReq{c, res}) //nolint:govet - <-res - } - - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Session": base.HeaderValue{sessionID}, - }, - }, nil - } - - onFrame := func(trackID int, streamType gortsplib.StreamType, payload []byte) { - if c.conn.State() != gortsplib.ServerConnStateRecord { - return - } - - c.sp.OnFrame(trackID, streamType, payload) - } - - readDone := c.conn.Read(gortsplib.ServerConnReadHandlers{ - OnRequest: onRequest, - OnResponse: onResponse, - OnDescribe: onDescribe, - OnAnnounce: onAnnounce, - OnSetup: onSetup, - OnPlay: onPlay, - OnRecord: onRecord, - OnPause: onPause, - OnFrame: onFrame, - }) - - select { - case err := <-readDone: - c.conn.Close() - - if err != io.EOF && err != errTerminated { - if _, ok := err.(liberrors.ErrServerTeardown); !ok { - c.log(logger.Info, "ERR: %s", err) - } - } - - switch c.conn.State() { - case gortsplib.ServerConnStatePlay: - c.playStop() - - case gortsplib.ServerConnStateRecord: - c.recordStop() - } - - if c.path != nil { - res := make(chan struct{}) - c.path.OnReadPublisherRemove(readpublisher.RemoveReq{c, res}) //nolint:govet - <-res - c.path = nil - } - - c.parent.OnClientClose(c) - <-c.terminate - - case <-c.terminate: - c.conn.Close() - <-readDone - - switch c.conn.State() { - case gortsplib.ServerConnStatePlay: - c.playStop() - - case gortsplib.ServerConnStateRecord: - c.recordStop() - } - - if c.path != nil { - res := make(chan struct{}) - c.path.OnReadPublisherRemove(readpublisher.RemoveReq{c, res}) //nolint:govet - <-res - c.path = nil - } - } +// OnRequest is called by serverrtsp.Server. +func (c *Client) OnRequest(req *base.Request) { + c.log(logger.Debug, "[c->s] %v", req) } -// Authenticate performs an authentication. -func (c *Client) Authenticate(authMethods []headers.AuthMethod, - pathName string, ips []interface{}, - user string, pass string, req interface{}) error { +// OnResponse is called by serverrtsp.Server. +func (c *Client) OnResponse(res *base.Response) { + c.log(logger.Debug, "[s->c] %v", res) +} - // validate ip - if ips != nil { - ip := c.ip() +// OnDescribe is called by serverrtsp.Server. +func (c *Client) OnDescribe(ctx *gortsplib.ServerHandlerOnDescribeCtx) (*base.Response, []byte, error) { + resc := make(chan readpublisher.DescribeRes) + c.pathMan.OnReadPublisherDescribe(readpublisher.DescribeReq{ + PathName: ctx.Path, + URL: ctx.Req.URL, + IP: c.ip(), + ValidateCredentials: func(authMethods []headers.AuthMethod, pathUser string, pathPass string) error { + return c.ValidateCredentials(authMethods, pathUser, pathPass, ctx.Path, ctx.Req) + }, + Res: resc, + }) + res := <-resc - if !ipEqualOrInRange(ip, ips) { - c.log(logger.Info, "ERR: ip '%s' not allowed", ip) + if res.Err != nil { + switch terr := res.Err.(type) { + case readpublisher.ErrAuthNotCritical: + return terr.Response, nil, nil - return readpublisher.ErrAuthCritical{&base.Response{ //nolint:govet - StatusCode: base.StatusUnauthorized, - }} + case readpublisher.ErrAuthCritical: + c.log(logger.Info, "ERR: %v", terr.Message) + + // wait some seconds to stop brute force attacks + <-time.After(pauseAfterAuthError) + return terr.Response, nil, errTerminated + + case readpublisher.ErrNoOnePublishing: + return &base.Response{ + StatusCode: base.StatusNotFound, + }, nil, res.Err + + default: + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, nil, res.Err } } - // validate user - if user != "" { - reqRTSP := req.(*base.Request) + if res.Redirect != "" { + return &base.Response{ + StatusCode: base.StatusMovedPermanently, + Header: base.Header{ + "Location": base.HeaderValue{res.Redirect}, + }, + }, nil, nil + } - // reset authValidator every time the credentials change - if c.authValidator == nil || c.authUser != user || c.authPass != pass { - c.authUser = user - c.authPass = pass - c.authValidator = auth.NewValidator(user, pass, authMethods) + return &base.Response{ + StatusCode: base.StatusOK, + }, res.SDP, nil +} + +// ValidateCredentials allows to validate the credentials of a path. +func (c *Client) ValidateCredentials( + authMethods []headers.AuthMethod, + pathUser string, + pathPass string, + pathName string, + req *base.Request, +) error { + + // reset authValidator every time the credentials change + if c.authValidator == nil || c.authUser != pathUser || c.authPass != pathPass { + c.authUser = pathUser + c.authPass = pathPass + c.authValidator = auth.NewValidator(pathUser, pathPass, authMethods) + } + + // VLC strips the control attribute + // provide an alternative URL without the control attribute + altURL := func() *base.URL { + if req.Method != base.Setup { + return nil } + return &base.URL{ + Scheme: req.URL.Scheme, + Host: req.URL.Host, + Path: "/" + pathName + "/", + } + }() - // VLC strips the control attribute - // provide an alternative URL without the control attribute - altURL := func() *base.URL { - if reqRTSP.Method != base.Setup { - return nil - } - return &base.URL{ - Scheme: reqRTSP.URL.Scheme, - Host: reqRTSP.URL.Host, - Path: "/" + pathName + "/", - } - }() + err := c.authValidator.ValidateHeader(req.Header["Authorization"], + req.Method, req.URL, altURL) + if err != nil { + c.authFailures++ - err := c.authValidator.ValidateHeader(reqRTSP.Header["Authorization"], - reqRTSP.Method, reqRTSP.URL, altURL) - if err != nil { - c.authFailures++ - - // vlc with login prompt sends 4 requests: - // 1) without credentials - // 2) with password but without username - // 3) without credentials - // 4) with password and username - // therefore we must allow up to 3 failures - if c.authFailures > 3 { - c.log(logger.Info, "ERR: unauthorized: %s", err) - - return readpublisher.ErrAuthCritical{&base.Response{ //nolint:govet + // vlc with login prompt sends 4 requests: + // 1) without credentials + // 2) with password but without username + // 3) without credentials + // 4) with password and username + // therefore we must allow up to 3 failures + if c.authFailures > 3 { + return readpublisher.ErrAuthCritical{ + Message: "unauthorized: " + err.Error(), + Response: &base.Response{ StatusCode: base.StatusUnauthorized, Header: base.Header{ "WWW-Authenticate": c.authValidator.GenerateHeader(), }, - }} - } - - if c.authFailures > 1 { - c.log(logger.Debug, "WARN: unauthorized: %s", err) - } - - return readpublisher.ErrAuthNotCritical{&base.Response{ //nolint:govet - StatusCode: base.StatusUnauthorized, - Header: base.Header{ - "WWW-Authenticate": c.authValidator.GenerateHeader(), }, - }} + } } + + if c.authFailures > 1 { + c.log(logger.Debug, "WARN: unauthorized: %s", err) + } + + return readpublisher.ErrAuthNotCritical{&base.Response{ //nolint:govet + StatusCode: base.StatusUnauthorized, + Header: base.Header{ + "WWW-Authenticate": c.authValidator.GenerateHeader(), + }, + }} } // login successful, reset authFailures @@ -569,88 +253,3 @@ func (c *Client) Authenticate(authMethods []headers.AuthMethod, return nil } - -func (c *Client) playStart() readpublisher.PlayRes { - resc := make(chan readpublisher.PlayRes) - c.path.OnReadPublisherPlay(readpublisher.PlayReq{c, resc}) //nolint:govet - res := <-resc - - tracksLen := len(c.conn.SetuppedTracks()) - - c.log(logger.Info, "is reading from path '%s', %d %s with %s", - c.path.Name(), - tracksLen, - func() string { - if tracksLen == 1 { - return "track" - } - return "tracks" - }(), - *c.conn.StreamProtocol()) - - if c.path.Conf().RunOnRead != "" { - _, port, _ := net.SplitHostPort(c.rtspAddress) - c.onReadCmd = externalcmd.New(c.path.Conf().RunOnRead, c.path.Conf().RunOnReadRestart, externalcmd.Environment{ - Path: c.path.Name(), - Port: port, - }) - } - - return res -} - -func (c *Client) playStop() { - if c.path.Conf().RunOnRead != "" { - c.onReadCmd.Close() - } -} - -func (c *Client) recordStart() error { - resc := make(chan readpublisher.RecordRes) - c.path.OnReadPublisherRecord(readpublisher.RecordReq{ReadPublisher: c, Res: resc}) - res := <-resc - - if res.Err != nil { - return res.Err - } - - c.sp = res.SP - - tracksLen := len(c.conn.AnnouncedTracks()) - - c.log(logger.Info, "is publishing to path '%s', %d %s with %s", - c.path.Name(), - tracksLen, - func() string { - if tracksLen == 1 { - return "track" - } - return "tracks" - }(), - *c.conn.StreamProtocol()) - - if c.path.Conf().RunOnPublish != "" { - _, port, _ := net.SplitHostPort(c.rtspAddress) - c.onPublishCmd = externalcmd.New(c.path.Conf().RunOnPublish, c.path.Conf().RunOnPublishRestart, externalcmd.Environment{ - Path: c.path.Name(), - Port: port, - }) - } - - return nil -} - -func (c *Client) recordStop() { - if c.path.Conf().RunOnPublish != "" { - c.onPublishCmd.Close() - } -} - -// OnFrame implements path.Reader. -func (c *Client) OnFrame(trackID int, streamType gortsplib.StreamType, payload []byte) { - if _, ok := c.conn.SetuppedTracks()[trackID]; !ok { - return - } - - c.conn.WriteFrame(trackID, streamType, payload) -} diff --git a/internal/conf/path.go b/internal/conf/path.go index f2bae434..a0707bb0 100644 --- a/internal/conf/path.go +++ b/internal/conf/path.go @@ -83,12 +83,12 @@ type PathConf struct { // authentication PublishUser string `yaml:"publishUser"` PublishPass string `yaml:"publishPass"` - PublishIps []string `yaml:"publishIps"` - PublishIpsParsed []interface{} `yaml:"-" json:"-"` + PublishIPs []string `yaml:"publishIps"` + PublishIPsParsed []interface{} `yaml:"-" json:"-"` ReadUser string `yaml:"readUser"` ReadPass string `yaml:"readPass"` - ReadIps []string `yaml:"readIps"` - ReadIpsParsed []interface{} `yaml:"-" json:"-"` + ReadIPs []string `yaml:"readIps"` + ReadIPsParsed []interface{} `yaml:"-" json:"-"` // custom commands RunOnInit string `yaml:"runOnInit"` @@ -260,12 +260,12 @@ func (pconf *PathConf) fillAndCheck(name string) error { return fmt.Errorf("publish password contains unsupported characters (supported are %s)", userPassSupportedChars) } } - if len(pconf.PublishIps) == 0 { - pconf.PublishIps = nil + if len(pconf.PublishIPs) == 0 { + pconf.PublishIPs = nil } var err error - pconf.PublishIpsParsed, err = func() ([]interface{}, error) { - if len(pconf.PublishIps) == 0 { + pconf.PublishIPsParsed, err = func() ([]interface{}, error) { + if len(pconf.PublishIPs) == 0 { return nil, nil } @@ -273,7 +273,7 @@ func (pconf *PathConf) fillAndCheck(name string) error { return nil, fmt.Errorf("'publishIps' is useless when source is not 'record', since the stream is not provided by a publisher, but by a fixed source") } - return parseIPCidrList(pconf.PublishIps) + return parseIPCidrList(pconf.PublishIPs) }() if err != nil { return err @@ -292,11 +292,11 @@ func (pconf *PathConf) fillAndCheck(name string) error { return fmt.Errorf("read password contains unsupported characters (supported are %s)", userPassSupportedChars) } } - if len(pconf.ReadIps) == 0 { - pconf.ReadIps = nil + if len(pconf.ReadIPs) == 0 { + pconf.ReadIPs = nil } - pconf.ReadIpsParsed, err = func() ([]interface{}, error) { - return parseIPCidrList(pconf.ReadIps) + pconf.ReadIPsParsed, err = func() ([]interface{}, error) { + return parseIPCidrList(pconf.ReadIPs) }() if err != nil { return err diff --git a/internal/confwatcher/confwatcher.go b/internal/confwatcher/confwatcher.go index 8bfea33e..720d5c8b 100644 --- a/internal/confwatcher/confwatcher.go +++ b/internal/confwatcher/confwatcher.go @@ -105,7 +105,7 @@ outer: close(w.signal) } -// Watch returns a channel that is called when the configuration file has changed. +// Watch returns a channel that is called after the configuration file has changed. func (w *ConfWatcher) Watch() chan struct{} { return w.signal } diff --git a/internal/converterhls/converter.go b/internal/converterhls/converter.go index e9bd1760..9534a939 100644 --- a/internal/converterhls/converter.go +++ b/internal/converterhls/converter.go @@ -13,7 +13,6 @@ import ( "time" "github.com/aler9/gortsplib" - "github.com/aler9/gortsplib/pkg/headers" "github.com/aler9/gortsplib/pkg/ringbuffer" "github.com/aler9/gortsplib/pkg/rtpaac" "github.com/aler9/gortsplib/pkg/rtph264" @@ -189,8 +188,8 @@ func (c *Converter) Close() { close(c.terminate) } -// CloseRequest closes a Converter. -func (c *Converter) CloseRequest() { +// RequestClose closes a Converter. +func (c *Converter) RequestClose() { c.parent.OnConverterClose(c) } @@ -222,7 +221,13 @@ func (c *Converter) run() { err := func() error { pres := make(chan readpublisher.SetupPlayRes) - c.pathMan.OnReadPublisherSetupPlay(readpublisher.SetupPlayReq{c, c.pathName, nil, pres}) //nolint:govet + c.pathMan.OnReadPublisherSetupPlay(readpublisher.SetupPlayReq{ + Author: c, + PathName: c.pathName, + IP: nil, + ValidateCredentials: nil, + Res: pres, + }) res := <-pres if res.Err != nil { @@ -522,10 +527,10 @@ func (c *Converter) runRequestHandler(done chan struct{}) { conf := c.path.Conf() - if conf.ReadIpsParsed != nil { + if conf.ReadIPsParsed != nil { tmp, _, _ := net.SplitHostPort(req.Req.RemoteAddr) ip := net.ParseIP(tmp) - if !ipEqualOrInRange(ip, conf.ReadIpsParsed) { + if !ipEqualOrInRange(ip, conf.ReadIPsParsed) { c.log(logger.Info, "ERR: ip '%s' not allowed", ip) req.W.WriteHeader(http.StatusUnauthorized) req.Res <- nil @@ -597,13 +602,6 @@ func (c *Converter) OnRequest(req Request) { c.request <- req } -// Authenticate performs an authentication. -func (c *Converter) Authenticate(authMethods []headers.AuthMethod, - pathName string, ips []interface{}, - user string, pass string, req interface{}) error { - return nil -} - // OnFrame implements path.Reader. func (c *Converter) OnFrame(trackID int, streamType gortsplib.StreamType, payload []byte) { if streamType == gortsplib.StreamTypeRTP { diff --git a/internal/path/path.go b/internal/path/path.go index 5986e6ad..744b0866 100644 --- a/internal/path/path.go +++ b/internal/path/path.go @@ -38,14 +38,14 @@ type sourceRedirect struct{} func (*sourceRedirect) IsSource() {} -type clientState int +type readPublisherState int const ( - clientStatePrePlay clientState = iota - clientStatePlay - clientStatePreRecord - clientStateRecord - clientStatePreRemove + readPublisherStatePrePlay readPublisherState = iota + readPublisherStatePlay + readPublisherStatePreRecord + readPublisherStateRecord + readPublisherStatePreRemove ) type sourceState int @@ -70,7 +70,7 @@ type Path struct { stats *stats.Stats parent Parent - readPublishers map[readpublisher.ReadPublisher]clientState + readPublishers map[readpublisher.ReadPublisher]readPublisherState readPublishersWg sync.WaitGroup describeRequests []readpublisher.DescribeReq setupPlayRequests []readpublisher.SetupPlayReq @@ -92,13 +92,13 @@ type Path struct { // in extSourceSetReady chan source.ExtSetReadyReq extSourceSetNotReady chan source.ExtSetNotReadyReq - clientDescribe chan readpublisher.DescribeReq - clientSetupPlay chan readpublisher.SetupPlayReq - clientAnnounce chan readpublisher.AnnounceReq - clientPlay chan readpublisher.PlayReq - clientRecord chan readpublisher.RecordReq - clientPause chan readpublisher.PauseReq - clientRemove chan readpublisher.RemoveReq + describeReq chan readpublisher.DescribeReq + setupPlayReq chan readpublisher.SetupPlayReq + announceReq chan readpublisher.AnnounceReq + playReq chan readpublisher.PlayReq + recordReq chan readpublisher.RecordReq + pauseReq chan readpublisher.PauseReq + removeReq chan readpublisher.RemoveReq terminate chan struct{} } @@ -128,7 +128,7 @@ func New( wg: wg, stats: stats, parent: parent, - readPublishers: make(map[readpublisher.ReadPublisher]clientState), + readPublishers: make(map[readpublisher.ReadPublisher]readPublisherState), readers: newReadersMap(), describeTimer: newEmptyTimer(), sourceCloseTimer: newEmptyTimer(), @@ -136,13 +136,13 @@ func New( closeTimer: newEmptyTimer(), extSourceSetReady: make(chan source.ExtSetReadyReq), extSourceSetNotReady: make(chan source.ExtSetNotReadyReq), - clientDescribe: make(chan readpublisher.DescribeReq), - clientSetupPlay: make(chan readpublisher.SetupPlayReq), - clientAnnounce: make(chan readpublisher.AnnounceReq), - clientPlay: make(chan readpublisher.PlayReq), - clientRecord: make(chan readpublisher.RecordReq), - clientPause: make(chan readpublisher.PauseReq), - clientRemove: make(chan readpublisher.RemoveReq), + describeReq: make(chan readpublisher.DescribeReq), + setupPlayReq: make(chan readpublisher.SetupPlayReq), + announceReq: make(chan readpublisher.AnnounceReq), + playReq: make(chan readpublisher.PlayReq), + recordReq: make(chan readpublisher.RecordReq), + pauseReq: make(chan readpublisher.PauseReq), + removeReq: make(chan readpublisher.RemoveReq), terminate: make(chan struct{}), } @@ -233,35 +233,35 @@ outer: pa.onSourceSetNotReady() close(req.Res) - case req := <-pa.clientDescribe: + case req := <-pa.describeReq: pa.onReadPublisherDescribe(req) - case req := <-pa.clientSetupPlay: + case req := <-pa.setupPlayReq: pa.onReadPublisherSetupPlay(req) - case req := <-pa.clientAnnounce: + case req := <-pa.announceReq: pa.onReadPublisherAnnounce(req) - case req := <-pa.clientPlay: + case req := <-pa.playReq: pa.onReadPublisherPlay(req) - case req := <-pa.clientRecord: + case req := <-pa.recordReq: pa.onReadPublisherRecord(req) - case req := <-pa.clientPause: + case req := <-pa.pauseReq: pa.onReadPublisherPause(req) - case req := <-pa.clientRemove: - if _, ok := pa.readPublishers[req.ReadPublisher]; !ok { + case req := <-pa.removeReq: + if _, ok := pa.readPublishers[req.Author]; !ok { close(req.Res) continue } - if pa.readPublishers[req.ReadPublisher] != clientStatePreRemove { - pa.removeReadPublisher(req.ReadPublisher) + if pa.readPublishers[req.Author] != readPublisherStatePreRemove { + pa.removeReadPublisher(req.Author) } - delete(pa.readPublishers, req.ReadPublisher) + delete(pa.readPublishers, req.Author) pa.readPublishersWg.Done() close(req.Res) @@ -300,29 +300,29 @@ outer: } for c, state := range pa.readPublishers { - if state != clientStatePreRemove { + if state != readPublisherStatePreRemove { switch state { - case clientStatePlay: + case readPublisherStatePlay: atomic.AddInt64(pa.stats.CountReaders, -1) pa.readers.remove(c) - case clientStateRecord: + case readPublisherStateRecord: atomic.AddInt64(pa.stats.CountPublishers, -1) } - c.CloseRequest() + c.RequestClose() } } pa.readPublishersWg.Wait() close(pa.extSourceSetReady) close(pa.extSourceSetNotReady) - close(pa.clientDescribe) - close(pa.clientSetupPlay) - close(pa.clientAnnounce) - close(pa.clientPlay) - close(pa.clientRecord) - close(pa.clientPause) - close(pa.clientRemove) + close(pa.describeReq) + close(pa.setupPlayReq) + close(pa.announceReq) + close(pa.playReq) + close(pa.recordReq) + close(pa.pauseReq) + close(pa.removeReq) } func (pa *Path) exhaustChannels() { @@ -341,48 +341,48 @@ func (pa *Path) exhaustChannels() { } close(req.Res) - case req, ok := <-pa.clientDescribe: + case req, ok := <-pa.describeReq: if !ok { return } req.Res <- readpublisher.DescribeRes{nil, "", fmt.Errorf("terminated")} //nolint:govet - case req, ok := <-pa.clientSetupPlay: + case req, ok := <-pa.setupPlayReq: if !ok { return } req.Res <- readpublisher.SetupPlayRes{nil, nil, fmt.Errorf("terminated")} //nolint:govet - case req, ok := <-pa.clientAnnounce: + case req, ok := <-pa.announceReq: if !ok { return } req.Res <- readpublisher.AnnounceRes{nil, fmt.Errorf("terminated")} //nolint:govet - case req, ok := <-pa.clientPlay: + case req, ok := <-pa.playReq: if !ok { return } close(req.Res) - case req, ok := <-pa.clientRecord: + case req, ok := <-pa.recordReq: if !ok { return } close(req.Res) - case req, ok := <-pa.clientPause: + case req, ok := <-pa.pauseReq: if !ok { return } close(req.Res) - case req, ok := <-pa.clientRemove: + case req, ok := <-pa.removeReq: if !ok { return } - if _, ok := pa.readPublishers[req.ReadPublisher]; !ok { + if _, ok := pa.readPublishers[req.Author]; !ok { close(req.Res) continue } @@ -428,7 +428,7 @@ func (pa *Path) startExternalSource() { func (pa *Path) hasReadPublishers() bool { for _, state := range pa.readPublishers { - if state != clientStatePreRemove { + if state != readPublisherStatePreRemove { return true } } @@ -437,28 +437,28 @@ func (pa *Path) hasReadPublishers() bool { func (pa *Path) hasReadPublishersNotSources() bool { for c, state := range pa.readPublishers { - if state != clientStatePreRemove && c != pa.source { + if state != readPublisherStatePreRemove && c != pa.source { return true } } return false } -func (pa *Path) addReadPublisher(c readpublisher.ReadPublisher, state clientState) { +func (pa *Path) addReadPublisher(c readpublisher.ReadPublisher, state readPublisherState) { pa.readPublishers[c] = state pa.readPublishersWg.Add(1) } func (pa *Path) removeReadPublisher(c readpublisher.ReadPublisher) { state := pa.readPublishers[c] - pa.readPublishers[c] = clientStatePreRemove + pa.readPublishers[c] = readPublisherStatePreRemove switch state { - case clientStatePlay: + case readPublisherStatePlay: atomic.AddInt64(pa.stats.CountReaders, -1) pa.readers.remove(c) - case clientStateRecord: + case readPublisherStateRecord: atomic.AddInt64(pa.stats.CountPublishers, -1) pa.onSourceSetNotReady() } @@ -468,9 +468,9 @@ func (pa *Path) removeReadPublisher(c readpublisher.ReadPublisher) { // close all readPublishers that are reading or waiting to read for oc, state := range pa.readPublishers { - if state != clientStatePreRemove { + if state != readPublisherStatePreRemove { pa.removeReadPublisher(oc) - oc.CloseRequest() + oc.RequestClose() } } } @@ -508,9 +508,9 @@ func (pa *Path) onSourceSetNotReady() { // close all readPublishers that are reading or waiting to read for c, state := range pa.readPublishers { - if c != pa.source && state != clientStatePreRemove { + if c != pa.source && state != readPublisherStatePreRemove { pa.removeReadPublisher(c) - c.CloseRequest() + c.RequestClose() } } } @@ -557,11 +557,6 @@ func (pa *Path) fixedPublisherStart() { } func (pa *Path) onReadPublisherDescribe(req readpublisher.DescribeReq) { - if _, ok := pa.readPublishers[req.ReadPublisher]; ok { - req.Res <- readpublisher.DescribeRes{nil, "", fmt.Errorf("already subscribed")} //nolint:govet - return - } - pa.fixedPublisherStart() pa.scheduleClose() @@ -584,9 +579,9 @@ func (pa *Path) onReadPublisherDescribe(req readpublisher.DescribeReq) { fallbackURL := func() string { if strings.HasPrefix(pa.conf.Fallback, "/") { ur := base.URL{ - Scheme: req.Data.URL.Scheme, - User: req.Data.URL.User, - Host: req.Data.URL.Host, + Scheme: req.URL.Scheme, + User: req.URL.User, + Host: req.URL.Host, Path: pa.conf.Fallback, } return ur.String() @@ -622,7 +617,7 @@ func (pa *Path) onReadPublisherSetupPlay(req readpublisher.SetupPlayReq) { } func (pa *Path) onReadPublisherSetupPlayPost(req readpublisher.SetupPlayReq) { - if _, ok := pa.readPublishers[req.ReadPublisher]; !ok { + if _, ok := pa.readPublishers[req.Author]; !ok { // prevent on-demand source from closing if pa.sourceCloseTimerStarted { pa.sourceCloseTimer = newEmptyTimer() @@ -635,7 +630,7 @@ func (pa *Path) onReadPublisherSetupPlayPost(req readpublisher.SetupPlayReq) { pa.runOnDemandCloseTimerStarted = false } - pa.addReadPublisher(req.ReadPublisher, clientStatePrePlay) + pa.addReadPublisher(req.Author, readPublisherStatePrePlay) } req.Res <- readpublisher.SetupPlayRes{pa, pa.sourceTracks, nil} //nolint:govet @@ -643,14 +638,14 @@ func (pa *Path) onReadPublisherSetupPlayPost(req readpublisher.SetupPlayReq) { func (pa *Path) onReadPublisherPlay(req readpublisher.PlayReq) { atomic.AddInt64(pa.stats.CountReaders, 1) - pa.readPublishers[req.ReadPublisher] = clientStatePlay - pa.readers.add(req.ReadPublisher) + pa.readPublishers[req.Author] = readPublisherStatePlay + pa.readers.add(req.Author) req.Res <- readpublisher.PlayRes{TrackInfos: pa.sp.TrackInfos()} } func (pa *Path) onReadPublisherAnnounce(req readpublisher.AnnounceReq) { - if _, ok := pa.readPublishers[req.ReadPublisher]; ok { + if _, ok := pa.readPublishers[req.Author]; ok { req.Res <- readpublisher.AnnounceRes{nil, fmt.Errorf("already publishing or reading")} //nolint:govet return } @@ -669,7 +664,7 @@ func (pa *Path) onReadPublisherAnnounce(req readpublisher.AnnounceReq) { pa.Log(logger.Info, "disconnecting existing publisher") curPublisher := pa.source.(readpublisher.ReadPublisher) pa.removeReadPublisher(curPublisher) - curPublisher.CloseRequest() + curPublisher.RequestClose() // prevent path closure if pa.closeTimerStarted { @@ -679,21 +674,21 @@ func (pa *Path) onReadPublisherAnnounce(req readpublisher.AnnounceReq) { } } - pa.addReadPublisher(req.ReadPublisher, clientStatePreRecord) + pa.addReadPublisher(req.Author, readPublisherStatePreRecord) - pa.source = req.ReadPublisher + pa.source = req.Author pa.sourceTracks = req.Tracks req.Res <- readpublisher.AnnounceRes{pa, nil} //nolint:govet } func (pa *Path) onReadPublisherRecord(req readpublisher.RecordReq) { - if state, ok := pa.readPublishers[req.ReadPublisher]; !ok || state != clientStatePreRecord { + if state, ok := pa.readPublishers[req.Author]; !ok || state != readPublisherStatePreRecord { req.Res <- readpublisher.RecordRes{SP: nil, Err: fmt.Errorf("not recording anymore")} return } atomic.AddInt64(pa.stats.CountPublishers, 1) - pa.readPublishers[req.ReadPublisher] = clientStateRecord + pa.readPublishers[req.Author] = readPublisherStateRecord pa.onSourceSetReady() pa.sp = streamproc.New(pa, len(pa.sourceTracks)) @@ -702,20 +697,20 @@ func (pa *Path) onReadPublisherRecord(req readpublisher.RecordReq) { } func (pa *Path) onReadPublisherPause(req readpublisher.PauseReq) { - state, ok := pa.readPublishers[req.ReadPublisher] + state, ok := pa.readPublishers[req.Author] if !ok { close(req.Res) return } - if state == clientStatePlay { + if state == readPublisherStatePlay { atomic.AddInt64(pa.stats.CountReaders, -1) - pa.readPublishers[req.ReadPublisher] = clientStatePrePlay - pa.readers.remove(req.ReadPublisher) + pa.readPublishers[req.Author] = readPublisherStatePrePlay + pa.readers.remove(req.Author) - } else if state == clientStateRecord { + } else if state == readPublisherStateRecord { atomic.AddInt64(pa.stats.CountPublishers, -1) - pa.readPublishers[req.ReadPublisher] = clientStatePreRecord + pa.readPublishers[req.Author] = readPublisherStatePreRecord pa.onSourceSetNotReady() } @@ -796,37 +791,37 @@ func (pa *Path) OnExtSourceSetNotReady(req source.ExtSetNotReadyReq) { // OnPathManDescribe is called by pathman.PathMan. func (pa *Path) OnPathManDescribe(req readpublisher.DescribeReq) { - pa.clientDescribe <- req + pa.describeReq <- req } // OnPathManSetupPlay is called by pathman.PathMan. func (pa *Path) OnPathManSetupPlay(req readpublisher.SetupPlayReq) { - pa.clientSetupPlay <- req + pa.setupPlayReq <- req } // OnPathManAnnounce is called by pathman.PathMan. func (pa *Path) OnPathManAnnounce(req readpublisher.AnnounceReq) { - pa.clientAnnounce <- req + pa.announceReq <- req } // OnReadPublisherRemove is called by a readpublisher. func (pa *Path) OnReadPublisherRemove(req readpublisher.RemoveReq) { - pa.clientRemove <- req + pa.removeReq <- req } // OnReadPublisherPlay is called by a readpublisher. func (pa *Path) OnReadPublisherPlay(req readpublisher.PlayReq) { - pa.clientPlay <- req + pa.playReq <- req } // OnReadPublisherRecord is called by a readpublisher. func (pa *Path) OnReadPublisherRecord(req readpublisher.RecordReq) { - pa.clientRecord <- req + pa.recordReq <- req } // OnReadPublisherPause is called by a readpublisher. func (pa *Path) OnReadPublisherPause(req readpublisher.PauseReq) { - pa.clientPause <- req + pa.pauseReq <- req } // OnSPFrame is called by streamproc.StreamProc. diff --git a/internal/pathman/pathman.go b/internal/pathman/pathman.go index febf89a8..8fd1fbbe 100644 --- a/internal/pathman/pathman.go +++ b/internal/pathman/pathman.go @@ -2,9 +2,11 @@ package pathman import ( "fmt" + "net" "sync" "time" + "github.com/aler9/gortsplib/pkg/base" "github.com/aler9/gortsplib/pkg/headers" "github.com/aler9/rtsp-simple-server/internal/conf" @@ -14,6 +16,23 @@ import ( "github.com/aler9/rtsp-simple-server/internal/stats" ) +func ipEqualOrInRange(ip net.IP, ips []interface{}) bool { + for _, item := range ips { + switch titem := item.(type) { + case net.IP: + if titem.Equal(ip) { + return true + } + + case *net.IPNet: + if titem.Contains(ip) { + return true + } + } + } + return false +} + // Parent is implemented by program. type Parent interface { Log(logger.Level, string, ...interface{}) @@ -149,13 +168,14 @@ outer: continue } - err = req.ReadPublisher.Authenticate( - pm.authMethods, + err = pm.authenticate( + req.IP, + req.ValidateCredentials, req.PathName, - pathConf.ReadIpsParsed, + pathConf.ReadIPsParsed, pathConf.ReadUser, pathConf.ReadPass, - req.Data) + ) if err != nil { req.Res <- readpublisher.DescribeRes{nil, "", err} //nolint:govet continue @@ -175,13 +195,14 @@ outer: continue } - err = req.ReadPublisher.Authenticate( - pm.authMethods, + err = pm.authenticate( + req.IP, + req.ValidateCredentials, req.PathName, - pathConf.ReadIpsParsed, + pathConf.ReadIPsParsed, pathConf.ReadUser, pathConf.ReadPass, - req.Data) + ) if err != nil { req.Res <- readpublisher.SetupPlayRes{nil, nil, err} //nolint:govet continue @@ -201,13 +222,14 @@ outer: continue } - err = req.ReadPublisher.Authenticate( - pm.authMethods, + err = pm.authenticate( + req.IP, + req.ValidateCredentials, req.PathName, - pathConf.PublishIpsParsed, + pathConf.PublishIPsParsed, pathConf.PublishUser, pathConf.PublishPass, - req.Data) + ) if err != nil { req.Res <- readpublisher.AnnounceRes{nil, err} //nolint:govet continue @@ -339,3 +361,35 @@ func (pm *PathManager) OnReadPublisherAnnounce(req readpublisher.AnnounceReq) { func (pm *PathManager) OnReadPublisherSetupPlay(req readpublisher.SetupPlayReq) { pm.clientSetupPlay <- req } + +func (pm *PathManager) authenticate( + ip net.IP, + validateCredentials func(authMethods []headers.AuthMethod, pathUser string, pathPass string) error, + pathName string, + pathIPs []interface{}, + pathUser string, + pathPass string, +) error { + + // validate ip + if pathIPs != nil && ip != nil { + if !ipEqualOrInRange(ip, pathIPs) { + return readpublisher.ErrAuthCritical{ + Message: fmt.Sprintf("IP '%s' not allowed", ip), + Response: &base.Response{ + StatusCode: base.StatusUnauthorized, + }, + } + } + } + + // validate user + if pathUser != "" && validateCredentials != nil { + err := validateCredentials(pm.authMethods, pathUser, pathPass) + if err != nil { + return err + } + } + + return nil +} diff --git a/internal/readpublisher/readpublisher.go b/internal/readpublisher/readpublisher.go index c099a281..81774b4e 100644 --- a/internal/readpublisher/readpublisher.go +++ b/internal/readpublisher/readpublisher.go @@ -2,6 +2,7 @@ package readpublisher import ( "fmt" + "net" "github.com/aler9/gortsplib" "github.com/aler9/gortsplib/pkg/base" @@ -11,6 +12,16 @@ import ( "github.com/aler9/rtsp-simple-server/internal/streamproc" ) +// Path is implemented by path.Path. +type Path interface { + Name() string + Conf() *conf.PathConf + OnReadPublisherRemove(RemoveReq) + OnReadPublisherPlay(PlayReq) + OnReadPublisherRecord(RecordReq) + OnReadPublisherPause(PauseReq) +} + // ErrNoOnePublishing is a "no one is publishing" error. type ErrNoOnePublishing struct { PathName string @@ -33,7 +44,8 @@ func (ErrAuthNotCritical) Error() string { // ErrAuthCritical is a critical authentication error. type ErrAuthCritical struct { - *base.Response + Message string + Response *base.Response } // Error implements the error interface. @@ -41,14 +53,13 @@ func (ErrAuthCritical) Error() string { return "critical authentication error" } -// Path is implemented by path.Path. -type Path interface { - Name() string - Conf() *conf.PathConf - OnReadPublisherRemove(RemoveReq) - OnReadPublisherPlay(PlayReq) - OnReadPublisherRecord(RecordReq) - OnReadPublisherPause(PauseReq) +// ReadPublisher is an entity that can read/publish from/to a path. +type ReadPublisher interface { + IsReadPublisher() + IsSource() + Close() + RequestClose() + OnFrame(int, gortsplib.StreamType, []byte) } // DescribeRes is a describe response. @@ -60,10 +71,11 @@ type DescribeRes struct { // DescribeReq is a describe request. type DescribeReq struct { - ReadPublisher ReadPublisher - PathName string - Data *base.Request - Res chan DescribeRes + PathName string + URL *base.URL + IP net.IP + ValidateCredentials func(authMethods []headers.AuthMethod, pathUser string, pathPass string) error + Res chan DescribeRes } // SetupPlayRes is a setup/play response. @@ -75,10 +87,11 @@ type SetupPlayRes struct { // SetupPlayReq is a setup/play request. type SetupPlayReq struct { - ReadPublisher ReadPublisher - PathName string - Data interface{} - Res chan SetupPlayRes + Author ReadPublisher + PathName string + IP net.IP + ValidateCredentials func(authMethods []headers.AuthMethod, pathUser string, pathPass string) error + Res chan SetupPlayRes } // AnnounceRes is a announce response. @@ -89,17 +102,18 @@ type AnnounceRes struct { // AnnounceReq is a announce request. type AnnounceReq struct { - ReadPublisher ReadPublisher - PathName string - Tracks gortsplib.Tracks - Data interface{} - Res chan AnnounceRes + Author ReadPublisher + PathName string + Tracks gortsplib.Tracks + IP net.IP + ValidateCredentials func(authMethods []headers.AuthMethod, pathUser string, pathPass string) error + Res chan AnnounceRes } // RemoveReq is a remove request. type RemoveReq struct { - ReadPublisher ReadPublisher - Res chan struct{} + Author ReadPublisher + Res chan struct{} } // PlayRes is a play response. @@ -109,8 +123,8 @@ type PlayRes struct { // PlayReq is a play request. type PlayReq struct { - ReadPublisher ReadPublisher - Res chan PlayRes + Author ReadPublisher + Res chan PlayRes } // RecordRes is a record response. @@ -121,24 +135,12 @@ type RecordRes struct { // RecordReq is a record request. type RecordReq struct { - ReadPublisher ReadPublisher - Res chan RecordRes + Author ReadPublisher + Res chan RecordRes } // PauseReq is a pause request. type PauseReq struct { - ReadPublisher ReadPublisher - Res chan struct{} -} - -// ReadPublisher is an entity that can read/publish from/to a path. -type ReadPublisher interface { - IsReadPublisher() - IsSource() - Close() - CloseRequest() - Authenticate([]headers.AuthMethod, - string, []interface{}, - string, string, interface{}) error - OnFrame(int, gortsplib.StreamType, []byte) + Author ReadPublisher + Res chan struct{} } diff --git a/internal/serverrtsp/server.go b/internal/serverrtsp/server.go index b224a402..73858327 100644 --- a/internal/serverrtsp/server.go +++ b/internal/serverrtsp/server.go @@ -11,6 +11,7 @@ import ( "github.com/aler9/rtsp-simple-server/internal/clientrtsp" "github.com/aler9/rtsp-simple-server/internal/logger" "github.com/aler9/rtsp-simple-server/internal/pathman" + "github.com/aler9/rtsp-simple-server/internal/sessionrtsp" "github.com/aler9/rtsp-simple-server/internal/stats" ) @@ -31,13 +32,13 @@ type Server struct { pathMan *pathman.PathManager parent Parent - srv *gortsplib.Server - wg sync.WaitGroup - clients map[*clientrtsp.Client]struct{} + srv *gortsplib.Server + mutex sync.RWMutex + clients map[*gortsplib.ServerConn]*clientrtsp.Client + sessions map[*gortsplib.ServerSession]*sessionrtsp.Session // in - clientClose chan *clientrtsp.Client - terminate chan struct{} + terminate chan struct{} // out done chan struct{} @@ -64,32 +65,6 @@ func New( pathMan *pathman.PathManager, parent Parent) (*Server, error) { - conf := gortsplib.ServerConf{ - ReadTimeout: readTimeout, - WriteTimeout: writeTimeout, - ReadBufferCount: readBufferCount, - ReadBufferSize: readBufferSize, - } - - if useUDP { - conf.UDPRTPAddress = rtpAddress - conf.UDPRTCPAddress = rtcpAddress - } - - if isTLS { - cert, err := tls.LoadX509KeyPair(serverCert, serverKey) - if err != nil { - return nil, err - } - - conf.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}} - } - - srv, err := conf.Serve(address) - if err != nil { - return nil, err - } - s := &Server{ readTimeout: readTimeout, isTLS: isTLS, @@ -98,19 +73,45 @@ func New( stats: stats, pathMan: pathMan, parent: parent, - srv: srv, - clients: make(map[*clientrtsp.Client]struct{}), - clientClose: make(chan *clientrtsp.Client), + clients: make(map[*gortsplib.ServerConn]*clientrtsp.Client), + sessions: make(map[*gortsplib.ServerSession]*sessionrtsp.Session), terminate: make(chan struct{}), done: make(chan struct{}), } - if conf.UDPRTPAddress != "" { - s.Log(logger.Info, "UDP/RTP listener opened on %s", conf.UDPRTPAddress) + s.srv = &gortsplib.Server{ + Handler: s, + ReadTimeout: readTimeout, + WriteTimeout: writeTimeout, + ReadBufferCount: readBufferCount, + ReadBufferSize: readBufferSize, } - if conf.UDPRTCPAddress != "" { - s.Log(logger.Info, "UDP/RTCP listener opened on %s", conf.UDPRTCPAddress) + if useUDP { + s.srv.UDPRTPAddress = rtpAddress + s.srv.UDPRTCPAddress = rtcpAddress + } + + if isTLS { + cert, err := tls.LoadX509KeyPair(serverCert, serverKey) + if err != nil { + return nil, err + } + + s.srv.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}} + } + + err := s.srv.Start(address) + if err != nil { + return nil, err + } + + if s.srv.UDPRTPAddress != "" { + s.Log(logger.Info, "UDP/RTP listener opened on %s", s.srv.UDPRTPAddress) + } + + if s.srv.UDPRTCPAddress != "" { + s.Log(logger.Info, "UDP/RTCP listener opened on %s", s.srv.UDPRTCPAddress) } s.Log(logger.Info, "TCP listener opened on %s", address) @@ -140,97 +141,158 @@ func (s *Server) Close() { func (s *Server) run() { defer close(s.done) - s.wg.Add(1) - connNew := make(chan *gortsplib.ServerConn) - acceptErr := make(chan error) + serverDone := make(chan struct{}) + serverErr := make(chan error) go func() { - defer s.wg.Done() - acceptErr <- func() error { - for { - conn, err := s.srv.Accept() - if err != nil { - return err - } - - connNew <- conn - } - }() + defer close(serverDone) + serverErr <- s.srv.Wait() }() outer: - for { - select { - case err := <-acceptErr: - s.Log(logger.Warn, "ERR: %s", err) - break outer + select { + case err := <-serverErr: + s.Log(logger.Warn, "ERR: %s", err) + break outer - case conn := <-connNew: - c := clientrtsp.New( - s.isTLS, - s.rtspAddress, - s.readTimeout, - s.runOnConnect, - s.runOnConnectRestart, - s.protocols, - &s.wg, - s.stats, - conn, - s.pathMan, - s) - s.clients[c] = struct{}{} - - case c := <-s.clientClose: - if _, ok := s.clients[c]; !ok { - continue - } - s.doClientClose(c) - - case <-s.terminate: - break outer - } + case <-s.terminate: + break outer } go func() { - for { - select { - case _, ok := <-acceptErr: - if !ok { - return - } - - case conn, ok := <-connNew: - if !ok { - return - } - conn.Close() - - case _, ok := <-s.clientClose: - if !ok { - return - } - } + for range serverErr { } }() s.srv.Close() - for c := range s.clients { - s.doClientClose(c) - } + <-serverDone - s.wg.Wait() - - close(acceptErr) - close(connNew) - close(s.clientClose) + close(serverErr) } -func (s *Server) doClientClose(c *clientrtsp.Client) { - delete(s.clients, c) - c.Close() +// OnConnOpen implements gortsplib.ServerHandlerOnConnOpenCtx. +func (s *Server) OnConnOpen(sc *gortsplib.ServerConn) { + c := clientrtsp.New( + s.rtspAddress, + s.readTimeout, + s.runOnConnect, + s.runOnConnectRestart, + s.pathMan, + s.stats, + sc, + s) + + s.mutex.Lock() + s.clients[sc] = c + s.mutex.Unlock() } -// OnClientClose is called by clientrtsp.Client. -func (s *Server) OnClientClose(c *clientrtsp.Client) { - s.clientClose <- c +// OnConnClose implements gortsplib.ServerHandlerOnConnCloseCtx. +func (s *Server) OnConnClose(sc *gortsplib.ServerConn, err error) { + s.mutex.Lock() + c := s.clients[sc] + delete(s.clients, sc) + s.mutex.Unlock() + + c.Close(err) +} + +// OnRequest implements gortsplib.ServerHandlerOnRequestCtx. +func (s *Server) OnRequest(sc *gortsplib.ServerConn, req *base.Request) { + s.mutex.Lock() + c := s.clients[sc] + s.mutex.Unlock() + + c.OnRequest(req) +} + +// OnResponse implements gortsplib.ServerHandlerOnResponseCtx. +func (s *Server) OnResponse(sc *gortsplib.ServerConn, res *base.Response) { + s.mutex.Lock() + c := s.clients[sc] + s.mutex.Unlock() + + c.OnResponse(res) +} + +// OnSessionOpen implements gortsplib.ServerHandlerOnSessionOpenCtx. +func (s *Server) OnSessionOpen(ss *gortsplib.ServerSession) { + se := sessionrtsp.New( + s.rtspAddress, + s.protocols, + ss, + s.pathMan, + s) + + s.mutex.Lock() + s.sessions[ss] = se + s.mutex.Unlock() +} + +// OnSessionClose implements gortsplib.ServerHandlerOnSessionCloseCtx. +func (s *Server) OnSessionClose(ss *gortsplib.ServerSession, err error) { + s.mutex.Lock() + se := s.sessions[ss] + delete(s.sessions, ss) + s.mutex.Unlock() + + se.Close() +} + +// OnDescribe implements gortsplib.ServerHandlerOnDescribeCtx. +func (s *Server) OnDescribe(ctx *gortsplib.ServerHandlerOnDescribeCtx) (*base.Response, []byte, error) { + s.mutex.RLock() + c := s.clients[ctx.Conn] + s.mutex.RUnlock() + return c.OnDescribe(ctx) +} + +// OnAnnounce implements gortsplib.ServerHandlerOnAnnounceCtx. +func (s *Server) OnAnnounce(ctx *gortsplib.ServerHandlerOnAnnounceCtx) (*base.Response, error) { + s.mutex.RLock() + c := s.clients[ctx.Conn] + se := s.sessions[ctx.Session] + s.mutex.RUnlock() + return se.OnAnnounce(c, ctx) +} + +// OnSetup implements gortsplib.ServerHandlerOnSetupCtx. +func (s *Server) OnSetup(ctx *gortsplib.ServerHandlerOnSetupCtx) (*base.Response, error) { + s.mutex.RLock() + c := s.clients[ctx.Conn] + se := s.sessions[ctx.Session] + s.mutex.RUnlock() + return se.OnSetup(c, ctx) +} + +// OnPlay implements gortsplib.ServerHandlerOnPlayCtx. +func (s *Server) OnPlay(ctx *gortsplib.ServerHandlerOnPlayCtx) (*base.Response, error) { + s.mutex.RLock() + se := s.sessions[ctx.Session] + s.mutex.RUnlock() + return se.OnPlay(ctx) +} + +// OnRecord implements gortsplib.ServerHandlerOnRecordCtx. +func (s *Server) OnRecord(ctx *gortsplib.ServerHandlerOnRecordCtx) (*base.Response, error) { + s.mutex.RLock() + se := s.sessions[ctx.Session] + s.mutex.RUnlock() + return se.OnRecord(ctx) +} + +// OnPause implements gortsplib.ServerHandlerOnPauseCtx. +func (s *Server) OnPause(ctx *gortsplib.ServerHandlerOnPauseCtx) (*base.Response, error) { + s.mutex.RLock() + se := s.sessions[ctx.Session] + s.mutex.RUnlock() + return se.OnPause(ctx) +} + +// OnFrame implements gortsplib.ServerHandlerOnFrameCtx. +func (s *Server) OnFrame(ctx *gortsplib.ServerHandlerOnFrameCtx) { + s.mutex.RLock() + se := s.sessions[ctx.Session] + s.mutex.RUnlock() + se.OnIncomingFrame(ctx) } diff --git a/internal/sessionrtsp/session.go b/internal/sessionrtsp/session.go new file mode 100644 index 00000000..28733b98 --- /dev/null +++ b/internal/sessionrtsp/session.go @@ -0,0 +1,393 @@ +package sessionrtsp + +import ( + "errors" + "fmt" + "net" + "strconv" + "time" + + "github.com/aler9/gortsplib" + "github.com/aler9/gortsplib/pkg/base" + "github.com/aler9/gortsplib/pkg/headers" + + "github.com/aler9/rtsp-simple-server/internal/clientrtsp" + "github.com/aler9/rtsp-simple-server/internal/externalcmd" + "github.com/aler9/rtsp-simple-server/internal/logger" + "github.com/aler9/rtsp-simple-server/internal/readpublisher" + "github.com/aler9/rtsp-simple-server/internal/streamproc" +) + +const ( + pauseAfterAuthError = 2 * time.Second +) + +var errTerminated = errors.New("terminated") + +// PathMan is implemented by pathman.PathMan. +type PathMan interface { + OnReadPublisherSetupPlay(readpublisher.SetupPlayReq) + OnReadPublisherAnnounce(readpublisher.AnnounceReq) +} + +// Parent is implemented by serverrtsp.Server. +type Parent interface { + Log(logger.Level, string, ...interface{}) +} + +// Session is a RTSP session. +type Session struct { + rtspAddress string + protocols map[gortsplib.StreamProtocol]struct{} + ss *gortsplib.ServerSession + pathMan PathMan + parent Parent + + path readpublisher.Path + setuppedTracks map[int]*gortsplib.Track // read + onReadCmd *externalcmd.Cmd // read + sp *streamproc.StreamProc // publish + onPublishCmd *externalcmd.Cmd // publish +} + +// New allocates a Session. +func New( + rtspAddress string, + protocols map[gortsplib.StreamProtocol]struct{}, + ss *gortsplib.ServerSession, + pathMan PathMan, + parent Parent) *Session { + + s := &Session{ + rtspAddress: rtspAddress, + protocols: protocols, + ss: ss, + pathMan: pathMan, + parent: parent, + } + + s.log(logger.Info, "created") + + return s +} + +// Close closes a Session. +func (s *Session) Close() { + s.log(logger.Info, "destroyed") + + switch s.ss.State() { + case gortsplib.ServerSessionStatePlay: + if s.onReadCmd != nil { + s.onReadCmd.Close() + } + + case gortsplib.ServerSessionStateRecord: + if s.onPublishCmd != nil { + s.onPublishCmd.Close() + } + } + + if s.path != nil { + res := make(chan struct{}) + s.path.OnReadPublisherRemove(readpublisher.RemoveReq{s, res}) //nolint:govet + <-res + s.path = nil + } +} + +// RequestClose closes a Session. +func (s *Session) RequestClose() { + s.ss.Close() +} + +// IsReadPublisher implements readpublisher.ReadPublisher. +func (s *Session) IsReadPublisher() {} + +// IsSource implements source.Source. +func (s *Session) IsSource() {} + +func (s *Session) log(level logger.Level, format string, args ...interface{}) { + s.parent.Log(level, "[session %s] "+format, append([]interface{}{"TODO"}, args...)...) +} + +// OnAnnounce is called by serverrtsp.Server. +func (s *Session) OnAnnounce(c *clientrtsp.Client, ctx *gortsplib.ServerHandlerOnAnnounceCtx) (*base.Response, error) { + resc := make(chan readpublisher.AnnounceRes) + s.pathMan.OnReadPublisherAnnounce(readpublisher.AnnounceReq{ + Author: s, + PathName: ctx.Path, + Tracks: ctx.Tracks, + IP: ctx.Conn.NetConn().RemoteAddr().(*net.TCPAddr).IP, + ValidateCredentials: func(authMethods []headers.AuthMethod, pathUser string, pathPass string) error { + return c.ValidateCredentials(authMethods, pathUser, pathPass, ctx.Path, ctx.Req) + }, + Res: resc, + }) + res := <-resc + + if res.Err != nil { + switch terr := res.Err.(type) { + case readpublisher.ErrAuthNotCritical: + return terr.Response, nil + + case readpublisher.ErrAuthCritical: + s.log(logger.Info, "ERR: %v", terr.Message) + + // wait some seconds to stop brute force attacks + <-time.After(pauseAfterAuthError) + return terr.Response, errTerminated + + default: + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, res.Err + } + } + + s.path = res.Path + + return &base.Response{ + StatusCode: base.StatusOK, + }, nil +} + +// OnSetup is called by serverrtsp.Server. +func (s *Session) OnSetup(c *clientrtsp.Client, ctx *gortsplib.ServerHandlerOnSetupCtx) (*base.Response, error) { + if ctx.Transport.Protocol == gortsplib.StreamProtocolUDP { + if _, ok := s.protocols[gortsplib.StreamProtocolUDP]; !ok { + return &base.Response{ + StatusCode: base.StatusUnsupportedTransport, + }, nil + } + } else { + if _, ok := s.protocols[gortsplib.StreamProtocolTCP]; !ok { + return &base.Response{ + StatusCode: base.StatusUnsupportedTransport, + }, nil + } + } + + switch s.ss.State() { + case gortsplib.ServerSessionStateInitial, gortsplib.ServerSessionStatePrePlay: // play + resc := make(chan readpublisher.SetupPlayRes) + s.pathMan.OnReadPublisherSetupPlay(readpublisher.SetupPlayReq{ + Author: s, + PathName: ctx.Path, + IP: ctx.Conn.NetConn().RemoteAddr().(*net.TCPAddr).IP, + ValidateCredentials: func(authMethods []headers.AuthMethod, pathUser string, pathPass string) error { + return c.ValidateCredentials(authMethods, pathUser, pathPass, ctx.Path, ctx.Req) + }, + Res: resc, + }) + res := <-resc + + if res.Err != nil { + switch terr := res.Err.(type) { + case readpublisher.ErrAuthNotCritical: + return terr.Response, nil + + case readpublisher.ErrAuthCritical: + s.log(logger.Info, "ERR: %v", terr.Message) + + // wait some seconds to stop brute force attacks + <-time.After(pauseAfterAuthError) + return terr.Response, errTerminated + + case readpublisher.ErrNoOnePublishing: + return &base.Response{ + StatusCode: base.StatusNotFound, + }, res.Err + + default: + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, res.Err + } + } + + s.path = res.Path + + if ctx.TrackID >= len(res.Tracks) { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("track %d does not exist", ctx.TrackID) + } + + if s.setuppedTracks == nil { + s.setuppedTracks = make(map[int]*gortsplib.Track) + } + s.setuppedTracks[ctx.TrackID] = res.Tracks[ctx.TrackID] + } + + return &base.Response{ + StatusCode: base.StatusOK, + }, nil +} + +// OnPlay is called by serverrtsp.Server. +func (s *Session) OnPlay(ctx *gortsplib.ServerHandlerOnPlayCtx) (*base.Response, error) { + h := make(base.Header) + + if s.ss.State() == gortsplib.ServerSessionStatePrePlay { + if ctx.Path != s.path.Name() { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("path has changed, was '%s', now is '%s'", s.path.Name(), ctx.Path) + } + + resc := make(chan readpublisher.PlayRes) + s.path.OnReadPublisherPlay(readpublisher.PlayReq{s, resc}) //nolint:govet + res := <-resc + + tracksLen := len(s.ss.SetuppedTracks()) + + s.log(logger.Info, "is reading from path '%s', %d %s with %s", + s.path.Name(), + tracksLen, + func() string { + if tracksLen == 1 { + return "track" + } + return "tracks" + }(), + *s.ss.StreamProtocol()) + + if s.path.Conf().RunOnRead != "" { + _, port, _ := net.SplitHostPort(s.rtspAddress) + s.onReadCmd = externalcmd.New(s.path.Conf().RunOnRead, s.path.Conf().RunOnReadRestart, externalcmd.Environment{ + Path: s.path.Name(), + Port: port, + }) + } + + // add RTP-Info + var ri headers.RTPInfo + for trackID, ti := range res.TrackInfos { + if ti.LastTimeNTP == 0 { + continue + } + + track, ok := s.setuppedTracks[trackID] + if !ok { + continue + } + + u := &base.URL{ + Scheme: ctx.Req.URL.Scheme, + User: ctx.Req.URL.User, + Host: ctx.Req.URL.Host, + Path: "/" + s.path.Name() + "/trackID=" + strconv.FormatInt(int64(trackID), 10), + } + + clockRate, _ := track.ClockRate() + ts := uint32(uint64(ti.LastTimeRTP) + + uint64(time.Since(time.Unix(ti.LastTimeNTP, 0)).Seconds()*float64(clockRate))) + lsn := ti.LastSequenceNumber + + ri = append(ri, &headers.RTPInfoEntry{ + URL: u.String(), + SequenceNumber: &lsn, + Timestamp: &ts, + }) + } + if len(ri) > 0 { + h["RTP-Info"] = ri.Write() + } + } + + return &base.Response{ + StatusCode: base.StatusOK, + Header: h, + }, nil +} + +// OnRecord is called by serverrtsp.Server. +func (s *Session) OnRecord(ctx *gortsplib.ServerHandlerOnRecordCtx) (*base.Response, error) { + if ctx.Path != s.path.Name() { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("path has changed, was '%s', now is '%s'", s.path.Name(), ctx.Path) + } + + resc := make(chan readpublisher.RecordRes) + s.path.OnReadPublisherRecord(readpublisher.RecordReq{Author: s, Res: resc}) + res := <-resc + + if res.Err != nil { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, res.Err + } + + s.sp = res.SP + + tracksLen := len(s.ss.AnnouncedTracks()) + + s.log(logger.Info, "is publishing to path '%s', %d %s with %s", + s.path.Name(), + tracksLen, + func() string { + if tracksLen == 1 { + return "track" + } + return "tracks" + }(), + *s.ss.StreamProtocol()) + + if s.path.Conf().RunOnPublish != "" { + _, port, _ := net.SplitHostPort(s.rtspAddress) + s.onPublishCmd = externalcmd.New(s.path.Conf().RunOnPublish, s.path.Conf().RunOnPublishRestart, externalcmd.Environment{ + Path: s.path.Name(), + Port: port, + }) + } + + return &base.Response{ + StatusCode: base.StatusOK, + }, nil +} + +// OnPause is called by serverrtsp.Server. +func (s *Session) OnPause(ctx *gortsplib.ServerHandlerOnPauseCtx) (*base.Response, error) { + switch s.ss.State() { + case gortsplib.ServerSessionStatePlay: + if s.onReadCmd != nil { + s.onReadCmd.Close() + } + + res := make(chan struct{}) + s.path.OnReadPublisherPause(readpublisher.PauseReq{s, res}) //nolint:govet + <-res + + case gortsplib.ServerSessionStateRecord: + if s.onPublishCmd != nil { + s.onPublishCmd.Close() + } + + res := make(chan struct{}) + s.path.OnReadPublisherPause(readpublisher.PauseReq{s, res}) //nolint:govet + <-res + } + + return &base.Response{ + StatusCode: base.StatusOK, + }, nil +} + +// OnFrame implements path.Reader. +func (s *Session) OnFrame(trackID int, streamType gortsplib.StreamType, payload []byte) { + if _, ok := s.ss.SetuppedTracks()[trackID]; !ok { + return + } + + s.ss.WriteFrame(trackID, streamType, payload) +} + +// OnIncomingFrame is called by serverrtsp.Server. +func (s *Session) OnIncomingFrame(ctx *gortsplib.ServerHandlerOnFrameCtx) { + if s.ss.State() != gortsplib.ServerSessionStateRecord { + return + } + + s.sp.OnFrame(ctx.TrackID, ctx.StreamType, ctx.Payload) +} diff --git a/internal/sourcertsp/source.go b/internal/sourcertsp/source.go index c89fe5c0..df4e9c0b 100644 --- a/internal/sourcertsp/source.go +++ b/internal/sourcertsp/source.go @@ -129,7 +129,7 @@ func (s *Source) runInner() bool { go func() { defer close(dialDone) - conf := gortsplib.ClientConf{ + client := &gortsplib.Client{ StreamProtocol: s.proto, TLSConfig: &tls.Config{ InsecureSkipVerify: true, @@ -158,7 +158,8 @@ func (s *Source) runInner() bool { s.log(logger.Debug, "s->c %v", res) }, } - conn, err = conf.DialRead(s.ur) + + conn, err = client.DialRead(s.ur) }() select { diff --git a/main_clientrtsp_test.go b/main_clientrtsp_test.go index 615a1d86..654153ab 100644 --- a/main_clientrtsp_test.go +++ b/main_clientrtsp_test.go @@ -532,7 +532,7 @@ func TestClientRTSPNonCompliantFrameSize(t *testing.T) { track, err := gortsplib.NewTrackH264(96, []byte("123456"), []byte("123456")) require.NoError(t, err) - conf := gortsplib.ClientConf{ + client := &gortsplib.Client{ StreamProtocol: func() *gortsplib.StreamProtocol { v := gortsplib.StreamProtocolTCP return &v @@ -540,12 +540,12 @@ func TestClientRTSPNonCompliantFrameSize(t *testing.T) { ReadBufferSize: 4500, } - source, err := conf.DialPublish("rtsp://"+ownDockerIP+":8554/teststream", + source, err := client.DialPublish("rtsp://"+ownDockerIP+":8554/teststream", gortsplib.Tracks{track}) require.NoError(t, err) defer source.Close() - dest, err := conf.DialRead("rtsp://" + ownDockerIP + ":8554/teststream") + dest, err := client.DialRead("rtsp://" + ownDockerIP + ":8554/teststream") require.NoError(t, err) defer dest.Close() @@ -579,7 +579,7 @@ func TestClientRTSPNonCompliantFrameSize(t *testing.T) { track, err := gortsplib.NewTrackH264(96, []byte("123456"), []byte("123456")) require.NoError(t, err) - conf := gortsplib.ClientConf{ + client := &gortsplib.Client{ StreamProtocol: func() *gortsplib.StreamProtocol { v := gortsplib.StreamProtocolTCP return &v @@ -587,7 +587,7 @@ func TestClientRTSPNonCompliantFrameSize(t *testing.T) { ReadBufferSize: 4500, } - source, err := conf.DialPublish("rtsp://"+ownDockerIP+":8554/teststream", + source, err := client.DialPublish("rtsp://"+ownDockerIP+":8554/teststream", gortsplib.Tracks{track}) require.NoError(t, err) defer source.Close() @@ -606,7 +606,7 @@ func TestClientRTSPNonCompliantFrameSize(t *testing.T) { time.Sleep(100 * time.Millisecond) - dest, err := conf.DialRead("rtsp://" + ownDockerIP + ":8555/teststream") + dest, err := client.DialRead("rtsp://" + ownDockerIP + ":8555/teststream") require.NoError(t, err) defer dest.Close() @@ -827,7 +827,7 @@ func TestClientRTSPRunOnDemand(t *testing.T) { doneFile := filepath.Join(os.TempDir(), "ondemand_done") onDemandFile, err := writeTempFile([]byte(fmt.Sprintf(`#!/bin/sh trap 'touch %s; [ -z "$(jobs -p)" ] || kill $(jobs -p)' INT -ffmpeg -hide_banner -loglevel error -re -i testimages/ffmpeg/emptyvideo.mkv -c copy -f rtsp rtsp://localhost:$RTSP_PORT/$RTSP_PATH & +(ffmpeg -hide_banner -loglevel error -re -i testimages/ffmpeg/emptyvideo.mkv -c copy -f rtsp rtsp://localhost:$RTSP_PORT/$RTSP_PATH; sleep 86400) & wait `, doneFile))) require.NoError(t, err)