diff --git a/main.go b/main.go index 797c2d9d..fa63cbe3 100644 --- a/main.go +++ b/main.go @@ -9,7 +9,9 @@ import ( "strings" "time" + "github.com/aler9/gortsplib" "gopkg.in/alecthomas/kingpin.v2" + "gortc.io/sdp" ) var Version = "v0.0.0" @@ -38,10 +40,10 @@ func parseIpCidrList(in string) ([]interface{}, error) { return ret, nil } -type trackFlow int +type trackFlowType int const ( - _TRACK_FLOW_RTP trackFlow = iota + _TRACK_FLOW_RTP trackFlowType = iota _TRACK_FLOW_RTCP ) @@ -64,6 +66,110 @@ func (s streamProtocol) String() string { return "tcp" } +type programEvent interface { + isProgramEvent() +} + +type programEventClientNew struct { + nconn net.Conn +} + +func (programEventClientNew) isProgramEvent() {} + +type programEventClientClose struct { + done chan struct{} + client *serverClient +} + +func (programEventClientClose) isProgramEvent() {} + +type programEventClientGetStreamSdp struct { + path string + res chan []byte +} + +func (programEventClientGetStreamSdp) isProgramEvent() {} + +type programEventClientAnnounce struct { + res chan error + client *serverClient + path string + sdpText []byte + sdpParsed *sdp.Message +} + +func (programEventClientAnnounce) isProgramEvent() {} + +type programEventClientSetupPlay struct { + res chan error + client *serverClient + path string + protocol streamProtocol + rtpPort int + rtcpPort int +} + +func (programEventClientSetupPlay) isProgramEvent() {} + +type programEventClientSetupRecord struct { + res chan error + client *serverClient + protocol streamProtocol + rtpPort int + rtcpPort int +} + +func (programEventClientSetupRecord) isProgramEvent() {} + +type programEventClientPlay1 struct { + res chan error + client *serverClient +} + +func (programEventClientPlay1) isProgramEvent() {} + +type programEventClientPlay2 struct { + res chan error + client *serverClient +} + +func (programEventClientPlay2) isProgramEvent() {} + +type programEventClientPause struct { + res chan error + client *serverClient +} + +func (programEventClientPause) isProgramEvent() {} + +type programEventClientRecord struct { + res chan error + client *serverClient +} + +func (programEventClientRecord) isProgramEvent() {} + +type programEventFrameUdp struct { + trackFlowType trackFlowType + addr *net.UDPAddr + buf []byte +} + +func (programEventFrameUdp) isProgramEvent() {} + +type programEventFrameTcp struct { + path string + trackId int + trackFlowType trackFlowType + buf []byte +} + +func (programEventFrameTcp) isProgramEvent() {} + +type programEventTerminate struct{} + +func (programEventTerminate) isProgramEvent() {} + type args struct { version bool protocolsStr string @@ -90,6 +196,11 @@ type program struct { tcpl *serverTcpListener udplRtp *serverUdpListener udplRtcp *serverUdpListener + clients map[*serverClient]struct{} + publishers map[string]*serverClient + + events chan programEvent + done chan struct{} } func newProgram(sargs []string) (*program, error) { @@ -204,6 +315,10 @@ func newProgram(sargs []string) (*program, error) { protocols: protocols, publishIps: publishIps, readIps: readIps, + clients: make(map[*serverClient]struct{}), + publishers: make(map[string]*serverClient), + events: make(chan programEvent), + done: make(chan struct{}), } p.udplRtp, err = newServerUdpListener(p, args.rtpPort, _TRACK_FLOW_RTP) @@ -224,14 +339,243 @@ func newProgram(sargs []string) (*program, error) { go p.udplRtp.run() go p.udplRtcp.run() go p.tcpl.run() + go p.run() return p, nil } -func (p *program) close() { +func (p *program) run() { +outer: + for rawEvt := range p.events { + switch evt := rawEvt.(type) { + case programEventClientNew: + c := newServerClient(p, evt.nconn) + p.clients[c] = struct{}{} + + case programEventClientClose: + // already deleted + if _, ok := p.clients[evt.client]; !ok { + close(evt.done) + continue + } + + delete(p.clients, evt.client) + + if evt.client.path != "" { + if pub, ok := p.publishers[evt.client.path]; ok && pub == evt.client { + delete(p.publishers, evt.client.path) + + // if the publisher has disconnected + // close all other connections that share the same path + for oc := range p.clients { + if oc.path == evt.client.path { + go oc.close() + } + } + } + } + + close(evt.done) + + case programEventClientGetStreamSdp: + pub, ok := p.publishers[evt.path] + if !ok { + evt.res <- nil + continue + } + evt.res <- pub.streamSdpText + + case programEventClientAnnounce: + _, ok := p.publishers[evt.path] + if ok { + evt.res <- fmt.Errorf("another client is already publishing on path '%s'", evt.path) + continue + } + + evt.client.path = evt.path + evt.client.streamSdpText = evt.sdpText + evt.client.streamSdpParsed = evt.sdpParsed + evt.client.state = _CLIENT_STATE_ANNOUNCE + p.publishers[evt.path] = evt.client + evt.res <- nil + + case programEventClientSetupPlay: + pub, ok := p.publishers[evt.path] + if !ok { + evt.res <- fmt.Errorf("no one is streaming on path '%s'", evt.path) + continue + } + + if len(evt.client.streamTracks) >= len(pub.streamSdpParsed.Medias) { + evt.res <- fmt.Errorf("all the tracks have already been setup") + continue + } + + evt.client.path = evt.path + evt.client.streamProtocol = evt.protocol + evt.client.streamTracks = append(evt.client.streamTracks, &track{ + rtpPort: evt.rtpPort, + rtcpPort: evt.rtcpPort, + }) + evt.client.state = _CLIENT_STATE_PRE_PLAY + evt.res <- nil + + case programEventClientSetupRecord: + evt.client.streamProtocol = evt.protocol + evt.client.streamTracks = append(evt.client.streamTracks, &track{ + rtpPort: evt.rtpPort, + rtcpPort: evt.rtcpPort, + }) + evt.client.state = _CLIENT_STATE_PRE_RECORD + evt.res <- nil + + case programEventClientPlay1: + pub, ok := p.publishers[evt.client.path] + if !ok { + evt.res <- fmt.Errorf("no one is streaming on path '%s'", evt.client.path) + continue + } + + if len(evt.client.streamTracks) != len(pub.streamSdpParsed.Medias) { + evt.res <- fmt.Errorf("not all tracks have been setup") + continue + } + + evt.res <- nil + + case programEventClientPlay2: + evt.client.state = _CLIENT_STATE_PLAY + evt.res <- nil + + case programEventClientPause: + evt.client.state = _CLIENT_STATE_PRE_PLAY + evt.res <- nil + + case programEventClientRecord: + evt.client.state = _CLIENT_STATE_RECORD + evt.res <- nil + + case programEventFrameUdp: + // find publisher and track id from ip and port + pub, trackId := func() (*serverClient, int) { + for _, pub := range p.publishers { + if pub.streamProtocol != _STREAM_PROTOCOL_UDP || + pub.state != _CLIENT_STATE_RECORD || + !pub.ip().Equal(evt.addr.IP) { + continue + } + + for i, t := range pub.streamTracks { + if evt.trackFlowType == _TRACK_FLOW_RTP { + if t.rtpPort == evt.addr.Port { + return pub, i + } + } else { + if t.rtcpPort == evt.addr.Port { + return pub, i + } + } + } + } + return nil, -1 + }() + if pub == nil { + continue + } + + pub.udpLastFrameTime = time.Now() + p.forwardTrack(pub.path, trackId, evt.trackFlowType, evt.buf) + + case programEventFrameTcp: + p.forwardTrack(evt.path, evt.trackId, evt.trackFlowType, evt.buf) + + case programEventTerminate: + break outer + } + } + + go func() { + for rawEvt := range p.events { + switch evt := rawEvt.(type) { + case programEventClientClose: + close(evt.done) + + case programEventClientGetStreamSdp: + evt.res <- nil + + case programEventClientAnnounce: + evt.res <- fmt.Errorf("terminated") + + case programEventClientSetupPlay: + evt.res <- fmt.Errorf("terminated") + + case programEventClientSetupRecord: + evt.res <- fmt.Errorf("terminated") + + case programEventClientPlay1: + evt.res <- fmt.Errorf("terminated") + + case programEventClientPlay2: + evt.res <- fmt.Errorf("terminated") + + case programEventClientPause: + evt.res <- fmt.Errorf("terminated") + + case programEventClientRecord: + evt.res <- fmt.Errorf("terminated") + } + } + }() + p.tcpl.close() p.udplRtcp.close() p.udplRtp.close() + + for c := range p.clients { + c.close() + } + + close(p.events) + close(p.done) +} + +func (p *program) close() { + p.events <- programEventTerminate{} + <-p.done +} + +func (p *program) forwardTrack(path string, id int, trackFlowType trackFlowType, frame []byte) { + for c := range p.clients { + if c.path == path && c.state == _CLIENT_STATE_PLAY { + if c.streamProtocol == _STREAM_PROTOCOL_UDP { + if trackFlowType == _TRACK_FLOW_RTP { + p.udplRtp.write <- &udpWrite{ + addr: &net.UDPAddr{ + IP: c.ip(), + Zone: c.zone(), + Port: c.streamTracks[id].rtpPort, + }, + buf: frame, + } + } else { + p.udplRtcp.write <- &udpWrite{ + addr: &net.UDPAddr{ + IP: c.ip(), + Zone: c.zone(), + Port: c.streamTracks[id].rtcpPort, + }, + buf: frame, + } + } + + } else { + c.write <- &gortsplib.InterleavedFrame{ + Channel: trackToInterleavedChannel(id, trackFlowType), + Content: frame, + } + } + } + } } func main() { diff --git a/server-client.go b/server-client.go index 11cef831..e38281c0 100644 --- a/server-client.go +++ b/server-client.go @@ -19,15 +19,15 @@ const ( _UDP_STREAM_DEAD_AFTER = 10 * time.Second ) -func interleavedChannelToTrack(channel uint8) (int, trackFlow) { +func interleavedChannelToTrack(channel uint8) (int, trackFlowType) { if (channel % 2) == 0 { return int(channel / 2), _TRACK_FLOW_RTP } return int((channel - 1) / 2), _TRACK_FLOW_RTCP } -func trackToInterleavedChannel(id int, flow trackFlow) uint8 { - if flow == _TRACK_FLOW_RTP { +func trackToInterleavedChannel(id int, trackFlowType trackFlowType) uint8 { + if trackFlowType == _TRACK_FLOW_RTP { return uint8(id * 2) } return uint8((id * 2) + 1) @@ -80,8 +80,9 @@ type serverClient struct { streamTracks []*track udpLastFrameTime time.Time udpCheckStreamTicker *time.Ticker - write chan *gortsplib.InterleavedFrame - done chan struct{} + + write chan *gortsplib.InterleavedFrame + done chan struct{} } func newServerClient(p *program, nconn net.Conn) *serverClient { @@ -97,39 +98,13 @@ func newServerClient(p *program, nconn net.Conn) *serverClient { done: make(chan struct{}), } - c.p.tcpl.mutex.Lock() - c.p.tcpl.clients[c] = struct{}{} - c.p.tcpl.mutex.Unlock() - go c.run() - return c } -func (c *serverClient) close() error { - // already deleted - if _, ok := c.p.tcpl.clients[c]; !ok { - return nil - } - - delete(c.p.tcpl.clients, c) +func (c *serverClient) close() { c.conn.NetConn().Close() - close(c.write) - - if c.path != "" { - if pub, ok := c.p.tcpl.publishers[c.path]; ok && pub == c { - delete(c.p.tcpl.publishers, c.path) - - // if the publisher has disconnected - // close all other connections that share the same path - for oc := range c.p.tcpl.clients { - if oc.path == c.path { - oc.close() - } - } - } - } - return nil + <-c.done } func (c *serverClient) log(format string, args ...interface{}) { @@ -172,18 +147,12 @@ func (c *serverClient) run() { } } - func() { - c.p.tcpl.mutex.Lock() - defer c.p.tcpl.mutex.Unlock() - c.close() - }() + c.log("disconnected") if c.udpCheckStreamTicker != nil { c.udpCheckStreamTicker.Stop() } - c.log("disconnected") - func() { if c.p.args.postScript != "" { postScript := exec.Command(c.p.args.postScript) @@ -194,6 +163,12 @@ func (c *serverClient) run() { } }() + done := make(chan struct{}) + c.p.events <- programEventClientClose{done, c} + <-done + + close(c.write) + close(c.done) } @@ -202,7 +177,7 @@ func (c *serverClient) writeResError(req *gortsplib.Request, code gortsplib.Stat header := gortsplib.Header{} if cseq, ok := req.Header["CSeq"]; ok && len(cseq) == 1 { - header["CSeq"] = []string{cseq[0]} + header["CSeq"] = cseq } c.conn.WriteResponse(&gortsplib.Response{ @@ -317,7 +292,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { c.conn.WriteResponse(&gortsplib.Response{ StatusCode: gortsplib.StatusOK, Header: gortsplib.Header{ - "CSeq": []string{cseq[0]}, + "CSeq": cseq, "Public": []string{strings.Join([]string{ string(gortsplib.DESCRIBE), string(gortsplib.ANNOUNCE), @@ -346,26 +321,18 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { return true } - sdp, err := func() ([]byte, error) { - c.p.tcpl.mutex.RLock() - defer c.p.tcpl.mutex.RUnlock() - - pub, ok := c.p.tcpl.publishers[path] - if !ok { - return nil, fmt.Errorf("no one is streaming on path '%s'", path) - } - - return pub.streamSdpText, nil - }() - if err != nil { - c.writeResError(req, gortsplib.StatusBadRequest, err) + res := make(chan []byte) + c.p.events <- programEventClientGetStreamSdp{path, res} + sdp := <-res + if sdp == nil { + c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("no one is streaming on path '%s'", path)) return false } c.conn.WriteResponse(&gortsplib.Response{ StatusCode: gortsplib.StatusOK, Header: gortsplib.Header{ - "CSeq": []string{cseq[0]}, + "CSeq": cseq, "Content-Base": []string{req.Url.String()}, "Content-Type": []string{"application/sdp"}, }, @@ -404,25 +371,16 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("invalid SDP: %s", err)) return false } - sdpParsed, req.Content = gortsplib.SDPFilter(sdpParsed, req.Content) - err = func() error { - c.p.tcpl.mutex.Lock() - defer c.p.tcpl.mutex.Unlock() + if len(path) == 0 { + c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("path can't be empty")) + return false + } - _, ok := c.p.tcpl.publishers[path] - if ok { - return fmt.Errorf("another client is already publishing on path '%s'", path) - } - - c.path = path - c.p.tcpl.publishers[path] = c - c.streamSdpText = req.Content - c.streamSdpParsed = sdpParsed - c.state = _CLIENT_STATE_ANNOUNCE - return nil - }() + res := make(chan error) + c.p.events <- programEventClientAnnounce{res, c, path, req.Content, sdpParsed} + err = <-res if err != nil { c.writeResError(req, gortsplib.StatusBadRequest, err) return false @@ -431,7 +389,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { c.conn.WriteResponse(&gortsplib.Response{ StatusCode: gortsplib.StatusOK, Header: gortsplib.Header{ - "CSeq": []string{cseq[0]}, + "CSeq": cseq, }, }) return true @@ -488,33 +446,14 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { return false } - err := func() error { - c.p.tcpl.mutex.Lock() - defer c.p.tcpl.mutex.Unlock() + if len(c.streamTracks) > 0 && c.streamProtocol != _STREAM_PROTOCOL_UDP { + c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("can't receive tracks with different protocols")) + return false + } - pub, ok := c.p.tcpl.publishers[path] - if !ok { - return fmt.Errorf("no one is streaming on path '%s'", path) - } - - if len(c.streamTracks) > 0 && c.streamProtocol != _STREAM_PROTOCOL_UDP { - return fmt.Errorf("client wants to read tracks with different protocols") - } - - if len(c.streamTracks) >= len(pub.streamSdpParsed.Medias) { - return fmt.Errorf("all the tracks have already been setup") - } - - c.path = path - c.streamProtocol = _STREAM_PROTOCOL_UDP - c.streamTracks = append(c.streamTracks, &track{ - rtpPort: rtpPort, - rtcpPort: rtcpPort, - }) - - c.state = _CLIENT_STATE_PRE_PLAY - return nil - }() + res := make(chan error) + c.p.events <- programEventClientSetupPlay{res, c, path, _STREAM_PROTOCOL_UDP, rtpPort, rtcpPort} + err = <-res if err != nil { c.writeResError(req, gortsplib.StatusBadRequest, err) return false @@ -523,7 +462,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { c.conn.WriteResponse(&gortsplib.Response{ StatusCode: gortsplib.StatusOK, Header: gortsplib.Header{ - "CSeq": []string{cseq[0]}, + "CSeq": cseq, "Transport": []string{strings.Join([]string{ "RTP/AVP/UDP", "unicast", @@ -547,33 +486,14 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { return false } - err := func() error { - c.p.tcpl.mutex.Lock() - defer c.p.tcpl.mutex.Unlock() + if len(c.streamTracks) > 0 && c.streamProtocol != _STREAM_PROTOCOL_TCP { + c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("can't receive tracks with different protocols")) + return false + } - pub, ok := c.p.tcpl.publishers[path] - if !ok { - return fmt.Errorf("no one is streaming on path '%s'", path) - } - - if len(c.streamTracks) > 0 && c.streamProtocol != _STREAM_PROTOCOL_TCP { - return fmt.Errorf("client wants to read tracks with different protocols") - } - - if len(c.streamTracks) >= len(pub.streamSdpParsed.Medias) { - return fmt.Errorf("all the tracks have already been setup") - } - - c.path = path - c.streamProtocol = _STREAM_PROTOCOL_TCP - c.streamTracks = append(c.streamTracks, &track{ - rtpPort: 0, - rtcpPort: 0, - }) - - c.state = _CLIENT_STATE_PRE_PLAY - return nil - }() + res := make(chan error) + c.p.events <- programEventClientSetupPlay{res, c, path, _STREAM_PROTOCOL_TCP, 0, 0} + err = <-res if err != nil { c.writeResError(req, gortsplib.StatusBadRequest, err) return false @@ -584,7 +504,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { c.conn.WriteResponse(&gortsplib.Response{ StatusCode: gortsplib.StatusOK, Header: gortsplib.Header{ - "CSeq": []string{cseq[0]}, + "CSeq": cseq, "Transport": []string{strings.Join([]string{ "RTP/AVP/TCP", "unicast", @@ -607,6 +527,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { return false } + // after ANNOUNCE, c.path is already set if path != c.path { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("path has changed")) return false @@ -635,27 +556,19 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { return false } - err := func() error { - c.p.tcpl.mutex.Lock() - defer c.p.tcpl.mutex.Unlock() + if len(c.streamTracks) > 0 && c.streamProtocol != _STREAM_PROTOCOL_UDP { + c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("can't publish tracks with different protocols")) + return false + } - if len(c.streamTracks) > 0 && c.streamProtocol != _STREAM_PROTOCOL_UDP { - return fmt.Errorf("client wants to publish tracks with different protocols") - } + if len(c.streamTracks) >= len(c.streamSdpParsed.Medias) { + c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("all the tracks have already been setup")) + return false + } - if len(c.streamTracks) >= len(c.streamSdpParsed.Medias) { - return fmt.Errorf("all the tracks have already been setup") - } - - c.streamProtocol = _STREAM_PROTOCOL_UDP - c.streamTracks = append(c.streamTracks, &track{ - rtpPort: rtpPort, - rtcpPort: rtcpPort, - }) - - c.state = _CLIENT_STATE_PRE_RECORD - return nil - }() + res := make(chan error) + c.p.events <- programEventClientSetupRecord{res, c, _STREAM_PROTOCOL_UDP, rtpPort, rtcpPort} + err := <-res if err != nil { c.writeResError(req, gortsplib.StatusBadRequest, err) return false @@ -664,7 +577,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { c.conn.WriteResponse(&gortsplib.Response{ StatusCode: gortsplib.StatusOK, Header: gortsplib.Header{ - "CSeq": []string{cseq[0]}, + "CSeq": cseq, "Transport": []string{strings.Join([]string{ "RTP/AVP/UDP", "unicast", @@ -683,38 +596,31 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { return false } - var interleaved string - err := func() error { - c.p.tcpl.mutex.Lock() - defer c.p.tcpl.mutex.Unlock() + if len(c.streamTracks) > 0 && c.streamProtocol != _STREAM_PROTOCOL_TCP { + c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("can't publish tracks with different protocols")) + return false + } - if len(c.streamTracks) > 0 && c.streamProtocol != _STREAM_PROTOCOL_TCP { - return fmt.Errorf("client wants to publish tracks with different protocols") - } + interleaved := th.GetValue("interleaved") + if interleaved == "" { + c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("transport header does not contain the interleaved field")) + return false + } - if len(c.streamTracks) >= len(c.streamSdpParsed.Medias) { - return fmt.Errorf("all the tracks have already been setup") - } + expInterleaved := fmt.Sprintf("%d-%d", 0+len(c.streamTracks)*2, 1+len(c.streamTracks)*2) + if interleaved != expInterleaved { + c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("wrong interleaved value, expected '%s', got '%s'", expInterleaved, interleaved)) + return false + } - interleaved = th.GetValue("interleaved") - if interleaved == "" { - return fmt.Errorf("transport header does not contain interleaved field") - } + if len(c.streamTracks) >= len(c.streamSdpParsed.Medias) { + c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("all the tracks have already been setup")) + return false + } - expInterleaved := fmt.Sprintf("%d-%d", 0+len(c.streamTracks)*2, 1+len(c.streamTracks)*2) - if interleaved != expInterleaved { - return fmt.Errorf("wrong interleaved value, expected '%s', got '%s'", expInterleaved, interleaved) - } - - c.streamProtocol = _STREAM_PROTOCOL_TCP - c.streamTracks = append(c.streamTracks, &track{ - rtpPort: 0, - rtcpPort: 0, - }) - - c.state = _CLIENT_STATE_PRE_RECORD - return nil - }() + res := make(chan error) + c.p.events <- programEventClientSetupRecord{res, c, _STREAM_PROTOCOL_TCP, 0, 0} + err := <-res if err != nil { c.writeResError(req, gortsplib.StatusBadRequest, err) return false @@ -723,7 +629,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { c.conn.WriteResponse(&gortsplib.Response{ StatusCode: gortsplib.StatusOK, Header: gortsplib.Header{ - "CSeq": []string{cseq[0]}, + "CSeq": cseq, "Transport": []string{strings.Join([]string{ "RTP/AVP/TCP", "unicast", @@ -756,33 +662,22 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { return false } - err := func() error { - c.p.tcpl.mutex.Lock() - defer c.p.tcpl.mutex.Unlock() - - pub, ok := c.p.tcpl.publishers[c.path] - if !ok { - return fmt.Errorf("no one is streaming on path '%s'", c.path) - } - - if len(c.streamTracks) != len(pub.streamSdpParsed.Medias) { - return fmt.Errorf("not all tracks have been setup") - } - - return nil - }() + // check publisher existence + res := make(chan error) + c.p.events <- programEventClientPlay1{res, c} + err := <-res if err != nil { c.writeResError(req, gortsplib.StatusBadRequest, err) return false } - // first write response, then set state - // otherwise, in case of TCP connections, RTP packets could be written + // write response before setting state + // otherwise, in case of TCP connections, RTP packets could be sent // before the response c.conn.WriteResponse(&gortsplib.Response{ StatusCode: gortsplib.StatusOK, Header: gortsplib.Header{ - "CSeq": []string{cseq[0]}, + "CSeq": cseq, "Session": []string{"12345678"}, }, }) @@ -794,9 +689,10 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { return "tracks" }(), c.streamProtocol) - c.p.tcpl.mutex.Lock() - c.state = _CLIENT_STATE_PLAY - c.p.tcpl.mutex.Unlock() + // set state + res = make(chan error) + c.p.events <- programEventClientPlay2{res, c} + <-res // when protocol is TCP, the RTSP connection becomes a RTP connection if c.streamProtocol == _STREAM_PROTOCOL_TCP { @@ -836,14 +732,14 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { c.log("paused") - c.p.tcpl.mutex.Lock() - c.state = _CLIENT_STATE_PRE_PLAY - c.p.tcpl.mutex.Unlock() + res := make(chan error) + c.p.events <- programEventClientPause{res, c} + <-res c.conn.WriteResponse(&gortsplib.Response{ StatusCode: gortsplib.StatusOK, Header: gortsplib.Header{ - "CSeq": []string{cseq[0]}, + "CSeq": cseq, "Session": []string{"12345678"}, }, }) @@ -861,25 +757,15 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { return false } - err := func() error { - c.p.tcpl.mutex.Lock() - defer c.p.tcpl.mutex.Unlock() - - if len(c.streamTracks) != len(c.streamSdpParsed.Medias) { - return fmt.Errorf("not all tracks have been setup") - } - - return nil - }() - if err != nil { - c.writeResError(req, gortsplib.StatusBadRequest, err) + if len(c.streamTracks) != len(c.streamSdpParsed.Medias) { + c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("not all tracks have been setup")) return false } c.conn.WriteResponse(&gortsplib.Response{ StatusCode: gortsplib.StatusOK, Header: gortsplib.Header{ - "CSeq": []string{cseq[0]}, + "CSeq": cseq, "Session": []string{"12345678"}, }, }) @@ -891,13 +777,13 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { return "tracks" }(), c.streamProtocol) + res := make(chan error) + c.p.events <- programEventClientRecord{res, c} + <-res + // when protocol is TCP, the RTSP connection becomes a RTP connection // receive RTP data and parse it if c.streamProtocol == _STREAM_PROTOCOL_TCP { - c.p.tcpl.mutex.Lock() - c.state = _CLIENT_STATE_RECORD - c.p.tcpl.mutex.Unlock() - for { frame, err := c.conn.ReadInterleavedFrame() if err != nil { @@ -907,37 +793,27 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { return false } - trackId, trackFlow := interleavedChannelToTrack(frame.Channel) + trackId, trackFlowType := interleavedChannelToTrack(frame.Channel) if trackId >= len(c.streamTracks) { c.log("ERR: invalid track id '%d'", trackId) return false } - c.p.tcpl.mutex.RLock() - c.p.tcpl.forwardTrack(c.path, trackId, trackFlow, frame.Content) - c.p.tcpl.mutex.RUnlock() + c.p.events <- programEventFrameTcp{ + c.path, + trackId, + trackFlowType, + frame.Content, + } } } else { - c.p.tcpl.mutex.Lock() - c.state = _CLIENT_STATE_RECORD c.udpLastFrameTime = time.Now() c.udpCheckStreamTicker = time.NewTicker(_UDP_CHECK_STREAM_INTERVAL) - c.p.tcpl.mutex.Unlock() go func() { for range c.udpCheckStreamTicker.C { - ok := func() bool { - c.p.tcpl.mutex.Lock() - defer c.p.tcpl.mutex.Unlock() - - if time.Since(c.udpLastFrameTime) >= _UDP_STREAM_DEAD_AFTER { - return false - } - - return true - }() - if !ok { + if time.Since(c.udpLastFrameTime) >= _UDP_STREAM_DEAD_AFTER { c.log("ERR: stream is dead") c.conn.NetConn().Close() break diff --git a/server-tcpl.go b/server-tcpl.go index 623a19d8..2bafef6a 100644 --- a/server-tcpl.go +++ b/server-tcpl.go @@ -3,18 +3,13 @@ package main import ( "log" "net" - "sync" - - "github.com/aler9/gortsplib" ) type serverTcpListener struct { - p *program - nconn *net.TCPListener - mutex sync.RWMutex - clients map[*serverClient]struct{} - publishers map[string]*serverClient - done chan struct{} + p *program + nconn *net.TCPListener + + done chan struct{} } func newServerTcpListener(p *program) (*serverTcpListener, error) { @@ -26,11 +21,9 @@ func newServerTcpListener(p *program) (*serverTcpListener, error) { } l := &serverTcpListener{ - p: p, - nconn: nconn, - clients: make(map[*serverClient]struct{}), - publishers: make(map[string]*serverClient), - done: make(chan struct{}), + p: p, + nconn: nconn, + done: make(chan struct{}), } l.log("opened on :%d", p.args.rtspPort) @@ -48,21 +41,7 @@ func (l *serverTcpListener) run() { break } - newServerClient(l.p, nconn) - } - - // close clients - var doneChans []chan struct{} - func() { - l.mutex.Lock() - defer l.mutex.Unlock() - for c := range l.clients { - c.close() - doneChans = append(doneChans, c.done) - } - }() - for _, c := range doneChans { - <-c + l.p.events <- programEventClientNew{nconn} } close(l.done) @@ -72,37 +51,3 @@ func (l *serverTcpListener) close() { l.nconn.Close() <-l.done } - -func (l *serverTcpListener) forwardTrack(path string, id int, flow trackFlow, frame []byte) { - for c := range l.clients { - if c.path == path && c.state == _CLIENT_STATE_PLAY { - if c.streamProtocol == _STREAM_PROTOCOL_UDP { - if flow == _TRACK_FLOW_RTP { - l.p.udplRtp.write <- &udpWrite{ - addr: &net.UDPAddr{ - IP: c.ip(), - Zone: c.zone(), - Port: c.streamTracks[id].rtpPort, - }, - buf: frame, - } - } else { - l.p.udplRtcp.write <- &udpWrite{ - addr: &net.UDPAddr{ - IP: c.ip(), - Zone: c.zone(), - Port: c.streamTracks[id].rtcpPort, - }, - buf: frame, - } - } - - } else { - c.write <- &gortsplib.InterleavedFrame{ - Channel: trackToInterleavedChannel(id, flow), - Content: frame, - } - } - } - } -} diff --git a/server-udpl.go b/server-udpl.go index 12a58885..59c6408a 100644 --- a/server-udpl.go +++ b/server-udpl.go @@ -12,14 +12,15 @@ type udpWrite struct { } type serverUdpListener struct { - p *program - nconn *net.UDPConn - flow trackFlow + p *program + nconn *net.UDPConn + trackFlowType trackFlowType + write chan *udpWrite done chan struct{} } -func newServerUdpListener(p *program, port int, flow trackFlow) (*serverUdpListener, error) { +func newServerUdpListener(p *program, port int, trackFlowType trackFlowType) (*serverUdpListener, error) { nconn, err := net.ListenUDP("udp", &net.UDPAddr{ Port: port, }) @@ -28,11 +29,11 @@ func newServerUdpListener(p *program, port int, flow trackFlow) (*serverUdpListe } l := &serverUdpListener{ - p: p, - nconn: nconn, - flow: flow, - write: make(chan *udpWrite), - done: make(chan struct{}), + p: p, + nconn: nconn, + trackFlowType: trackFlowType, + write: make(chan *udpWrite), + done: make(chan struct{}), } l.log("opened on :%d", port) @@ -41,7 +42,7 @@ func newServerUdpListener(p *program, port int, flow trackFlow) (*serverUdpListe func (l *serverUdpListener) log(format string, args ...interface{}) { var label string - if l.flow == _TRACK_FLOW_RTP { + if l.trackFlowType == _TRACK_FLOW_RTP { label = "RTP" } else { label = "RTCP" @@ -67,40 +68,11 @@ func (l *serverUdpListener) run() { break } - func() { - l.p.tcpl.mutex.Lock() - defer l.p.tcpl.mutex.Unlock() - - // find publisher and track id from ip and port - pub, trackId := func() (*serverClient, int) { - for _, pub := range l.p.tcpl.publishers { - if pub.streamProtocol != _STREAM_PROTOCOL_UDP || - pub.state != _CLIENT_STATE_RECORD || - !pub.ip().Equal(addr.IP) { - continue - } - - for i, t := range pub.streamTracks { - if l.flow == _TRACK_FLOW_RTP { - if t.rtpPort == addr.Port { - return pub, i - } - } else { - if t.rtcpPort == addr.Port { - return pub, i - } - } - } - } - return nil, -1 - }() - if pub == nil { - return - } - - pub.udpLastFrameTime = time.Now() - l.p.tcpl.forwardTrack(pub.path, trackId, l.flow, buf[:n]) - }() + l.p.events <- programEventFrameUdp{ + l.trackFlowType, + addr, + buf[:n], + } } close(l.write)