diff --git a/go.mod b/go.mod index 0bcddcf9..f3f843e2 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.15 require ( github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 // indirect github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d // indirect - github.com/aler9/gortsplib v0.0.0-20201105100708-34389c06cd57 + github.com/aler9/gortsplib v0.0.0-20201107181327-316a40cdf5af github.com/davecgh/go-spew v1.1.1 // indirect github.com/fsnotify/fsnotify v1.4.9 github.com/notedit/rtmp v0.0.2 diff --git a/go.sum b/go.sum index 72152a9a..fb5af8a2 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafo github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d h1:UQZhZ2O0vMHr2cI+DC1Mbh0TJxzA3RcLoMsFw+aXw7E= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= -github.com/aler9/gortsplib v0.0.0-20201105100708-34389c06cd57 h1:Vm3vZJsk99jVPNklTRE5gyxzlgFvuoDyEmB1kVTOIdk= -github.com/aler9/gortsplib v0.0.0-20201105100708-34389c06cd57/go.mod h1:dRaVvesTIz9a2RbIp7WCfO1dFR0xnkpgSlN3eIn6VfI= +github.com/aler9/gortsplib v0.0.0-20201107181327-316a40cdf5af h1:lcTfGNVO8SmxJkJEY5cu3oCgtKMq6TbH03u0OJfDfWU= +github.com/aler9/gortsplib v0.0.0-20201107181327-316a40cdf5af/go.mod h1:6yKsTNIrCapRz90WHQtyFV/rKK0TT+QapxUXNqSJi9M= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= diff --git a/internal/client/client.go b/internal/client/client.go index 06aa778e..d08525cb 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -26,6 +26,7 @@ import ( const ( checkStreamInterval = 5 * time.Second receiverReportInterval = 10 * time.Second + sessionId = "12345678" ) type readRequestPair struct { @@ -86,6 +87,7 @@ type Path interface { OnClientRemove(*Client) OnClientPlay(*Client) OnClientRecord(*Client) + OnClientPause(*Client) OnFrame(int, gortsplib.StreamType, []byte) } @@ -165,8 +167,6 @@ func New( parent: parent, state: stateInitial, streamTracks: make(map[int]*streamTrack), - describeData: make(chan describeData), - tcpFrame: make(chan *base.InterleavedFrame), terminate: make(chan struct{}), } @@ -199,10 +199,11 @@ func (c *Client) zone() string { return c.conn.NetConn().RemoteAddr().(*net.TCPAddr).Zone } -var errRunTerminate = errors.New("terminate") -var errRunWaitingDescribe = errors.New("wait description") -var errRunPlay = errors.New("play") -var errRunRecord = errors.New("record") +var errStateTerminate = errors.New("terminate") +var errStateWaitingDescribe = errors.New("wait description") +var errStatePlay = errors.New("play") +var errStateRecord = errors.New("record") +var errStateInitial = errors.New("initial") func (c *Client) run() { defer c.wg.Done() @@ -230,20 +231,6 @@ func (c *Client) run() { if onConnectCmd != nil { onConnectCmd.Close() } - - close(c.describeData) - close(c.tcpFrame) -} - -func (c *Client) writeResError(cseq base.HeaderValue, code base.StatusCode, err error) { - c.log("ERR: %s", err) - - c.conn.WriteResponse(&base.Response{ - StatusCode: code, - Header: base.Header{ - "CSeq": cseq, - }, - }) } type errAuthNotCritical struct { @@ -333,13 +320,37 @@ func (c *Client) Authenticate(authMethods []headers.AuthMethod, ips []interface{ return nil } +func (c *Client) checkState(allowed map[state]struct{}) error { + if _, ok := allowed[c.state]; ok { + return nil + } + + var allowedList []state + for s := range allowed { + allowedList = append(allowedList, s) + } + return fmt.Errorf("client must be in state %v, while is in state %v", + allowedList, c.state) +} + +func (c *Client) writeResError(cseq base.HeaderValue, code base.StatusCode, err error) { + c.log("ERR: %s", err) + + c.conn.WriteResponse(&base.Response{ + StatusCode: code, + Header: base.Header{ + "CSeq": cseq, + }, + }) +} + func (c *Client) handleRequest(req *base.Request) error { c.log(string(req.Method)) cseq, ok := req.Header["CSeq"] if !ok || len(cseq) != 1 { c.writeResError(nil, base.StatusBadRequest, fmt.Errorf("cseq missing")) - return errRunTerminate + return errStateTerminate } switch req.Method { @@ -374,32 +385,38 @@ func (c *Client) handleRequest(req *base.Request) error { return nil case base.DESCRIBE: - if c.state != stateInitial { - c.writeResError(cseq, base.StatusBadRequest, - fmt.Errorf("client is in state '%s' instead of '%s'", c.state, stateInitial)) - return errRunTerminate + err := c.checkState(map[state]struct{}{ + stateInitial: {}, + }) + if err != nil { + c.writeResError(cseq, base.StatusBadRequest, err) + return errStateTerminate } basePath, ok := req.URL.BasePath() if !ok { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("unable to find base path (%s)", req.URL)) - return errRunTerminate + return errStateTerminate } + c.describeData = make(chan describeData) + path, err := c.parent.OnClientDescribe(c, basePath, req) if err != nil { switch terr := err.(type) { case errAuthNotCritical: + close(c.describeData) c.conn.WriteResponse(terr.Response) return nil case errAuthCritical: + close(c.describeData) c.conn.WriteResponse(terr.Response) - return errRunTerminate + return errStateTerminate default: c.writeResError(cseq, base.StatusBadRequest, err) - return errRunTerminate + return errStateTerminate } } @@ -408,41 +425,43 @@ func (c *Client) handleRequest(req *base.Request) error { c.describeCSeq = cseq c.describeUrl = req.URL.String() - return errRunWaitingDescribe + return errStateWaitingDescribe case base.ANNOUNCE: - if c.state != stateInitial { - c.writeResError(cseq, base.StatusBadRequest, - fmt.Errorf("client is in state '%s' instead of '%s'", c.state, stateInitial)) - return errRunTerminate + err := c.checkState(map[state]struct{}{ + stateInitial: {}, + }) + if err != nil { + c.writeResError(cseq, base.StatusBadRequest, err) + return errStateTerminate } ct, ok := req.Header["Content-Type"] if !ok || len(ct) != 1 { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("Content-Type header missing")) - return errRunTerminate + return errStateTerminate } if ct[0] != "application/sdp" { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("unsupported Content-Type '%s'", ct)) - return errRunTerminate + return errStateTerminate } tracks, err := gortsplib.ReadTracks(req.Content) if err != nil { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("invalid SDP: %s", err)) - return errRunTerminate + return errStateTerminate } if len(tracks) == 0 { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("no tracks defined")) - return errRunTerminate + return errStateTerminate } basePath, ok := req.URL.BasePath() if !ok { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("unable to find base path (%s)", req.URL)) - return errRunTerminate + return errStateTerminate } path, err := c.parent.OnClientAnnounce(c, basePath, tracks, req) @@ -454,11 +473,11 @@ func (c *Client) handleRequest(req *base.Request) error { case errAuthCritical: c.conn.WriteResponse(terr.Response) - return errRunTerminate + return errStateTerminate default: c.writeResError(cseq, base.StatusBadRequest, err) - return errRunTerminate + return errStateTerminate } } @@ -477,48 +496,48 @@ func (c *Client) handleRequest(req *base.Request) error { th, err := headers.ReadTransport(req.Header["Transport"]) if err != nil { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("transport header: %s", err)) - return errRunTerminate + return errStateTerminate } - if th.Cast != nil && *th.Cast == gortsplib.StreamMulticast { + if th.Delivery != nil && *th.Delivery == base.StreamDeliveryMulticast { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("multicast is not supported")) - return errRunTerminate + return errStateTerminate } basePath, controlPath, ok := req.URL.BasePathControlAttr() if !ok { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("unable to find control attribute (%s)", req.URL)) - return errRunTerminate + return errStateTerminate } switch c.state { // play case stateInitial, statePrePlay: - if th.Mode != nil && *th.Mode != gortsplib.TransportModePlay { + if th.Mode != nil && *th.Mode != headers.TransportModePlay { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("transport header must contain mode=play or not contain a mode")) - return errRunTerminate + return errStateTerminate } if c.path != nil && basePath != c.path.Name() { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("path has changed, was '%s', now is '%s'", c.path.Name(), basePath)) - return errRunTerminate + return errStateTerminate } if !strings.HasPrefix(controlPath, "trackID=") { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("invalid control attribute (%s)", controlPath)) - return errRunTerminate + return errStateTerminate } tmp, err := strconv.ParseInt(controlPath[len("trackID="):], 10, 64) if err != nil || tmp < 0 { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("invalid track id (%s)", controlPath)) - return errRunTerminate + return errStateTerminate } trackId := int(tmp) if _, ok := c.streamTracks[trackId]; ok { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("track %d has already been setup", trackId)) - return errRunTerminate + return errStateTerminate } // play with UDP @@ -530,12 +549,12 @@ func (c *Client) handleRequest(req *base.Request) error { if len(c.streamTracks) > 0 && c.streamProtocol != gortsplib.StreamProtocolUDP { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("can't receive tracks with different protocols")) - return errRunTerminate + return errStateTerminate } if th.ClientPorts == nil { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("transport header does not have valid client ports (%v)", req.Header["Transport"])) - return errRunTerminate + return errStateTerminate } path, err := c.parent.OnClientSetupPlay(c, basePath, trackId, req) @@ -547,11 +566,11 @@ func (c *Client) handleRequest(req *base.Request) error { case errAuthCritical: c.conn.WriteResponse(terr.Response) - return errRunTerminate + return errStateTerminate default: c.writeResError(cseq, base.StatusBadRequest, err) - return errRunTerminate + return errStateTerminate } } @@ -566,8 +585,8 @@ func (c *Client) handleRequest(req *base.Request) error { th := &headers.Transport{ Protocol: gortsplib.StreamProtocolUDP, - Cast: func() *gortsplib.StreamCast { - v := gortsplib.StreamUnicast + Delivery: func() *base.StreamDelivery { + v := base.StreamDeliveryUnicast return &v }(), ClientPorts: th.ClientPorts, @@ -579,7 +598,7 @@ func (c *Client) handleRequest(req *base.Request) error { Header: base.Header{ "CSeq": cseq, "Transport": th.Write(), - "Session": base.HeaderValue{"12345678"}, + "Session": base.HeaderValue{sessionId}, }, }) return nil @@ -593,7 +612,7 @@ func (c *Client) handleRequest(req *base.Request) error { if len(c.streamTracks) > 0 && c.streamProtocol != gortsplib.StreamProtocolTCP { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("can't receive tracks with different protocols")) - return errRunTerminate + return errStateTerminate } path, err := c.parent.OnClientSetupPlay(c, basePath, trackId, req) @@ -605,11 +624,11 @@ func (c *Client) handleRequest(req *base.Request) error { case errAuthCritical: c.conn.WriteResponse(terr.Response) - return errRunTerminate + return errStateTerminate default: c.writeResError(cseq, base.StatusBadRequest, err) - return errRunTerminate + return errStateTerminate } } @@ -634,7 +653,7 @@ func (c *Client) handleRequest(req *base.Request) error { Header: base.Header{ "CSeq": cseq, "Transport": th.Write(), - "Session": base.HeaderValue{"12345678"}, + "Session": base.HeaderValue{sessionId}, }, }) return nil @@ -642,15 +661,15 @@ func (c *Client) handleRequest(req *base.Request) error { // record case statePreRecord: - if th.Mode == nil || *th.Mode != gortsplib.TransportModeRecord { + if th.Mode == nil || *th.Mode != headers.TransportModeRecord { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("transport header does not contain mode=record")) - return errRunTerminate + return errStateTerminate } // after ANNOUNCE, c.path is already set if basePath != c.path.Name() { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("path has changed, was '%s', now is '%s'", c.path.Name(), basePath)) - return errRunTerminate + return errStateTerminate } // record with UDP @@ -662,17 +681,17 @@ func (c *Client) handleRequest(req *base.Request) error { if len(c.streamTracks) > 0 && c.streamProtocol != gortsplib.StreamProtocolUDP { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("can't publish tracks with different protocols")) - return errRunTerminate + return errStateTerminate } if th.ClientPorts == nil { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("transport header does not have valid client ports (%s)", req.Header["Transport"])) - return errRunTerminate + return errStateTerminate } if len(c.streamTracks) >= c.path.SourceTrackCount() { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("all the tracks have already been setup")) - return errRunTerminate + return errStateTerminate } c.streamProtocol = gortsplib.StreamProtocolUDP @@ -683,8 +702,8 @@ func (c *Client) handleRequest(req *base.Request) error { th := &headers.Transport{ Protocol: gortsplib.StreamProtocolUDP, - Cast: func() *gortsplib.StreamCast { - v := gortsplib.StreamUnicast + Delivery: func() *base.StreamDelivery { + v := base.StreamDeliveryUnicast return &v }(), ClientPorts: th.ClientPorts, @@ -696,7 +715,7 @@ func (c *Client) handleRequest(req *base.Request) error { Header: base.Header{ "CSeq": cseq, "Transport": th.Write(), - "Session": base.HeaderValue{"12345678"}, + "Session": base.HeaderValue{sessionId}, }, }) return nil @@ -710,24 +729,24 @@ func (c *Client) handleRequest(req *base.Request) error { if len(c.streamTracks) > 0 && c.streamProtocol != gortsplib.StreamProtocolTCP { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("can't publish tracks with different protocols")) - return errRunTerminate + return errStateTerminate } interleavedIds := [2]int{len(c.streamTracks) * 2, 1 + len(c.streamTracks)*2} if th.InterleavedIds == nil { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("transport header does not contain the interleaved field")) - return errRunTerminate + return errStateTerminate } if (*th.InterleavedIds)[0] != interleavedIds[0] || (*th.InterleavedIds)[1] != interleavedIds[1] { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("wrong interleaved ids, expected %v, got %v", interleavedIds, *th.InterleavedIds)) - return errRunTerminate + return errStateTerminate } if len(c.streamTracks) >= c.path.SourceTrackCount() { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("all the tracks have already been setup")) - return errRunTerminate + return errStateTerminate } c.streamProtocol = gortsplib.StreamProtocolTCP @@ -746,7 +765,7 @@ func (c *Client) handleRequest(req *base.Request) error { Header: base.Header{ "CSeq": cseq, "Transport": ht.Write(), - "Session": base.HeaderValue{"12345678"}, + "Session": base.HeaderValue{sessionId}, }, }) return nil @@ -754,20 +773,22 @@ func (c *Client) handleRequest(req *base.Request) error { default: c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("client is in state '%s'", c.state)) - return errRunTerminate + return errStateTerminate } case base.PLAY: - if c.state != statePrePlay { - c.writeResError(cseq, base.StatusBadRequest, - fmt.Errorf("client is in state '%s' instead of '%s'", c.state, statePrePlay)) - return errRunTerminate + err := c.checkState(map[state]struct{}{ + statePrePlay: {}, + }) + if err != nil { + c.writeResError(cseq, base.StatusBadRequest, err) + return errStateTerminate } basePath, ok := req.URL.BasePath() if !ok { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("unable to find base path (%s)", req.URL)) - return errRunTerminate + return errStateTerminate } // path can end with a slash, remove it @@ -775,12 +796,12 @@ func (c *Client) handleRequest(req *base.Request) error { if basePath != c.path.Name() { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("path has changed, was '%s', now is '%s'", c.path.Name(), basePath)) - return errRunTerminate + return errStateTerminate } if len(c.streamTracks) == 0 { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("no tracks have been setup")) - return errRunTerminate + return errStateTerminate } // write response before setting state @@ -790,23 +811,24 @@ func (c *Client) handleRequest(req *base.Request) error { StatusCode: base.StatusOK, Header: base.Header{ "CSeq": cseq, - "Session": base.HeaderValue{"12345678"}, + "Session": base.HeaderValue{sessionId}, }, }) - - return errRunPlay + return errStatePlay case base.RECORD: - if c.state != statePreRecord { - c.writeResError(cseq, base.StatusBadRequest, - fmt.Errorf("client is in state '%s' instead of '%s'", c.state, statePreRecord)) - return errRunTerminate + err := c.checkState(map[state]struct{}{ + statePreRecord: {}, + }) + if err != nil { + c.writeResError(cseq, base.StatusBadRequest, err) + return errStateTerminate } basePath, ok := req.URL.BasePath() if !ok { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("unable to find base path (%s)", req.URL)) - return errRunTerminate + return errStateTerminate } // path can end with a slash, remove it @@ -814,31 +836,49 @@ func (c *Client) handleRequest(req *base.Request) error { if basePath != c.path.Name() { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("path has changed, was '%s', now is '%s'", c.path.Name(), basePath)) - return errRunTerminate + return errStateTerminate } if len(c.streamTracks) != c.path.SourceTrackCount() { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("not all tracks have been setup")) - return errRunTerminate + return errStateTerminate } c.conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "CSeq": cseq, - "Session": base.HeaderValue{"12345678"}, + "Session": base.HeaderValue{sessionId}, }, }) + return errStateRecord - return errRunRecord + case base.PAUSE: + err := c.checkState(map[state]struct{}{ + statePlay: {}, + stateRecord: {}, + }) + if err != nil { + c.writeResError(cseq, base.StatusBadRequest, err) + return errStateTerminate + } + + c.conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "CSeq": cseq, + "Session": base.HeaderValue{sessionId}, + }, + }) + return errStateInitial case base.TEARDOWN: // close connection silently - return errRunTerminate + return errStateTerminate default: c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("unhandled method '%s'", req.Method)) - return errRunTerminate + return errStateTerminate } } @@ -863,18 +903,18 @@ func (c *Client) runInitial() bool { select { case err := <-readDone: switch err { - case errRunWaitingDescribe: + case errStateWaitingDescribe: return c.runWaitingDescribe() - case errRunPlay: + case errStatePlay: return c.runPlay() - case errRunRecord: + case errStateRecord: return c.runRecord() default: c.conn.Close() - if err != io.EOF && err != errRunTerminate { + if err != io.EOF && err != errStateTerminate { c.log("ERR: %s", err) } @@ -896,6 +936,8 @@ func (c *Client) runWaitingDescribe() bool { c.path.OnClientRemove(c) c.path = nil + close(c.describeData) + c.state = stateInitial if res.err != nil { @@ -926,18 +968,29 @@ func (c *Client) runWaitingDescribe() bool { return true case <-c.terminate: + ch := c.describeData go func() { - for range c.describeData { + for range ch { } }() + c.path.OnClientRemove(c) + c.path = nil + + close(c.describeData) + c.conn.Close() return false } } func (c *Client) runPlay() bool { + if c.streamProtocol == gortsplib.StreamProtocolTCP { + c.tcpFrame = make(chan *base.InterleavedFrame) + } + // start sending frames only after replying to the PLAY request + c.state = statePlay c.path.OnClientPlay(c) c.log("is reading from path '%s', %d %s with %s", c.path.Name(), len(c.streamTracks), func() string { @@ -955,20 +1008,21 @@ func (c *Client) runPlay() bool { }) } + var ret bool if c.streamProtocol == gortsplib.StreamProtocolUDP { - c.runPlayUDP() + ret = c.runPlayUDP() } else { - c.runPlayTCP() + ret = c.runPlayTCP() } if onReadCmd != nil { onReadCmd.Close() } - return false + return ret } -func (c *Client) runPlayUDP() { +func (c *Client) runPlayUDP() bool { readDone := make(chan error) go func() { for { @@ -988,23 +1042,36 @@ func (c *Client) runPlayUDP() { select { case err := <-readDone: - c.conn.Close() - if err != io.EOF && err != errRunTerminate { - c.log("ERR: %s", err) + if err == errStateInitial { + c.state = statePrePlay + c.path.OnClientPause(c) + return true + + } else { + c.path.OnClientRemove(c) + c.path = nil + + c.conn.Close() + if err != io.EOF && err != errStateTerminate { + c.log("ERR: %s", err) + } + + c.parent.OnClientClose(c) + <-c.terminate + return false } - c.parent.OnClientClose(c) - <-c.terminate - return - case <-c.terminate: + c.path.OnClientRemove(c) + c.path = nil + c.conn.Close() <-readDone - return + return false } } -func (c *Client) runPlayTCP() { +func (c *Client) runPlayTCP() bool { readRequest := make(chan readRequestPair) defer close(readRequest) @@ -1038,22 +1105,42 @@ func (c *Client) runPlayTCP() { // responses must be written in the same routine of frames case req := <-readRequest: req.res <- c.handleRequest(req.req) - close(req.res) case err := <-readDone: - c.conn.Close() - if err != io.EOF && err != errRunTerminate { - c.log("ERR: %s", err) - } + if err == errStateInitial { + ch := c.tcpFrame + go func() { + for range ch { + } + }() - go func() { - for range c.tcpFrame { + c.state = statePrePlay + c.path.OnClientPause(c) + + close(c.tcpFrame) + return true + + } else { + ch := c.tcpFrame + go func() { + for range ch { + } + }() + + c.path.OnClientRemove(c) + c.path = nil + + close(c.tcpFrame) + + c.conn.Close() + if err != io.EOF && err != errStateTerminate { + c.log("ERR: %s", err) } - }() - c.parent.OnClientClose(c) - <-c.terminate - return + c.parent.OnClientClose(c) + <-c.terminate + return false + } case frame := <-c.tcpFrame: c.conn.WriteFrameTCP(frame.TrackId, frame.StreamType, frame.Content) @@ -1065,19 +1152,26 @@ func (c *Client) runPlayTCP() { } }() + ch := c.tcpFrame go func() { - for range c.tcpFrame { + for range ch { } }() + c.path.OnClientRemove(c) + c.path = nil + + close(c.tcpFrame) + c.conn.Close() <-readDone - return + return false } } } func (c *Client) runRecord() bool { + c.state = stateRecord c.path.OnClientRecord(c) c.log("is publishing to path '%s', %d %s with %s", c.path.Name(), len(c.streamTracks), func() string { @@ -1132,27 +1226,21 @@ func (c *Client) runRecord() bool { }) } + var ret bool if c.streamProtocol == gortsplib.StreamProtocolUDP { - c.runRecordUDP() + ret = c.runRecordUDP() } else { - c.runRecordTCP() + ret = c.runRecordTCP() } if onPublishCmd != nil { onPublishCmd.Close() } - if c.streamProtocol == gortsplib.StreamProtocolUDP { - for _, track := range c.streamTracks { - c.serverUdpRtp.RemovePublisher(c.ip(), track.rtpPort, c) - c.serverUdpRtcp.RemovePublisher(c.ip(), track.rtcpPort, c) - } - } - - return false + return ret } -func (c *Client) runRecordUDP() { +func (c *Client) runRecordUDP() bool { readDone := make(chan error) go func() { for { @@ -1179,14 +1267,34 @@ func (c *Client) runRecordUDP() { for { select { case err := <-readDone: - c.conn.Close() - if err != io.EOF && err != errRunTerminate { - c.log("ERR: %s", err) - } + if err == errStateInitial { + for _, track := range c.streamTracks { + c.serverUdpRtp.RemovePublisher(c.ip(), track.rtpPort, c) + c.serverUdpRtcp.RemovePublisher(c.ip(), track.rtcpPort, c) + } - c.parent.OnClientClose(c) - <-c.terminate - return + c.state = statePreRecord + c.path.OnClientPause(c) + return true + + } else { + for _, track := range c.streamTracks { + c.serverUdpRtp.RemovePublisher(c.ip(), track.rtpPort, c) + c.serverUdpRtcp.RemovePublisher(c.ip(), track.rtcpPort, c) + } + + c.path.OnClientRemove(c) + c.path = nil + + c.conn.Close() + if err != io.EOF && err != errStateTerminate { + c.log("ERR: %s", err) + } + + c.parent.OnClientClose(c) + <-c.terminate + return false + } case <-checkStreamTicker.C: now := time.Now() @@ -1195,13 +1303,21 @@ func (c *Client) runRecordUDP() { last := time.Unix(atomic.LoadInt64(lastUnix), 0) if now.Sub(last) >= c.readTimeout { + for _, track := range c.streamTracks { + c.serverUdpRtp.RemovePublisher(c.ip(), track.rtpPort, c) + c.serverUdpRtcp.RemovePublisher(c.ip(), track.rtcpPort, c) + } + + c.path.OnClientRemove(c) + c.path = nil + c.log("ERR: no packets received recently (maybe there's a firewall/NAT in between)") c.conn.Close() <-readDone c.parent.OnClientClose(c) <-c.terminate - return + return false } } @@ -1216,14 +1332,22 @@ func (c *Client) runRecordUDP() { } case <-c.terminate: + for _, track := range c.streamTracks { + c.serverUdpRtp.RemovePublisher(c.ip(), track.rtpPort, c) + c.serverUdpRtcp.RemovePublisher(c.ip(), track.rtcpPort, c) + } + + c.path.OnClientRemove(c) + c.path = nil + c.conn.Close() <-readDone - return + return false } } } -func (c *Client) runRecordTCP() { +func (c *Client) runRecordTCP() bool { readRequest := make(chan readRequestPair) defer close(readRequest) @@ -1266,14 +1390,24 @@ func (c *Client) runRecordTCP() { req.res <- c.handleRequest(req.req) case err := <-readDone: - c.conn.Close() - if err != io.EOF && err != errRunTerminate { - c.log("ERR: %s", err) - } + if err == errStateInitial { + c.state = statePreRecord + c.path.OnClientPause(c) + return true - c.parent.OnClientClose(c) - <-c.terminate - return + } else { + c.path.OnClientRemove(c) + c.path = nil + + c.conn.Close() + if err != io.EOF && err != errStateTerminate { + c.log("ERR: %s", err) + } + + c.parent.OnClientClose(c) + <-c.terminate + return false + } case <-receiverReportTicker.C: for trackId := range c.streamTracks { @@ -1288,9 +1422,12 @@ func (c *Client) runRecordTCP() { } }() + c.path.OnClientRemove(c) + c.path = nil + c.conn.Close() <-readDone - return + return false } } } diff --git a/internal/clientman/clientman.go b/internal/clientman/clientman.go index 163227c8..41279e63 100644 --- a/internal/clientman/clientman.go +++ b/internal/clientman/clientman.go @@ -6,7 +6,6 @@ import ( "github.com/aler9/gortsplib" "github.com/aler9/gortsplib/base" - "github.com/aler9/gortsplib/headers" "github.com/aler9/rtsp-simple-server/internal/client" "github.com/aler9/rtsp-simple-server/internal/pathman" @@ -27,7 +26,7 @@ type ClientManager struct { writeTimeout time.Duration runOnConnect string runOnConnectRestart bool - protocols map[headers.StreamProtocol]struct{} + protocols map[base.StreamProtocol]struct{} stats *stats.Stats serverUdpRtp *serverudp.Server serverUdpRtcp *serverudp.Server @@ -53,7 +52,7 @@ func New( writeTimeout time.Duration, runOnConnect string, runOnConnectRestart bool, - protocols map[headers.StreamProtocol]struct{}, + protocols map[base.StreamProtocol]struct{}, stats *stats.Stats, serverUdpRtp *serverudp.Server, serverUdpRtcp *serverudp.Server, diff --git a/internal/conf/conf.go b/internal/conf/conf.go index 38ef44e2..5a403196 100644 --- a/internal/conf/conf.go +++ b/internal/conf/conf.go @@ -13,7 +13,7 @@ import ( "github.com/aler9/rtsp-simple-server/internal/loghandler" ) -// Conf is the program configuration. +// Conf is the main program configuration. type Conf struct { Protocols []string `yaml:"protocols"` ProtocolsParsed map[gortsplib.StreamProtocol]struct{} `yaml:"-" json:"-"` diff --git a/internal/path/path.go b/internal/path/path.go index fcdee065..66630484 100644 --- a/internal/path/path.go +++ b/internal/path/path.go @@ -113,6 +113,11 @@ type clientRecordReq struct { client *client.Client } +type clientPauseReq struct { + res chan struct{} + client *client.Client +} + type clientState int const ( @@ -170,6 +175,7 @@ type Path struct { clientSetupPlay chan ClientSetupPlayReq // from program clientPlay chan clientPlayReq // from client clientRecord chan clientRecordReq // from client + clientPause chan clientPauseReq // from client clientRemove chan clientRemoveReq // from client terminate chan struct{} } @@ -209,6 +215,7 @@ func New( clientSetupPlay: make(chan ClientSetupPlayReq), clientPlay: make(chan clientPlayReq), clientRecord: make(chan clientRecordReq), + clientPause: make(chan clientPauseReq), clientRemove: make(chan clientRemoveReq), terminate: make(chan struct{}), } @@ -326,6 +333,10 @@ outer: pa.onClientRecord(req.client) close(req.res) + case req := <-pa.clientPause: + pa.onClientPause(req.client) + close(req.res) + case req := <-pa.clientRemove: if _, ok := pa.clients[req.client]; !ok { close(req.res) @@ -389,6 +400,7 @@ outer: close(pa.clientSetupPlay) close(pa.clientPlay) close(pa.clientRecord) + close(pa.clientPause) close(pa.clientRemove) } @@ -436,6 +448,12 @@ func (pa *Path) exhaustChannels() { } close(req.res) + case req, ok := <-pa.clientPause: + if !ok { + return + } + close(req.res) + case req, ok := <-pa.clientRemove: if !ok { return @@ -677,6 +695,7 @@ func (pa *Path) onClientPlay(c *client.Client) { atomic.AddInt64(pa.stats.CountReaders, 1) pa.clients[c] = clientStatePlay + pa.readers.add(c) } @@ -713,6 +732,26 @@ func (pa *Path) onClientRecord(c *client.Client) { pa.onSourceSetReady() } +func (pa *Path) onClientPause(c *client.Client) { + state, ok := pa.clients[c] + if !ok { + return + } + + if state == clientStatePlay { + atomic.AddInt64(pa.stats.CountReaders, -1) + pa.clients[c] = clientStatePrePlay + + pa.readers.remove(c) + + } else if state == clientStateRecord { + atomic.AddInt64(pa.stats.CountPublishers, -1) + pa.clients[c] = clientStatePreRecord + + pa.onSourceSetNotReady() + } +} + func (pa *Path) scheduleSourceClose() { if !pa.hasExternalSource() || !pa.conf.SourceOnDemand || pa.source == nil { return @@ -826,6 +865,13 @@ func (pa *Path) OnClientRecord(c *client.Client) { <-res } +// OnClientPause is called by client.Client. +func (pa *Path) OnClientPause(c *client.Client) { + res := make(chan struct{}) + pa.clientPause <- clientPauseReq{res, c} + <-res +} + // OnFrame is called by a source or by a client.Client. func (pa *Path) OnFrame(trackId int, streamType gortsplib.StreamType, buf []byte) { pa.readers.forwardFrame(trackId, streamType, buf) diff --git a/main_test.go b/main_test.go index 49fc2a70..fd34186b 100644 --- a/main_test.go +++ b/main_test.go @@ -288,6 +288,77 @@ func TestPublish(t *testing.T) { } } +func TestPublishPause(t *testing.T) { + for _, conf := range []struct { + proto string + }{ + {"udp"}, + {"tcp"}, + } { + t.Run(conf.proto, func(t *testing.T) { + p, err := testProgram("") + require.NoError(t, err) + defer p.close() + + time.Sleep(1 * time.Second) + + track, err := gortsplib.NewTrackH264(0, []byte("\x00\x00\x00\x00\x00"), + []byte("\x00\x00\x00\x00\x00")) + require.NoError(t, err) + + switch conf.proto { + case "udp": + conn, err := gortsplib.DialPublish("rtsp://"+ownDockerIp+":8554/teststream", + gortsplib.StreamProtocolUDP, gortsplib.Tracks{track}) + require.NoError(t, err) + defer conn.Close() + + for i := 0; i < 2; i++ { + err := conn.WriteFrameUDP(track.Id, gortsplib.StreamTypeRtp, []byte("\x00\x00\x00\x00")) + require.NoError(t, err) + } + + _, err = conn.Pause() + require.NoError(t, err) + + time.Sleep(1 * time.Second) + + _, err = conn.Record() + require.NoError(t, err) + + for i := 0; i < 2; i++ { + err := conn.WriteFrameUDP(track.Id, gortsplib.StreamTypeRtp, []byte("\x00\x00\x00\x00")) + require.NoError(t, err) + } + + case "tcp": + conn, err := gortsplib.DialPublish("rtsp://"+ownDockerIp+":8554/teststream", + gortsplib.StreamProtocolTCP, gortsplib.Tracks{track}) + require.NoError(t, err) + defer conn.Close() + + for i := 0; i < 2; i++ { + err := conn.WriteFrameTCP(track.Id, gortsplib.StreamTypeRtp, []byte("\x00\x00\x00\x00")) + require.NoError(t, err) + } + + _, err = conn.Pause() + require.NoError(t, err) + + time.Sleep(1 * time.Second) + + _, err = conn.Record() + require.NoError(t, err) + + for i := 0; i < 2; i++ { + err := conn.WriteFrameTCP(track.Id, gortsplib.StreamTypeRtp, []byte("\x00\x00\x00\x00")) + require.NoError(t, err) + } + } + }) + } +} + func TestRead(t *testing.T) { for _, conf := range []struct { readSoft string @@ -352,6 +423,87 @@ func TestRead(t *testing.T) { } } +func TestReadPause(t *testing.T) { + for _, conf := range []struct { + proto string + }{ + {"udp"}, + {"tcp"}, + } { + t.Run(conf.proto, func(t *testing.T) { + p, err := testProgram("") + require.NoError(t, err) + defer p.close() + + time.Sleep(1 * time.Second) + + cnt1, err := newContainer("ffmpeg", "source", []string{ + "-re", + "-stream_loop", "-1", + "-i", "/emptyvideo.ts", + "-c", "copy", + "-f", "rtsp", + "-rtsp_transport", "udp", + "rtsp://" + ownDockerIp + ":8554/teststream", + }) + require.NoError(t, err) + defer cnt1.close() + + time.Sleep(1 * time.Second) + + switch conf.proto { + case "udp": + conn, err := gortsplib.DialRead("rtsp://"+ownDockerIp+":8554/teststream", + gortsplib.StreamProtocolUDP) + require.NoError(t, err) + defer conn.Close() + + for i := 0; i < 2; i++ { + _, err := conn.ReadFrameUDP(0, gortsplib.StreamTypeRtp) + require.NoError(t, err) + } + + _, err = conn.Pause() + require.NoError(t, err) + + time.Sleep(1 * time.Second) + + _, err = conn.Play() + require.NoError(t, err) + + for i := 0; i < 2; i++ { + _, err := conn.ReadFrameUDP(0, gortsplib.StreamTypeRtp) + require.NoError(t, err) + } + + case "tcp": + conn, err := gortsplib.DialRead("rtsp://"+ownDockerIp+":8554/teststream", + gortsplib.StreamProtocolTCP) + require.NoError(t, err) + defer conn.Close() + + for i := 0; i < 2; i++ { + _, _, _, err := conn.ReadFrameTCP() + require.NoError(t, err) + } + + _, err = conn.Pause() + require.NoError(t, err) + + time.Sleep(1 * time.Second) + + _, err = conn.Play() + require.NoError(t, err) + + for i := 0; i < 2; i++ { + _, _, _, err := conn.ReadFrameTCP() + require.NoError(t, err) + } + } + }) + } +} + func TestTCPOnly(t *testing.T) { p, err := testProgram("protocols: [tcp]\n") require.NoError(t, err)