diff --git a/go.mod b/go.mod index 30e408b7..ca475276 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.15 require ( github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 // indirect github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d // indirect - github.com/aler9/gortsplib v0.0.0-20201208105438-07aefbcd5d11 + github.com/aler9/gortsplib v0.0.0-20201212222949-4c942d33fed8 github.com/davecgh/go-spew v1.1.1 // indirect github.com/fsnotify/fsnotify v1.4.9 github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 diff --git a/go.sum b/go.sum index 6fc17079..44bba788 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafo github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d h1:UQZhZ2O0vMHr2cI+DC1Mbh0TJxzA3RcLoMsFw+aXw7E= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= -github.com/aler9/gortsplib v0.0.0-20201208105438-07aefbcd5d11 h1:as97tV7XyNJurmD1e3iT0AcgxeIwRa+nwMm10gi0vO0= -github.com/aler9/gortsplib v0.0.0-20201208105438-07aefbcd5d11/go.mod h1:8P09VjpiPJFyfkVosyF5/TY82jNwkMN165NS/7sc32I= +github.com/aler9/gortsplib v0.0.0-20201212222949-4c942d33fed8 h1:nDEZtoFBPDPgu9wxujoTEmMXFNTg+d0ATYKSgGHtsgE= +github.com/aler9/gortsplib v0.0.0-20201212222949-4c942d33fed8/go.mod h1:8P09VjpiPJFyfkVosyF5/TY82jNwkMN165NS/7sc32I= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= diff --git a/internal/client/client.go b/internal/client/client.go index 3c1d783b..71b00476 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -30,11 +30,6 @@ const ( sessionID = "12345678" ) -type readReq struct { - req *base.Request - res chan bool -} - type streamTrack struct { rtpPort int rtcpPort int @@ -50,7 +45,6 @@ type state int const ( stateInitial state = iota - stateWaitingDescribe statePrePlay statePlay statePreRecord @@ -61,8 +55,6 @@ func (s state) String() string { switch s { case stateInitial: return "initial" - case stateWaitingDescribe: - return "waitingDescribe" case statePrePlay: return "prePlay" case statePlay: @@ -120,14 +112,17 @@ type Client struct { streamTracks map[int]*streamTrack rtcpReceivers map[int]*rtcpreceiver.RtcpReceiver udpLastFrameTimes []*int64 - describeCSeq base.HeaderValue - describeURL string - tcpWriteMutex sync.Mutex - tcpWriteOk bool + writeFrameEnable bool + writeFrameMutex sync.Mutex + onReadCmd *externalcmd.Cmd + onPublishCmd *externalcmd.Cmd // in describeData chan describeData // from path terminate chan struct{} + + backgroundRecordTerminate chan struct{} + backgroundRecordDone chan struct{} } // New allocates a Client. @@ -191,11 +186,7 @@ func (c *Client) zone() string { return c.conn.NetConn().RemoteAddr().(*net.TCPAddr).Zone } -var errStateTerminate = errors.New("terminate") -var errStateWaitingDescribe = errors.New("wait description") -var errStatePlay = errors.New("play") -var errStateRecord = errors.New("record") -var errStateInitial = errors.New("initial") +var errTerminated = errors.New("terminated") func (c *Client) run() { defer c.wg.Done() @@ -209,15 +200,47 @@ func (c *Client) run() { defer onConnectCmd.Close() } - for { - if !c.runInitial() { - break - } - } + readDone := c.conn.Read(c.onRequest, c.onFrame) - if c.path != nil { - c.path.OnClientRemove(c) - c.path = nil + select { + case err := <-readDone: + c.conn.Close() + if err != io.EOF && err != errTerminated { + c.log(logger.Info, "ERR: %s", err) + } + + switch c.state { + case statePlay: + c.stopPlay() + + case stateRecord: + c.stopRecord() + } + + if c.path != nil { + c.path.OnClientRemove(c) + c.path = nil + } + + c.parent.OnClientClose(c) + <-c.terminate + + case <-c.terminate: + c.conn.Close() + <-readDone + + switch c.state { + case statePlay: + c.stopPlay() + + case stateRecord: + c.stopRecord() + } + + if c.path != nil { + c.path.OnClientRemove(c) + c.path = nil + } } } @@ -248,9 +271,6 @@ func (c *Client) Authenticate(authMethods []headers.AuthMethod, ips []interface{ return errAuthCritical{&base.Response{ StatusCode: base.StatusUnauthorized, - Header: base.Header{ - "CSeq": req.Header["CSeq"], - }, }} } } @@ -280,7 +300,6 @@ func (c *Client) Authenticate(authMethods []headers.AuthMethod, ips []interface{ return errAuthCritical{&base.Response{ StatusCode: base.StatusUnauthorized, Header: base.Header{ - "CSeq": req.Header["CSeq"], "WWW-Authenticate": c.authHelper.GenerateHeader(), }, }} @@ -293,7 +312,6 @@ func (c *Client) Authenticate(authMethods []headers.AuthMethod, ips []interface{ return errAuthNotCritical{&base.Response{ StatusCode: base.StatusUnauthorized, Header: base.Header{ - "CSeq": req.Header["CSeq"], "WWW-Authenticate": c.authHelper.GenerateHeader(), }, }} @@ -315,41 +333,24 @@ func (c *Client) checkState(allowed map[state]struct{}) error { for s := range allowed { allowedList = append(allowedList, s) } + return fmt.Errorf("client must be in state %v, while is in state %v", allowedList, c.state) } -func (c *Client) writeRes(res *base.Response) { - c.log(logger.Debug, "s->c %v", res) - c.conn.WriteResponse(res) -} - -func (c *Client) writeResError(cseq base.HeaderValue, code base.StatusCode, err error) { - c.log(logger.Info, "ERR: %s", err) - - c.writeRes(&base.Response{ - StatusCode: code, - Header: base.Header{ - "CSeq": cseq, - }, - }) -} - -func (c *Client) handleRequest(req *base.Request) error { +func (c *Client) onRequest(req *base.Request) (*base.Response, error) { c.log(logger.Debug, "[c->s] %v", req) + res, err := c.onRequestInner(req) + c.log(logger.Debug, "[s->c] %v", res) + return res, err +} - cseq, ok := req.Header["CSeq"] - if !ok || len(cseq) != 1 { - c.writeResError(nil, base.StatusBadRequest, fmt.Errorf("cseq missing")) - return errStateTerminate - } - +func (c *Client) onRequestInner(req *base.Request) (*base.Response, error) { switch req.Method { case base.Options: - c.writeRes(&base.Response{ + return &base.Response{ StatusCode: base.StatusOK, Header: base.Header{ - "CSeq": cseq, "Public": base.HeaderValue{strings.Join([]string{ string(base.GetParameter), string(base.Describe), @@ -361,34 +362,33 @@ func (c *Client) handleRequest(req *base.Request) error { string(base.Teardown), }, ", ")}, }, - }) - return nil + }, nil // GET_PARAMETER is used like a ping case base.GetParameter: - c.writeRes(&base.Response{ + return &base.Response{ StatusCode: base.StatusOK, Header: base.Header{ - "CSeq": cseq, "Content-Type": base.HeaderValue{"text/parameters"}, }, Content: []byte("\n"), - }) - return nil + }, nil case base.Describe: err := c.checkState(map[state]struct{}{ stateInitial: {}, }) if err != nil { - c.writeResError(cseq, base.StatusBadRequest, err) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, err } basePath, ok := req.URL.BasePath() if !ok { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("unable to find base path (%s)", req.URL)) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("unable to find base path (%s)", req.URL) } c.describeData = make(chan describeData) @@ -397,87 +397,123 @@ func (c *Client) handleRequest(req *base.Request) error { if err != nil { switch terr := err.(type) { case errAuthNotCritical: - close(c.describeData) - c.writeRes(terr.Response) - return nil + return terr.Response, nil case errAuthCritical: - close(c.describeData) - c.writeRes(terr.Response) - return errStateTerminate + return terr.Response, errTerminated default: - c.writeResError(cseq, base.StatusBadRequest, err) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, err } } c.path = path - c.state = stateWaitingDescribe - c.describeCSeq = cseq - c.describeURL = req.URL.String() - return errStateWaitingDescribe + select { + case res := <-c.describeData: + c.path.OnClientRemove(c) + c.path = nil + + if res.err != nil { + c.log(logger.Info, "no one is publishing to path '%s'", basePath) + return &base.Response{ + StatusCode: base.StatusNotFound, + }, nil + } + + if res.redirect != "" { + return &base.Response{ + StatusCode: base.StatusMovedPermanently, + Header: base.Header{ + "Location": base.HeaderValue{res.redirect}, + }, + }, nil + } + + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Content-Base": base.HeaderValue{req.URL.String() + "/"}, + "Content-Type": base.HeaderValue{"application/sdp"}, + }, + Content: res.sdp, + }, nil + + case <-c.terminate: + ch := c.describeData + go func() { + for range ch { + } + }() + + c.path.OnClientRemove(c) + c.path = nil + + close(c.describeData) + + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, errTerminated + } case base.Announce: err := c.checkState(map[state]struct{}{ stateInitial: {}, }) if err != nil { - c.writeResError(cseq, base.StatusBadRequest, err) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, err } basePath, ok := req.URL.BasePath() if !ok { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("unable to find base path (%s)", req.URL)) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("unable to find base path (%s)", req.URL) } ct, ok := req.Header["Content-Type"] if !ok || len(ct) != 1 { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("Content-Type header missing")) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("Content-Type header missing") } if ct[0] != "application/sdp" { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("unsupported Content-Type '%s'", ct)) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("unsupported Content-Type '%s'", ct) } tracks, err := gortsplib.ReadTracks(req.Content) if err != nil { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("invalid SDP: %s", err)) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("invalid SDP: %s", err) } if len(tracks) == 0 { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("no tracks defined")) - return errStateTerminate - } - - for trackID, t := range tracks { - _, err := t.ClockRate() - if err != nil { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("unable to get clock rate of track %d", trackID)) - return errStateTerminate - } + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("no tracks defined") } path, err := c.parent.OnClientAnnounce(c, basePath, tracks, req) if err != nil { switch terr := err.(type) { case errAuthNotCritical: - c.writeRes(terr.Response) - return nil + return terr.Response, nil case errAuthCritical: - c.writeRes(terr.Response) - return errStateTerminate + return terr.Response, errTerminated default: - c.writeResError(cseq, base.StatusBadRequest, err) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, err } } @@ -489,93 +525,99 @@ func (c *Client) handleRequest(req *base.Request) error { c.path = path c.state = statePreRecord - c.writeRes(&base.Response{ + return &base.Response{ StatusCode: base.StatusOK, - Header: base.Header{ - "CSeq": cseq, - }, - }) - return nil + }, nil case base.Setup: th, err := headers.ReadTransport(req.Header["Transport"]) if err != nil { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("transport header: %s", err)) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("transport header: %s", err) } if th.Delivery != nil && *th.Delivery == base.StreamDeliveryMulticast { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("multicast is not supported")) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("multicast is not supported") } basePath, controlPath, ok := req.URL.BasePathControlAttr() if !ok { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("unable to find control attribute (%s)", req.URL)) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("unable to find control attribute (%s)", req.URL) } switch c.state { // play case stateInitial, statePrePlay: if th.Mode != nil && *th.Mode != headers.TransportModePlay { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("transport header must contain mode=play or not contain a mode")) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("transport header must contain mode=play or not contain a mode") } if c.path != nil && basePath != c.path.Name() { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("path has changed, was '%s', now is '%s'", c.path.Name(), basePath)) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("path has changed, was '%s', now is '%s'", c.path.Name(), basePath) } if !strings.HasPrefix(controlPath, "trackID=") { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("invalid control attribute (%s)", controlPath)) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("invalid control attribute (%s)", controlPath) } tmp, err := strconv.ParseInt(controlPath[len("trackID="):], 10, 64) if err != nil || tmp < 0 { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("invalid track id (%s)", controlPath)) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("invalid track id (%s)", controlPath) } trackID := int(tmp) if _, ok := c.streamTracks[trackID]; ok { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("track %d has already been setup", trackID)) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("track %d has already been setup", trackID) } // play with UDP if th.Protocol == gortsplib.StreamProtocolUDP { if _, ok := c.protocols[gortsplib.StreamProtocolUDP]; !ok { - c.writeResError(cseq, base.StatusUnsupportedTransport, fmt.Errorf("UDP streaming is disabled")) - return nil + return &base.Response{ + StatusCode: base.StatusUnsupportedTransport, + }, nil } if len(c.streamTracks) > 0 && c.streamProtocol != gortsplib.StreamProtocolUDP { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("can't receive tracks with different protocols")) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("can't receive tracks with different protocols") } if th.ClientPorts == nil { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("transport header does not have valid client ports (%v)", req.Header["Transport"])) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("transport header does not have valid client ports (%v)", req.Header["Transport"]) } path, err := c.parent.OnClientSetupPlay(c, basePath, trackID, req) if err != nil { switch terr := err.(type) { case errAuthNotCritical: - c.writeRes(terr.Response) - return nil + return terr.Response, nil case errAuthCritical: - c.writeRes(terr.Response) - return errStateTerminate + return terr.Response, errTerminated default: - c.writeResError(cseq, base.StatusBadRequest, err) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, err } } @@ -598,42 +640,42 @@ func (c *Client) handleRequest(req *base.Request) error { ServerPorts: &[2]int{c.serverUDPRtp.Port(), c.serverUDPRtcp.Port()}, } - c.writeRes(&base.Response{ + return &base.Response{ StatusCode: base.StatusOK, Header: base.Header{ - "CSeq": cseq, "Transport": th.Write(), "Session": base.HeaderValue{sessionID}, }, - }) - return nil + }, nil } // play with TCP + if _, ok := c.protocols[gortsplib.StreamProtocolTCP]; !ok { - c.writeResError(cseq, base.StatusUnsupportedTransport, fmt.Errorf("TCP streaming is disabled")) - return nil + return &base.Response{ + StatusCode: base.StatusUnsupportedTransport, + }, nil } if len(c.streamTracks) > 0 && c.streamProtocol != gortsplib.StreamProtocolTCP { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("can't receive tracks with different protocols")) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("can't receive tracks with different protocols") } path, err := c.parent.OnClientSetupPlay(c, basePath, trackID, req) if err != nil { switch terr := err.(type) { case errAuthNotCritical: - c.writeRes(terr.Response) - return nil + return terr.Response, nil case errAuthCritical: - c.writeRes(terr.Response) - return errStateTerminate + return terr.Response, errTerminated default: - c.writeResError(cseq, base.StatusBadRequest, err) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, err } } @@ -653,49 +695,53 @@ func (c *Client) handleRequest(req *base.Request) error { InterleavedIds: &interleavedIds, } - c.writeRes(&base.Response{ + return &base.Response{ StatusCode: base.StatusOK, Header: base.Header{ - "CSeq": cseq, "Transport": th.Write(), "Session": base.HeaderValue{sessionID}, }, - }) - return nil + }, nil // record case statePreRecord: if th.Mode == nil || *th.Mode != headers.TransportModeRecord { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("transport header does not contain mode=record")) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("transport header does not contain mode=record") } // after ANNOUNCE, c.path is already set if basePath != c.path.Name() { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("path has changed, was '%s', now is '%s'", c.path.Name(), basePath)) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("path has changed, was '%s', now is '%s'", c.path.Name(), basePath) } // record with UDP if th.Protocol == gortsplib.StreamProtocolUDP { if _, ok := c.protocols[gortsplib.StreamProtocolUDP]; !ok { - c.writeResError(cseq, base.StatusUnsupportedTransport, fmt.Errorf("UDP streaming is disabled")) - return nil + return &base.Response{ + StatusCode: base.StatusUnsupportedTransport, + }, nil } if len(c.streamTracks) > 0 && c.streamProtocol != gortsplib.StreamProtocolUDP { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("can't publish tracks with different protocols")) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("can't publish tracks with different protocols") } if th.ClientPorts == nil { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("transport header does not have valid client ports (%s)", req.Header["Transport"])) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("transport header does not have valid client ports (%s)", req.Header["Transport"]) } if len(c.streamTracks) >= c.path.SourceTrackCount() { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("all the tracks have already been setup")) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("all the tracks have already been setup") } c.streamProtocol = gortsplib.StreamProtocolUDP @@ -715,43 +761,46 @@ func (c *Client) handleRequest(req *base.Request) error { ServerPorts: &[2]int{c.serverUDPRtp.Port(), c.serverUDPRtcp.Port()}, } - c.writeRes(&base.Response{ + return &base.Response{ StatusCode: base.StatusOK, Header: base.Header{ - "CSeq": cseq, "Transport": th.Write(), "Session": base.HeaderValue{sessionID}, }, - }) - return nil + }, nil } // record with TCP if _, ok := c.protocols[gortsplib.StreamProtocolTCP]; !ok { - c.writeResError(cseq, base.StatusUnsupportedTransport, fmt.Errorf("TCP streaming is disabled")) - return nil + return &base.Response{ + StatusCode: base.StatusUnsupportedTransport, + }, nil } if len(c.streamTracks) > 0 && c.streamProtocol != gortsplib.StreamProtocolTCP { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("can't publish tracks with different protocols")) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("can't publish tracks with different protocols") } interleavedIds := [2]int{len(c.streamTracks) * 2, 1 + len(c.streamTracks)*2} if th.InterleavedIds == nil { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("transport header does not contain the interleaved field")) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("transport header does not contain the interleaved field") } if (*th.InterleavedIds)[0] != interleavedIds[0] || (*th.InterleavedIds)[1] != interleavedIds[1] { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("wrong interleaved ids, expected %v, got %v", interleavedIds, *th.InterleavedIds)) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("wrong interleaved ids, expected %v, got %v", interleavedIds, *th.InterleavedIds) } if len(c.streamTracks) >= c.path.SourceTrackCount() { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("all the tracks have already been setup")) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("all the tracks have already been setup") } c.streamProtocol = gortsplib.StreamProtocolTCP @@ -766,19 +815,18 @@ func (c *Client) handleRequest(req *base.Request) error { InterleavedIds: &interleavedIds, } - c.writeRes(&base.Response{ + return &base.Response{ StatusCode: base.StatusOK, Header: base.Header{ - "CSeq": cseq, "Transport": ht.Write(), "Session": base.HeaderValue{sessionID}, }, - }) - return nil + }, nil default: - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("client is in state '%s'", c.state)) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("client is in state '%s'", c.state) } case base.Play: @@ -788,83 +836,84 @@ func (c *Client) handleRequest(req *base.Request) error { statePlay: {}, }) if err != nil { - c.writeResError(cseq, base.StatusBadRequest, err) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, err } if c.state == statePrePlay { basePath, ok := req.URL.BasePath() if !ok { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("unable to find base path (%s)", req.URL)) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("unable to find base path (%s)", req.URL) } // path can end with a slash, remove it basePath = strings.TrimSuffix(basePath, "/") if basePath != c.path.Name() { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("path has changed, was '%s', now is '%s'", c.path.Name(), basePath)) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("path has changed, was '%s', now is '%s'", c.path.Name(), basePath) } if len(c.streamTracks) == 0 { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("no tracks have been setup")) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("no tracks have been setup") } } - // write response before setting state - // otherwise, in case of TCP connections, RTP packets could be sent - // before the response - c.writeRes(&base.Response{ + c.startPlay() + + return &base.Response{ StatusCode: base.StatusOK, Header: base.Header{ - "CSeq": cseq, "Session": base.HeaderValue{sessionID}, }, - }) - - if c.state == statePrePlay { - return errStatePlay - } - return nil + }, nil case base.Record: err := c.checkState(map[state]struct{}{ statePreRecord: {}, }) if err != nil { - c.writeResError(cseq, base.StatusBadRequest, err) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, err } basePath, ok := req.URL.BasePath() if !ok { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("unable to find base path (%s)", req.URL)) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("unable to find base path (%s)", req.URL) } // path can end with a slash, remove it basePath = strings.TrimSuffix(basePath, "/") if basePath != c.path.Name() { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("path has changed, was '%s', now is '%s'", c.path.Name(), basePath)) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("path has changed, was '%s', now is '%s'", c.path.Name(), basePath) } if len(c.streamTracks) != c.path.SourceTrackCount() { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("not all tracks have been setup")) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("not all tracks have been setup") } - c.writeRes(&base.Response{ + c.startRecord() + + return &base.Response{ StatusCode: base.StatusOK, Header: base.Header{ - "CSeq": cseq, "Session": base.HeaderValue{sessionID}, }, - }) - return errStateRecord + }, nil case base.Pause: err := c.checkState(map[state]struct{}{ @@ -874,141 +923,56 @@ func (c *Client) handleRequest(req *base.Request) error { stateRecord: {}, }) if err != nil { - c.writeResError(cseq, base.StatusBadRequest, err) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, err } - c.writeRes(&base.Response{ + switch c.state { + case statePlay: + c.stopPlay() + c.state = statePrePlay + + case stateRecord: + c.stopRecord() + c.state = statePreRecord + } + + return &base.Response{ StatusCode: base.StatusOK, Header: base.Header{ - "CSeq": cseq, "Session": base.HeaderValue{sessionID}, }, - }) - - if c.state == statePlay || c.state == stateRecord { - return errStateInitial - } - return nil + }, nil case base.Teardown: - // close connection silently - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusOK, + }, errTerminated default: - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("unhandled method '%s'", req.Method)) - return errStateTerminate + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("unhandled method '%s'", req.Method) } } -func (c *Client) runInitial() bool { - readerDone := make(chan error) - go func() { - for { - req, err := c.conn.ReadRequest() - if err != nil { - readerDone <- err - return - } - - err = c.handleRequest(req) - if err != nil { - readerDone <- err - return - } - } - }() - - select { - case err := <-readerDone: - switch err { - case errStateWaitingDescribe: - return c.runWaitingDescribe() - - case errStatePlay: - return c.runPlay() - - case errStateRecord: - return c.runRecord() - - default: - c.conn.Close() - if err != io.EOF && err != errStateTerminate { - c.log(logger.Info, "ERR: %s", err) - } - - c.parent.OnClientClose(c) - <-c.terminate - return false +func (c *Client) onFrame(trackID int, streamType gortsplib.StreamType, content []byte) { + if c.state == stateRecord { + if trackID >= len(c.streamTracks) { + return } - case <-c.terminate: - c.conn.Close() - <-readerDone - return false + c.rtcpReceivers[trackID].ProcessFrame(time.Now(), streamType, content) + c.path.OnFrame(trackID, streamType, content) } } -func (c *Client) runWaitingDescribe() bool { - select { - case res := <-c.describeData: - c.path.OnClientRemove(c) - c.path = nil - - close(c.describeData) - - c.state = stateInitial - - if res.err != nil { - c.writeResError(c.describeCSeq, base.StatusNotFound, res.err) - return true - } - - if res.redirect != "" { - c.writeRes(&base.Response{ - StatusCode: base.StatusMovedPermanently, - Header: base.Header{ - "CSeq": c.describeCSeq, - "Location": base.HeaderValue{res.redirect}, - }, - }) - return true - } - - c.writeRes(&base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "CSeq": c.describeCSeq, - "Content-Base": base.HeaderValue{c.describeURL + "/"}, - "Content-Type": base.HeaderValue{"application/sdp"}, - }, - Content: res.sdp, - }) - return true - - case <-c.terminate: - ch := c.describeData - go func() { - for range ch { - } - }() - - c.path.OnClientRemove(c) - c.path = nil - - close(c.describeData) - - c.conn.Close() - return false - } -} - -func (c *Client) runPlay() bool { +func (c *Client) startPlay() { if c.streamProtocol == gortsplib.StreamProtocolTCP { - c.tcpWriteOk = true + c.writeFrameEnable = true } - // start sending frames only after replying to the PLAY request c.state = statePlay c.path.OnClientPlay(c) @@ -1020,178 +984,38 @@ func (c *Client) runPlay() bool { }(), c.streamProtocol) if c.path.Conf().RunOnRead != "" { - onReadCmd := externalcmd.New(c.path.Conf().RunOnRead, c.path.Conf().RunOnReadRestart, externalcmd.Environment{ + c.onReadCmd = externalcmd.New(c.path.Conf().RunOnRead, c.path.Conf().RunOnReadRestart, externalcmd.Environment{ Path: c.path.Name(), Port: strconv.FormatInt(int64(c.rtspPort), 10), }) - defer onReadCmd.Close() } - if c.streamProtocol == gortsplib.StreamProtocolUDP { - return c.runPlayUDP() - } - return c.runPlayTCP() -} + if c.streamProtocol == gortsplib.StreamProtocolTCP { + c.writeFrameMutex.Lock() + c.writeFrameEnable = true + c.writeFrameMutex.Unlock() -func (c *Client) runPlayUDP() bool { - readerRequest := make(chan readReq) - defer close(readerRequest) - - readerDone := make(chan error) - go func() { - for { - req, err := c.conn.ReadRequest() - if err != nil { - readerDone <- err - return - } - - okc := make(chan bool) - readerRequest <- readReq{req, okc} - ok := <-okc - if !ok { - readerDone <- nil - return - } - } - }() - - onError := func(err error) bool { - if err == errStateInitial { - c.state = statePrePlay - c.path.OnClientPause(c) - return true - } - - c.conn.Close() - if err != io.EOF && err != errStateTerminate { - c.log(logger.Info, "ERR: %s", err) - } - - c.path.OnClientRemove(c) - c.path = nil - - c.parent.OnClientClose(c) - <-c.terminate - return false - } - - for { - select { - case req := <-readerRequest: - err := c.handleRequest(req.req) - if err != nil { - req.res <- false - <-readerDone - return onError(err) - } - req.res <- true - - case err := <-readerDone: - return onError(err) - - case <-c.terminate: - go func() { - for req := range readerRequest { - req.res <- false - } - }() - - c.path.OnClientRemove(c) - c.path = nil - - c.conn.Close() - <-readerDone - return false - } + c.conn.EnableReadFrames(true) + c.conn.EnableReadTimeout(false) } } -func (c *Client) runPlayTCP() bool { - readerRequest := make(chan readReq) - defer close(readerRequest) +func (c *Client) stopPlay() { + if c.streamProtocol == gortsplib.StreamProtocolTCP { + c.conn.EnableReadFrames(false) + c.conn.EnableReadTimeout(false) - readerDone := make(chan error) - go func() { - for { - recv, err := c.conn.ReadFrameTCPOrRequest(false) - if err != nil { - readerDone <- err - return - } - - switch recvt := recv.(type) { - case *base.InterleavedFrame: - // rtcp feedback is handled by gortsplib - - case *base.Request: - okc := make(chan bool) - readerRequest <- readReq{recvt, okc} - ok := <-okc - if !ok { - readerDone <- nil - return - } - } - } - }() - - onError := func(err error) bool { - if err == errStateInitial { - c.state = statePrePlay - c.path.OnClientPause(c) - return true - } - - c.conn.Close() - if err != io.EOF && err != errStateTerminate { - c.log(logger.Info, "ERR: %s", err) - } - - c.path.OnClientRemove(c) - c.path = nil - - c.parent.OnClientClose(c) - <-c.terminate - return false + c.writeFrameMutex.Lock() + c.writeFrameEnable = false + c.writeFrameMutex.Unlock() } - for { - select { - case req := <-readerRequest: - c.tcpWriteMutex.Lock() - err := c.handleRequest(req.req) - if err != nil { - c.tcpWriteOk = false - c.tcpWriteMutex.Unlock() - req.res <- false - <-readerDone - return onError(err) - } - c.tcpWriteMutex.Unlock() - req.res <- true - - case err := <-readerDone: - return onError(err) - - case <-c.terminate: - go func() { - for req := range readerRequest { - req.res <- false - } - }() - - c.path.OnClientRemove(c) - c.path = nil - - c.conn.Close() - <-readerDone - return false - } + if c.path.Conf().RunOnRead != "" { + c.onReadCmd.Close() } } -func (c *Client) runRecord() bool { +func (c *Client) startRecord() { c.state = stateRecord c.path.OnClientRecord(c) @@ -1235,41 +1059,46 @@ func (c *Client) runRecord() bool { } if c.path.Conf().RunOnPublish != "" { - onPublishCmd := externalcmd.New(c.path.Conf().RunOnPublish, c.path.Conf().RunOnPublishRestart, externalcmd.Environment{ + c.onPublishCmd = externalcmd.New(c.path.Conf().RunOnPublish, c.path.Conf().RunOnPublishRestart, externalcmd.Environment{ Path: c.path.Name(), Port: strconv.FormatInt(int64(c.rtspPort), 10), }) - defer onPublishCmd.Close() } + c.backgroundRecordTerminate = make(chan struct{}) + c.backgroundRecordDone = make(chan struct{}) + if c.streamProtocol == gortsplib.StreamProtocolUDP { - return c.runRecordUDP() + go c.backgroundRecordUDP() + } else { + c.conn.EnableReadFrames(true) + c.conn.EnableReadTimeout(true) + go c.backgroundRecordTCP() } - return c.runRecordTCP() } -func (c *Client) runRecordUDP() bool { - readerRequest := make(chan readReq) - defer close(readerRequest) +func (c *Client) stopRecord() { + close(c.backgroundRecordTerminate) + <-c.backgroundRecordDone - readerDone := make(chan error) - go func() { - for { - req, err := c.conn.ReadRequest() - if err != nil { - readerDone <- err - return - } - - okc := make(chan bool) - readerRequest <- readReq{req, okc} - ok := <-okc - if !ok { - readerDone <- nil - return - } + if c.streamProtocol == gortsplib.StreamProtocolUDP { + for _, track := range c.streamTracks { + c.serverUDPRtp.RemovePublisher(c.ip(), track.rtpPort, c) + c.serverUDPRtcp.RemovePublisher(c.ip(), track.rtcpPort, c) } - }() + + } else { + c.conn.EnableReadFrames(false) + c.conn.EnableReadTimeout(false) + } + + if c.path.Conf().RunOnPublish != "" { + c.onPublishCmd.Close() + } +} + +func (c *Client) backgroundRecordUDP() { + defer close(c.backgroundRecordDone) checkStreamTicker := time.NewTicker(checkStreamInterval) defer checkStreamTicker.Stop() @@ -1277,50 +1106,8 @@ func (c *Client) runRecordUDP() bool { receiverReportTicker := time.NewTicker(receiverReportInterval) defer receiverReportTicker.Stop() - onError := func(err error) bool { - if err == errStateInitial { - for _, track := range c.streamTracks { - c.serverUDPRtp.RemovePublisher(c.ip(), track.rtpPort, c) - c.serverUDPRtcp.RemovePublisher(c.ip(), track.rtcpPort, c) - } - - c.state = statePreRecord - c.path.OnClientPause(c) - return true - } - - c.conn.Close() - if err != io.EOF && err != errStateTerminate { - c.log(logger.Info, "ERR: %s", err) - } - - for _, track := range c.streamTracks { - c.serverUDPRtp.RemovePublisher(c.ip(), track.rtpPort, c) - c.serverUDPRtcp.RemovePublisher(c.ip(), track.rtcpPort, c) - } - - c.path.OnClientRemove(c) - c.path = nil - - c.parent.OnClientClose(c) - <-c.terminate - return false - } - for { select { - case req := <-readerRequest: - err := c.handleRequest(req.req) - if err != nil { - req.res <- false - <-readerDone - return onError(err) - } - req.res <- true - - case err := <-readerDone: - return onError(err) - case <-checkStreamTicker.C: now := time.Now() @@ -1328,27 +1115,9 @@ func (c *Client) runRecordUDP() bool { last := time.Unix(atomic.LoadInt64(lastUnix), 0) if now.Sub(last) >= c.readTimeout { - go func() { - for req := range readerRequest { - req.res <- false - } - }() - - c.log(logger.Info, "ERR: no packets received recently (maybe there's a firewall/NAT in between)") + c.log(logger.Info, "ERR: no UDP packets received recently (maybe there's a firewall/NAT in between)") c.conn.Close() - <-readerDone - - for _, track := range c.streamTracks { - c.serverUDPRtp.RemovePublisher(c.ip(), track.rtpPort, c) - c.serverUDPRtcp.RemovePublisher(c.ip(), track.rtcpPort, c) - } - - c.path.OnClientRemove(c) - c.path = nil - - c.parent.OnClientClose(c) - <-c.terminate - return false + return } } @@ -1363,120 +1132,29 @@ func (c *Client) runRecordUDP() bool { }) } - case <-c.terminate: - go func() { - for req := range readerRequest { - req.res <- false - } - }() - - c.conn.Close() - <-readerDone - - for _, track := range c.streamTracks { - c.serverUDPRtp.RemovePublisher(c.ip(), track.rtpPort, c) - c.serverUDPRtcp.RemovePublisher(c.ip(), track.rtcpPort, c) - } - - c.path.OnClientRemove(c) - c.path = nil - return false + case <-c.backgroundRecordTerminate: + return } } } -func (c *Client) runRecordTCP() bool { - readerRequest := make(chan readReq) - defer close(readerRequest) - - readerDone := make(chan error) - go func() { - for { - recv, err := c.conn.ReadFrameTCPOrRequest(true) - if err != nil { - readerDone <- err - return - } - - switch recvt := recv.(type) { - case *base.InterleavedFrame: - if recvt.TrackID >= len(c.streamTracks) { - readerDone <- fmt.Errorf("invalid track id '%d'", recvt.TrackID) - return - } - - c.rtcpReceivers[recvt.TrackID].ProcessFrame(time.Now(), recvt.StreamType, recvt.Content) - c.path.OnFrame(recvt.TrackID, recvt.StreamType, recvt.Content) - - case *base.Request: - okc := make(chan bool) - readerRequest <- readReq{recvt, okc} - ok := <-okc - if !ok { - readerDone <- nil - return - } - } - } - }() +func (c *Client) backgroundRecordTCP() { + defer close(c.backgroundRecordDone) receiverReportTicker := time.NewTicker(receiverReportInterval) defer receiverReportTicker.Stop() - onError := func(err error) bool { - if err == errStateInitial { - c.state = statePreRecord - c.path.OnClientPause(c) - return true - } - - c.conn.Close() - if err != io.EOF && err != errStateTerminate { - c.log(logger.Info, "ERR: %s", err) - } - - c.path.OnClientRemove(c) - c.path = nil - - c.parent.OnClientClose(c) - <-c.terminate - return false - } - for { select { - case req := <-readerRequest: - err := c.handleRequest(req.req) - if err != nil { - req.res <- false - <-readerDone - return onError(err) - } - req.res <- true - - case err := <-readerDone: - return onError(err) - case <-receiverReportTicker.C: now := time.Now() for trackID := range c.streamTracks { r := c.rtcpReceivers[trackID].Report(now) - c.conn.WriteFrameTCP(trackID, gortsplib.StreamTypeRtcp, r) + c.conn.WriteFrame(trackID, gortsplib.StreamTypeRtcp, r) } - case <-c.terminate: - go func() { - for req := range readerRequest { - req.res <- false - } - }() - - c.conn.Close() - <-readerDone - - c.path.OnClientRemove(c) - c.path = nil - return false + case <-c.backgroundRecordTerminate: + return } } } @@ -1513,11 +1191,11 @@ func (c *Client) OnReaderFrame(trackID int, streamType base.StreamType, buf []by } } else { - c.tcpWriteMutex.Lock() - if c.tcpWriteOk { - c.conn.WriteFrameTCP(trackID, streamType, buf) + c.writeFrameMutex.Lock() + if c.writeFrameEnable { + c.conn.WriteFrame(trackID, streamType, buf) } - c.tcpWriteMutex.Unlock() + c.writeFrameMutex.Unlock() } } diff --git a/internal/servertcp/server.go b/internal/servertcp/server.go index 6d3030bc..8e851192 100644 --- a/internal/servertcp/server.go +++ b/internal/servertcp/server.go @@ -37,7 +37,7 @@ func New(port int, ReadBufferCount: 1, } - srv, err := conf.Serve(":"+strconv.FormatInt(int64(port), 10), nil) + srv, err := conf.Serve(":" + strconv.FormatInt(int64(port), 10)) if err != nil { return nil, err } diff --git a/internal/sourcertsp/source.go b/internal/sourcertsp/source.go index 74ddb717..98b552a9 100644 --- a/internal/sourcertsp/source.go +++ b/internal/sourcertsp/source.go @@ -149,7 +149,7 @@ func (s *Source) runInner() bool { s.parent.OnSourceSetReady(tracks) defer s.parent.OnSourceSetNotReady() - readerDone := conn.OnFrame(func(trackID int, streamType gortsplib.StreamType, content []byte) { + done := conn.ReadFrames(func(trackID int, streamType gortsplib.StreamType, content []byte) { s.parent.OnFrame(trackID, streamType, content) }) @@ -157,10 +157,10 @@ func (s *Source) runInner() bool { select { case <-s.terminate: conn.Close() - <-readerDone + <-done return false - case err := <-readerDone: + case err := <-done: conn.Close() s.log(logger.Info, "ERR: %s", err) return true diff --git a/rtsp-simple-server.yml b/rtsp-simple-server.yml index 864282e5..c52ac5c6 100644 --- a/rtsp-simple-server.yml +++ b/rtsp-simple-server.yml @@ -51,7 +51,7 @@ paths: # if the source is an RTSP url, this is the protocol that will be used to # pull the stream. available options are "automatic", "udp", "tcp". - # the tcp protocol can help to overcome the error "no packets received recently". + # the tcp protocol can help to overcome the error "no UDP packets received recently". sourceProtocol: automatic # if the source is an RTSP or RTMP url, it will be pulled only when at least