From 6163095a11b2325463875e5accdcdf9bb79aa34d Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Tue, 10 Aug 2021 18:34:10 +0200 Subject: [PATCH] fix crash that happens when sourceOnDemand is true and a source times out --- go.mod | 2 +- go.sum | 4 +- internal/core/hls_remuxer.go | 2 +- internal/core/path.go | 153 +++++++++++++--------------------- internal/core/rtmp_conn.go | 12 +-- internal/core/rtmp_source.go | 12 +-- internal/core/rtsp_conn.go | 2 +- internal/core/rtsp_session.go | 25 ++++-- internal/core/rtsp_source.go | 11 ++- internal/core/stream.go | 87 +++++++++++++++++++ 10 files changed, 185 insertions(+), 125 deletions(-) create mode 100644 internal/core/stream.go diff --git a/go.mod b/go.mod index 11da028b..62baedf9 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.16 require ( github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 // indirect github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d // indirect - github.com/aler9/gortsplib v0.0.0-20210731192657-45db8582b0b3 + github.com/aler9/gortsplib v0.0.0-20210810153440-c45a1b399530 github.com/asticode/go-astits v1.9.0 github.com/fsnotify/fsnotify v1.4.9 github.com/gin-gonic/gin v1.7.2 diff --git a/go.sum b/go.sum index 0026fe8d..3277ebd6 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafo github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d h1:UQZhZ2O0vMHr2cI+DC1Mbh0TJxzA3RcLoMsFw+aXw7E= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= -github.com/aler9/gortsplib v0.0.0-20210731192657-45db8582b0b3 h1:OHLssJ39nrj8ln6xBJz3529c2In8cXhfptpvvw1bwDc= -github.com/aler9/gortsplib v0.0.0-20210731192657-45db8582b0b3/go.mod h1:s5FsbPRxJhU/YedvUKAKHVY+lQEdYsiJpuN2CHb89cI= +github.com/aler9/gortsplib v0.0.0-20210810153440-c45a1b399530 h1:/Lzuu854GPVUzVHW35QyViBQ4EE2dgP30E6VMULcqF4= +github.com/aler9/gortsplib v0.0.0-20210810153440-c45a1b399530/go.mod h1:s5FsbPRxJhU/YedvUKAKHVY+lQEdYsiJpuN2CHb89cI= github.com/aler9/rtmp v0.0.0-20210403095203-3be4a5535927 h1:95mXJ5fUCYpBRdSOnLAQAdJHHKxxxJrVCiaqDi965YQ= github.com/aler9/rtmp v0.0.0-20210403095203-3be4a5535927/go.mod h1:vzuE21rowz+lT1NGsWbreIvYulgBpCGnQyeTyFblUHc= github.com/asticode/go-astikit v0.20.0 h1:+7N+J4E4lWx2QOkRdOf6DafWJMv6O4RRfgClwQokrH8= diff --git a/internal/core/hls_remuxer.go b/internal/core/hls_remuxer.go index 8b365e5f..0752dcf8 100644 --- a/internal/core/hls_remuxer.go +++ b/internal/core/hls_remuxer.go @@ -250,7 +250,7 @@ func (r *hlsRemuxer) runRemuxer(remuxerCtx context.Context, remuxerReady chan st var aacConfig rtpaac.MPEG4AudioConfig var aacDecoder *rtpaac.Decoder - for i, t := range res.Stream.Tracks() { + for i, t := range res.Stream.tracks() { if t.IsH264() { if videoTrack != nil { return fmt.Errorf("can't read track %d with HLS: too many tracks", i+1) diff --git a/internal/core/path.go b/internal/core/path.go index 465ed385..448dae4f 100644 --- a/internal/core/path.go +++ b/internal/core/path.go @@ -86,9 +86,14 @@ const ( pathOnDemandStateClosing ) +type pathSourceStaticSetReadyRes struct { + Stream *stream + Err error +} + type pathSourceStaticSetReadyReq struct { Tracks gortsplib.Tracks - Res chan struct{} + Res chan pathSourceStaticSetReadyRes } type pathSourceStaticSetNotReadyReq struct { @@ -108,7 +113,7 @@ type pathPublisherRemoveReq struct { type pathDescribeRes struct { Path *path - Stream *gortsplib.ServerStream + Stream *stream Redirect string Err error } @@ -123,7 +128,7 @@ type pathDescribeReq struct { type pathReaderSetupPlayRes struct { Path *path - Stream *gortsplib.ServerStream + Stream *stream Err error } @@ -143,7 +148,6 @@ type pathPublisherAnnounceRes struct { type pathPublisherAnnounceReq struct { Author publisher PathName string - Tracks gortsplib.Tracks IP net.IP ValidateCredentials func(pathUser string, pathPass string) error Res chan pathPublisherAnnounceRes @@ -155,11 +159,13 @@ type pathReaderPlayReq struct { } type pathPublisherRecordRes struct { - Err error + Stream *stream + Err error } type pathPublisherRecordReq struct { Author publisher + Tracks gortsplib.Tracks Res chan pathPublisherRecordRes } @@ -173,38 +179,6 @@ type pathPublisherPauseReq struct { Res chan struct{} } -type pathReadersMap struct { - mutex sync.RWMutex - ma map[reader]struct{} -} - -func newPathReadersMap() *pathReadersMap { - return &pathReadersMap{ - ma: make(map[reader]struct{}), - } -} - -func (m *pathReadersMap) add(r reader) { - m.mutex.Lock() - defer m.mutex.Unlock() - m.ma[r] = struct{}{} -} - -func (m *pathReadersMap) remove(r reader) { - m.mutex.Lock() - defer m.mutex.Unlock() - delete(m.ma, r) -} - -func (m *pathReadersMap) forwardFrame(trackID int, streamType gortsplib.StreamType, payload []byte) { - m.mutex.RLock() - defer m.mutex.RUnlock() - - for c := range m.ma { - c.OnReaderFrame(trackID, streamType, payload) - } -} - type path struct { rtspAddress string readTimeout time.Duration @@ -223,11 +197,10 @@ type path struct { source source sourceReady bool sourceStaticWg sync.WaitGroup - stream *gortsplib.ServerStream readers map[reader]pathReaderState describeRequests []pathDescribeReq setupPlayRequests []pathReaderSetupPlayReq - nonRTSPReaders *pathReadersMap + stream *stream onDemandCmd *externalcmd.Cmd onPublishCmd *externalcmd.Cmd onDemandReadyTimer *time.Timer @@ -279,7 +252,6 @@ func newPath( ctx: ctx, ctxCancel: ctxCancel, readers: make(map[reader]pathReaderState), - nonRTSPReaders: newPathReadersMap(), onDemandReadyTimer: newEmptyTimer(), onDemandCloseTimer: newEmptyTimer(), sourceStaticSetReady: make(chan pathSourceStaticSetReadyReq), @@ -376,13 +348,16 @@ outer: } case req := <-pa.sourceStaticSetReady: - pa.stream = gortsplib.NewServerStream(req.Tracks) - pa.sourceSetReady() - close(req.Res) + pa.sourceSetReady(req.Tracks) + req.Res <- pathSourceStaticSetReadyRes{Stream: pa.stream} case req := <-pa.sourceStaticSetNotReady: if req.Source == pa.source { - pa.sourceSetNotReady() + if pa.isOnDemand() && pa.onDemandState != pathOnDemandStateInitial { + pa.onDemandCloseSource() + } else { + pa.sourceSetNotReady() + } } close(req.Res) @@ -472,10 +447,6 @@ outer: for rp, state := range pa.readers { if state == pathReaderStatePlay { atomic.AddInt64(pa.stats.CountReaders, -1) - - if _, ok := rp.(pathRTSPSession); !ok { - pa.nonRTSPReaders.remove(rp) - } } rp.Close() } @@ -485,6 +456,10 @@ outer: pa.onDemandCmd.Close() } + if pa.stream != nil { + pa.stream.close() + } + if pa.source != nil { if source, ok := pa.source.(sourceStatic); ok { source.Close() @@ -550,7 +525,11 @@ func (pa *path) onDemandCloseSource() { pa.onDemandState = pathOnDemandStateInitial if pa.hasStaticSource() { - pa.staticSourceDelete() + if pa.sourceReady { + pa.sourceSetNotReady() + } + pa.source.(sourceStatic).Close() + pa.source = nil } else { pa.Log(logger.Info, "on demand command stopped") pa.onDemandCmd.Close() @@ -563,8 +542,9 @@ func (pa *path) onDemandCloseSource() { } } -func (pa *path) sourceSetReady() { +func (pa *path) sourceSetReady(tracks gortsplib.Tracks) { pa.sourceReady = true + pa.stream = newStream(tracks) if pa.isOnDemand() { pa.onDemandReadyTimer.Stop() @@ -593,12 +573,6 @@ func (pa *path) sourceSetReady() { } func (pa *path) sourceSetNotReady() { - pa.sourceReady = false - - if pa.isOnDemand() && pa.onDemandState != pathOnDemandStateInitial { - pa.onDemandCloseSource() - } - if pa.onPublishCmd != nil { pa.onPublishCmd.Close() pa.onPublishCmd = nil @@ -608,6 +582,10 @@ func (pa *path) sourceSetNotReady() { pa.doReaderRemove(r) r.Close() } + + pa.sourceReady = false + pa.stream.close() + pa.stream = nil } func (pa *path) staticSourceCreate() { @@ -638,25 +616,12 @@ func (pa *path) staticSourceCreate() { } } -func (pa *path) staticSourceDelete() { - pa.sourceReady = false - - pa.source.(sourceStatic).Close() - pa.source = nil - - pa.stream.Close() - pa.stream = nil -} - func (pa *path) doReaderRemove(r reader) { state := pa.readers[r] if state == pathReaderStatePlay { atomic.AddInt64(pa.stats.CountReaders, -1) - - if _, ok := r.(pathRTSPSession); !ok { - pa.nonRTSPReaders.remove(r) - } + pa.stream.readerRemove(r) } delete(pa.readers, r) @@ -665,12 +630,15 @@ func (pa *path) doReaderRemove(r reader) { func (pa *path) doPublisherRemove() { if pa.sourceReady { atomic.AddInt64(pa.stats.CountPublishers, -1) - pa.sourceSetNotReady() + + if pa.isOnDemand() && pa.onDemandState != pathOnDemandStateInitial { + pa.onDemandCloseSource() + } else { + pa.sourceSetNotReady() + } } pa.source = nil - pa.stream.Close() - pa.stream = nil for r := range pa.readers { pa.doReaderRemove(r) @@ -746,7 +714,6 @@ func (pa *path) onPublisherAnnounce(req pathPublisherAnnounceReq) { } pa.source = req.Author - pa.stream = gortsplib.NewServerStream(req.Tracks) req.Res <- pathPublisherAnnounceRes{Path: pa} } @@ -759,9 +726,9 @@ func (pa *path) onPublisherRecord(req pathPublisherRecordReq) { atomic.AddInt64(pa.stats.CountPublishers, 1) - req.Author.OnPublisherAccepted(len(pa.stream.Tracks())) + req.Author.OnPublisherAccepted(len(req.Tracks)) - pa.sourceSetReady() + pa.sourceSetReady(req.Tracks) if pa.conf.RunOnPublish != "" { _, port, _ := net.SplitHostPort(pa.rtspAddress) @@ -771,13 +738,18 @@ func (pa *path) onPublisherRecord(req pathPublisherRecordReq) { }) } - req.Res <- pathPublisherRecordRes{} + req.Res <- pathPublisherRecordRes{Stream: pa.stream} } func (pa *path) onPublisherPause(req pathPublisherPauseReq) { if req.Author == pa.source && pa.sourceReady { atomic.AddInt64(pa.stats.CountPublishers, -1) - pa.sourceSetNotReady() + + if pa.isOnDemand() && pa.onDemandState != pathOnDemandStateInitial { + pa.onDemandCloseSource() + } else { + pa.sourceSetNotReady() + } } close(req.Res) } @@ -831,9 +803,7 @@ func (pa *path) onReaderPlay(req pathReaderPlayReq) { atomic.AddInt64(pa.stats.CountReaders, 1) pa.readers[req.Author] = pathReaderStatePlay - if _, ok := req.Author.(pathRTSPSession); !ok { - pa.nonRTSPReaders.add(req.Author) - } + pa.stream.readerAdd(req.Author) req.Author.OnReaderAccepted() @@ -844,21 +814,19 @@ func (pa *path) onReaderPause(req pathReaderPauseReq) { if state, ok := pa.readers[req.Author]; ok && state == pathReaderStatePlay { atomic.AddInt64(pa.stats.CountReaders, -1) pa.readers[req.Author] = pathReaderStatePrePlay - - if _, ok := req.Author.(pathRTSPSession); !ok { - pa.nonRTSPReaders.remove(req.Author) - } + pa.stream.readerRemove(req.Author) } close(req.Res) } // OnSourceStaticSetReady is called by a sourceStatic. -func (pa *path) OnSourceStaticSetReady(req pathSourceStaticSetReadyReq) { - req.Res = make(chan struct{}) +func (pa *path) OnSourceStaticSetReady(req pathSourceStaticSetReadyReq) pathSourceStaticSetReadyRes { + req.Res = make(chan pathSourceStaticSetReadyRes) select { case pa.sourceStaticSetReady <- req: - <-req.Res + return <-req.Res case <-pa.ctx.Done(): + return pathSourceStaticSetReadyRes{Err: fmt.Errorf("terminated")} } } @@ -963,15 +931,6 @@ func (pa *path) OnReaderPause(req pathReaderPauseReq) { } } -// OnSourceFrame is called by a source. -func (pa *path) OnSourceFrame(trackID int, streamType gortsplib.StreamType, payload []byte) { - // forward to RTSP readers - pa.stream.WriteFrame(trackID, streamType, payload) - - // forward to non-RTSP readers - pa.nonRTSPReaders.forwardFrame(trackID, streamType, payload) -} - // OnAPIPathsList is called by api. func (pa *path) OnAPIPathsList(req apiPathsListReq2) apiPathsListRes2 { req.Res = make(chan apiPathsListRes2) diff --git a/internal/core/rtmp_conn.go b/internal/core/rtmp_conn.go index 70af1350..d83b2847 100644 --- a/internal/core/rtmp_conn.go +++ b/internal/core/rtmp_conn.go @@ -231,7 +231,7 @@ func (c *rtmpConn) runRead(ctx context.Context) error { var audioClockRate int var aacDecoder *rtpaac.Decoder - for i, t := range res.Stream.Tracks() { + for i, t := range res.Stream.tracks() { if t.IsH264() { if videoTrack != nil { return fmt.Errorf("can't read track %d with RTMP: too many tracks", i+1) @@ -398,7 +398,6 @@ func (c *rtmpConn) runPublish(ctx context.Context) error { res := c.pathManager.OnPublisherAnnounce(pathPublisherAnnounceReq{ Author: c, PathName: pathName, - Tracks: tracks, IP: c.ip(), ValidateCredentials: func(pathUser string, pathPass string) error { return c.validateCredentials(pathUser, pathPass, query) @@ -423,17 +422,20 @@ func (c *rtmpConn) runPublish(ctx context.Context) error { // disable write deadline c.conn.NetConn().SetWriteDeadline(time.Time{}) - rres := c.path.OnPublisherRecord(pathPublisherRecordReq{Author: c}) + rres := c.path.OnPublisherRecord(pathPublisherRecordReq{ + Author: c, + Tracks: tracks, + }) if rres.Err != nil { return rres.Err } - rtcpSenders := rtcpsenderset.New(tracks, c.path.OnSourceFrame) + rtcpSenders := rtcpsenderset.New(tracks, rres.Stream.onFrame) defer rtcpSenders.Close() onFrame := func(trackID int, payload []byte) { rtcpSenders.OnFrame(trackID, gortsplib.StreamTypeRTP, payload) - c.path.OnSourceFrame(trackID, gortsplib.StreamTypeRTP, payload) + rres.Stream.onFrame(trackID, gortsplib.StreamTypeRTP, payload) } for { diff --git a/internal/core/rtmp_source.go b/internal/core/rtmp_source.go index 39d7f888..6c6a921b 100644 --- a/internal/core/rtmp_source.go +++ b/internal/core/rtmp_source.go @@ -24,9 +24,8 @@ const ( type rtmpSourceParent interface { Log(logger.Level, string, ...interface{}) - OnSourceStaticSetReady(req pathSourceStaticSetReadyReq) + OnSourceStaticSetReady(req pathSourceStaticSetReadyReq) pathSourceStaticSetReadyRes OnSourceStaticSetNotReady(req pathSourceStaticSetNotReadyReq) - OnSourceFrame(int, gortsplib.StreamType, []byte) } type rtmpSource struct { @@ -162,20 +161,23 @@ func (s *rtmpSource) runInner() bool { s.log(logger.Info, "ready") - s.parent.OnSourceStaticSetReady(pathSourceStaticSetReadyReq{ + res := s.parent.OnSourceStaticSetReady(pathSourceStaticSetReadyReq{ Tracks: tracks, }) + if res.Err != nil { + return err + } defer func() { s.parent.OnSourceStaticSetNotReady(pathSourceStaticSetNotReadyReq{Source: s}) }() - rtcpSenders := rtcpsenderset.New(tracks, s.parent.OnSourceFrame) + rtcpSenders := rtcpsenderset.New(tracks, res.Stream.onFrame) defer rtcpSenders.Close() onFrame := func(trackID int, payload []byte) { rtcpSenders.OnFrame(trackID, gortsplib.StreamTypeRTP, payload) - s.parent.OnSourceFrame(trackID, gortsplib.StreamTypeRTP, payload) + res.Stream.onFrame(trackID, gortsplib.StreamTypeRTP, payload) } for { diff --git a/internal/core/rtsp_conn.go b/internal/core/rtsp_conn.go index e1381239..77d52089 100644 --- a/internal/core/rtsp_conn.go +++ b/internal/core/rtsp_conn.go @@ -233,5 +233,5 @@ func (c *rtspConn) OnDescribe(ctx *gortsplib.ServerHandlerOnDescribeCtx) (*base. return &base.Response{ StatusCode: base.StatusOK, - }, res.Stream, nil + }, res.Stream.rtspStream, nil } diff --git a/internal/core/rtsp_session.go b/internal/core/rtsp_session.go index 08a89f86..f194f86b 100644 --- a/internal/core/rtsp_session.go +++ b/internal/core/rtsp_session.go @@ -31,9 +31,11 @@ type rtspSession struct { pathManager *pathManager parent rtspSessionParent - path *path - setuppedTracks map[int]*gortsplib.Track // read - onReadCmd *externalcmd.Cmd // read + path *path + setuppedTracks map[int]*gortsplib.Track // read + onReadCmd *externalcmd.Cmd // read + announcedTracks gortsplib.Tracks // publish + stream *stream // publish } func newRTSPSession( @@ -114,7 +116,6 @@ func (s *rtspSession) OnAnnounce(c *rtspConn, ctx *gortsplib.ServerHandlerOnAnno res := s.pathManager.OnPublisherAnnounce(pathPublisherAnnounceReq{ Author: s, PathName: ctx.Path, - Tracks: ctx.Tracks, IP: ctx.Conn.NetConn().RemoteAddr().(*net.TCPAddr).IP, ValidateCredentials: func(pathUser string, pathPass string) error { return c.validateCredentials(pathUser, pathPass, ctx.Path, ctx.Req) @@ -140,6 +141,7 @@ func (s *rtspSession) OnAnnounce(c *rtspConn, ctx *gortsplib.ServerHandlerOnAnno } s.path = res.Path + s.announcedTracks = ctx.Tracks return &base.Response{ StatusCode: base.StatusOK, @@ -204,7 +206,7 @@ func (s *rtspSession) OnSetup(c *rtspConn, ctx *gortsplib.ServerHandlerOnSetupCt s.path = res.Path - if ctx.TrackID >= len(res.Stream.Tracks()) { + if ctx.TrackID >= len(res.Stream.tracks()) { return &base.Response{ StatusCode: base.StatusBadRequest, }, nil, fmt.Errorf("track %d does not exist", ctx.TrackID) @@ -213,11 +215,11 @@ func (s *rtspSession) OnSetup(c *rtspConn, ctx *gortsplib.ServerHandlerOnSetupCt if s.setuppedTracks == nil { s.setuppedTracks = make(map[int]*gortsplib.Track) } - s.setuppedTracks[ctx.TrackID] = res.Stream.Tracks()[ctx.TrackID] + s.setuppedTracks[ctx.TrackID] = res.Stream.tracks()[ctx.TrackID] return &base.Response{ StatusCode: base.StatusOK, - }, res.Stream, nil + }, res.Stream.rtspStream, nil default: // record return &base.Response{ @@ -250,13 +252,18 @@ func (s *rtspSession) OnPlay(ctx *gortsplib.ServerHandlerOnPlayCtx) (*base.Respo // OnRecord is called by rtspServer. func (s *rtspSession) OnRecord(ctx *gortsplib.ServerHandlerOnRecordCtx) (*base.Response, error) { - res := s.path.OnPublisherRecord(pathPublisherRecordReq{Author: s}) + res := s.path.OnPublisherRecord(pathPublisherRecordReq{ + Author: s, + Tracks: s.announcedTracks, + }) if res.Err != nil { return &base.Response{ StatusCode: base.StatusBadRequest, }, res.Err } + s.stream = res.Stream + return &base.Response{ StatusCode: base.StatusOK, }, nil @@ -338,5 +345,5 @@ func (s *rtspSession) OnFrame(ctx *gortsplib.ServerHandlerOnFrameCtx) { return } - s.path.OnSourceFrame(ctx.TrackID, ctx.StreamType, ctx.Payload) + s.stream.onFrame(ctx.TrackID, ctx.StreamType, ctx.Payload) } diff --git a/internal/core/rtsp_source.go b/internal/core/rtsp_source.go index 100b2fee..34dc3980 100644 --- a/internal/core/rtsp_source.go +++ b/internal/core/rtsp_source.go @@ -23,9 +23,8 @@ const ( type rtspSourceParent interface { Log(logger.Level, string, ...interface{}) - OnSourceStaticSetReady(req pathSourceStaticSetReadyReq) + OnSourceStaticSetReady(req pathSourceStaticSetReadyReq) pathSourceStaticSetReadyRes OnSourceStaticSetNotReady(req pathSourceStaticSetNotReadyReq) - OnSourceFrame(int, gortsplib.StreamType, []byte) } type rtspSource struct { @@ -181,9 +180,13 @@ func (s *rtspSource) runInner() bool { s.log(logger.Info, "ready") - s.parent.OnSourceStaticSetReady(pathSourceStaticSetReadyReq{ + res := s.parent.OnSourceStaticSetReady(pathSourceStaticSetReadyReq{ Tracks: conn.Tracks(), }) + if res.Err != nil { + s.log(logger.Info, "ERR: %s", err) + return true + } defer func() { s.parent.OnSourceStaticSetNotReady(pathSourceStaticSetNotReadyReq{Source: s}) @@ -192,7 +195,7 @@ func (s *rtspSource) runInner() bool { readErr := make(chan error) go func() { readErr <- conn.ReadFrames(func(trackID int, streamType gortsplib.StreamType, payload []byte) { - s.parent.OnSourceFrame(trackID, streamType, payload) + res.Stream.onFrame(trackID, streamType, payload) }) }() diff --git a/internal/core/stream.go b/internal/core/stream.go new file mode 100644 index 00000000..c9712314 --- /dev/null +++ b/internal/core/stream.go @@ -0,0 +1,87 @@ +package core + +import ( + "sync" + + "github.com/aler9/gortsplib" +) + +type streamNonRTSPReadersMap struct { + mutex sync.RWMutex + ma map[reader]struct{} +} + +func newStreamNonRTSPReadersMap() *streamNonRTSPReadersMap { + return &streamNonRTSPReadersMap{ + ma: make(map[reader]struct{}), + } +} + +func (m *streamNonRTSPReadersMap) close() { + m.mutex.Lock() + defer m.mutex.Unlock() + m.ma = nil +} + +func (m *streamNonRTSPReadersMap) add(r reader) { + m.mutex.Lock() + defer m.mutex.Unlock() + m.ma[r] = struct{}{} +} + +func (m *streamNonRTSPReadersMap) remove(r reader) { + m.mutex.Lock() + defer m.mutex.Unlock() + delete(m.ma, r) +} + +func (m *streamNonRTSPReadersMap) forwardFrame(trackID int, streamType gortsplib.StreamType, payload []byte) { + m.mutex.RLock() + defer m.mutex.RUnlock() + + for c := range m.ma { + c.OnReaderFrame(trackID, streamType, payload) + } +} + +type stream struct { + nonRTSPReaders *streamNonRTSPReadersMap + rtspStream *gortsplib.ServerStream +} + +func newStream(tracks gortsplib.Tracks) *stream { + s := &stream{ + nonRTSPReaders: newStreamNonRTSPReadersMap(), + rtspStream: gortsplib.NewServerStream(tracks), + } + return s +} + +func (s *stream) close() { + s.nonRTSPReaders.close() + s.rtspStream.Close() +} + +func (s *stream) tracks() gortsplib.Tracks { + return s.rtspStream.Tracks() +} + +func (s *stream) readerAdd(r reader) { + if _, ok := r.(pathRTSPSession); !ok { + s.nonRTSPReaders.add(r) + } +} + +func (s *stream) readerRemove(r reader) { + if _, ok := r.(pathRTSPSession); !ok { + s.nonRTSPReaders.remove(r) + } +} + +func (s *stream) onFrame(trackID int, streamType gortsplib.StreamType, payload []byte) { + // forward to RTSP readers + s.rtspStream.WriteFrame(trackID, streamType, payload) + + // forward to non-RTSP readers + s.nonRTSPReaders.forwardFrame(trackID, streamType, payload) +}