From a14246d7769ac3a1b321cd540138417adc12546b Mon Sep 17 00:00:00 2001 From: Alessandro Ros Date: Tue, 16 May 2023 15:59:37 +0200 Subject: [PATCH] webrtc: support publishing with WHIP and reading with WHEP (#1800) --- README.md | 38 +- apidocs/openapi.yaml | 10 +- internal/core/api.go | 113 +-- internal/core/api_test.go | 15 +- internal/core/core.go | 36 +- internal/core/hls_http_server.go | 148 ++++ internal/core/hls_manager.go | 314 ++++++++ ...hls_server_test.go => hls_manager_test.go} | 4 +- internal/core/hls_muxer.go | 26 +- internal/core/hls_server.go | 398 ---------- internal/core/hls_source_test.go | 16 +- internal/core/http_requestpool.go | 25 - internal/core/metrics.go | 50 +- internal/core/metrics_test.go | 12 +- internal/core/path_manager.go | 26 +- internal/core/rtmp_server.go | 53 +- internal/core/rtsp_server.go | 27 +- internal/core/webrtc_candidate_reader.go | 73 -- internal/core/webrtc_conn.go | 694 ------------------ internal/core/webrtc_http_server.go | 364 +++++++++ internal/core/webrtc_incoming_track.go | 1 + internal/core/webrtc_manager.go | 508 +++++++++++++ internal/core/webrtc_manager_test.go | 353 +++++++++ internal/core/webrtc_outgoing_track.go | 88 ++- internal/core/webrtc_pc.go | 7 + internal/core/webrtc_publish_index.html | 286 +++++--- internal/core/webrtc_read_index.html | 241 ++++-- internal/core/webrtc_server.go | 522 ------------- internal/core/webrtc_server_test.go | 235 ------ internal/core/webrtc_session.go | 592 +++++++++++++++ 30 files changed, 2937 insertions(+), 2338 deletions(-) create mode 100644 internal/core/hls_http_server.go create mode 100644 internal/core/hls_manager.go rename internal/core/{hls_server_test.go => hls_manager_test.go} (98%) delete mode 100644 internal/core/hls_server.go delete mode 100644 internal/core/http_requestpool.go delete mode 100644 internal/core/webrtc_candidate_reader.go delete mode 100644 internal/core/webrtc_conn.go create mode 100644 internal/core/webrtc_http_server.go create mode 100644 internal/core/webrtc_manager.go create mode 100644 internal/core/webrtc_manager_test.go delete mode 100644 internal/core/webrtc_server.go delete mode 100644 internal/core/webrtc_server_test.go create mode 100644 internal/core/webrtc_session.go diff --git a/README.md b/README.md index abbd442e..353f60ba 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Live streams can be published to the server with: |RTMP servers and cameras|RTMP, RTMPS, Enhanced RTMP|H264|MPEG-4 Audio (AAC), MPEG-2 Audio (MP3)| |HLS servers and cameras|Low-Latency HLS, MP4-based HLS, legacy HLS|H265, H264|Opus, MPEG-4 Audio (AAC)| |UDP/MPEG-TS streams|Unicast, broadcast, multicast|H265, H264|Opus, MPEG-4 Audio (AAC)| -|WebRTC||AV1, VP9, VP8, H264|Opus, G722, G711| +|WebRTC|WHIP|AV1, VP9, VP8, H264|Opus, G722, G711| |Raspberry Pi Cameras||H264|| And can be read from the server with: @@ -26,7 +26,7 @@ And can be read from the server with: |RTSP|UDP, UDP-Multicast, TCP, RTSPS|AV1, VP9, VP8, H265, H264, MPEG-4 Video (H263, Xvid), MPEG-2 Video, M-JPEG and any RTP-compatible codec|Opus, MPEG-4 Audio (AAC), MPEG-2 Audio (MP3), G722, G711, LPCM and any RTP-compatible codec| |RTMP|RTMP, RTMPS, Enhanced RTMP|H264|MPEG-4 Audio (AAC), MPEG-2 Audio (MP3)| |HLS|Low-Latency HLS, MP4-based HLS, legacy HLS|H265, H264|Opus, MPEG-4 Audio (AAC)| -|WebRTC||AV1, VP9, VP8, H264|Opus, G722, G711| +|WebRTC|WHEP|AV1, VP9, VP8, H264|Opus, G722, G711| Features: @@ -546,9 +546,9 @@ rtmp_conns_bytes_received{id="[id]",state="[state]"} 1234 rtmp_conns_bytes_sent{id="[id]",state="[state]"} 187 # metrics of every WebRTC connection -webrtc_conns{id="[id]"} 1 -webrtc_conns_bytes_received{id="[id]",state="[state]"} 1234 -webrtc_conns_bytes_sent{id="[id]",state="[state]"} 187 +webrtc_sessions{id="[id]"} 1 +webrtc_sessions_bytes_received{id="[id]",state="[state]"} 1234 +webrtc_sessions_bytes_sent{id="[id]",state="[state]"} 187 ``` ### pprof @@ -1209,12 +1209,26 @@ For more advanced options, you can create and serve a custom web page by startin ## Standards -* [RTSP/RTP/RTCP standards](https://github.com/bluenviron/gortsplib#standards) -* [HLS standards](https://github.com/bluenviron/gohlslib#standards) -* [Codec standards](https://github.com/bluenviron/mediacommon#standards) -* [RTMP](https://rtmp.veriskope.com/pdf/rtmp_specification_1.0.pdf) -* [Enhanced RTMP](https://raw.githubusercontent.com/veovera/enhanced-rtmp/main/enhanced-rtmp-v1.pdf) -* [Golang project layout](https://github.com/golang-standards/project-layout) +* RTSP + * [RTSP/RTP/RTCP standards](https://github.com/bluenviron/gortsplib#standards) + +* HLS + * [HLS standards](https://github.com/bluenviron/gohlslib#standards) + +* RTMP + * [RTMP](https://rtmp.veriskope.com/pdf/rtmp_specification_1.0.pdf) + * [Enhanced RTMP](https://raw.githubusercontent.com/veovera/enhanced-rtmp/main/enhanced-rtmp-v1.pdf) + +* WebRTC + * [WebRTC: Real-Time Communication in Browsers](https://www.w3.org/TR/webrtc/) + * [WebRTC Ingestion Protocol (WHIP)](https://datatracker.ietf.org/doc/draft-ietf-wish-whip/) + * [WebRTC HTTP Egress Protocol (WHEP)](https://datatracker.ietf.org/doc/draft-murillo-whep/) + +* Video and audio codecs + * [Codec standards](https://github.com/bluenviron/mediacommon#standards) + +* Other + * [Golang project layout](https://github.com/golang-standards/project-layout) ## Links @@ -1222,10 +1236,10 @@ Related projects * [gortsplib (RTSP library used internally)](https://github.com/bluenviron/gortsplib) * [gohlslib (HLS library used internally)](https://github.com/bluenviron/gohlslib) +* [pion/webrtc (WebRTC library used internally)](https://github.com/pion/webrtc) * [pion/sdp (SDP library used internally)](https://github.com/pion/sdp) * [pion/rtp (RTP library used internally)](https://github.com/pion/rtp) * [pion/rtcp (RTCP library used internally)](https://github.com/pion/rtcp) -* [pion/webrtc (WebRTC library used internally)](https://github.com/pion/webrtc) * [notedit/rtmp (RTMP library used internally)](https://github.com/notedit/rtmp) * [go-astits (MPEG-TS library used internally)](https://github.com/asticode/go-astits) * [go-mp4 (MP4 library used internally)](https://github.com/abema/go-mp4) diff --git a/apidocs/openapi.yaml b/apidocs/openapi.yaml index b6e3e39d..84421c73 100644 --- a/apidocs/openapi.yaml +++ b/apidocs/openapi.yaml @@ -348,7 +348,7 @@ components: - rtspsSession - redirect - udpSource - - webRTCConn + - webRTCSession id: type: string @@ -807,9 +807,9 @@ paths: '500': description: internal server error. - /v1/webrtcconns/list: + /v1/webrtcsessions/list: get: - operationId: webrtcConnsList + operationId: webrtcSessionsList summary: returns all WebRTC connections. description: '' responses: @@ -824,9 +824,9 @@ paths: '500': description: internal server error. - /v1/webrtcconns/kick/{id}: + /v1/webrtcsessions/kick/{id}: post: - operationId: webrtcConnsKick + operationId: webrtcSessionsKick summary: kicks out a WebRTC connection from the server. description: '' parameters: diff --git a/internal/core/api.go b/internal/core/api.go index c363618a..6babf981 100644 --- a/internal/core/api.go +++ b/internal/core/api.go @@ -11,6 +11,7 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/google/uuid" "github.com/aler9/mediamtx/internal/conf" "github.com/aler9/mediamtx/internal/logger" @@ -82,19 +83,19 @@ type apiPathManager interface { apiPathsList() pathAPIPathsListRes } -type apiHLSServer interface { - apiMuxersList() hlsServerAPIMuxersListRes +type apiHLSManager interface { + apiMuxersList() hlsManagerAPIMuxersListRes } type apiRTSPServer interface { apiConnsList() rtspServerAPIConnsListRes apiSessionsList() rtspServerAPISessionsListRes - apiSessionsKick(string) rtspServerAPISessionsKickRes + apiSessionsKick(uuid.UUID) rtspServerAPISessionsKickRes } type apiRTMPServer interface { apiConnsList() rtmpServerAPIConnsListRes - apiConnsKick(id string) rtmpServerAPIConnsKickRes + apiConnsKick(uuid.UUID) rtmpServerAPIConnsKickRes } type apiParent interface { @@ -102,21 +103,21 @@ type apiParent interface { apiConfigSet(conf *conf.Conf) } -type apiWebRTCServer interface { - apiConnsList() webRTCServerAPIConnsListRes - apiConnsKick(id string) webRTCServerAPIConnsKickRes +type apiWebRTCManager interface { + apiSessionsList() webRTCManagerAPISessionsListRes + apiSessionsKick(uuid.UUID) webRTCManagerAPISessionsKickRes } type api struct { - conf *conf.Conf - pathManager apiPathManager - rtspServer apiRTSPServer - rtspsServer apiRTSPServer - rtmpServer apiRTMPServer - rtmpsServer apiRTMPServer - hlsServer apiHLSServer - webRTCServer apiWebRTCServer - parent apiParent + conf *conf.Conf + pathManager apiPathManager + rtspServer apiRTSPServer + rtspsServer apiRTSPServer + rtmpServer apiRTMPServer + rtmpsServer apiRTMPServer + hlsManager apiHLSManager + webRTCManager apiWebRTCManager + parent apiParent ln net.Listener httpServer *http.Server @@ -132,8 +133,8 @@ func newAPI( rtspsServer apiRTSPServer, rtmpServer apiRTMPServer, rtmpsServer apiRTMPServer, - hlsServer apiHLSServer, - webRTCServer apiWebRTCServer, + hlsManager apiHLSManager, + webRTCManager apiWebRTCManager, parent apiParent, ) (*api, error) { ln, err := net.Listen(restrictNetwork("tcp", address)) @@ -142,16 +143,16 @@ func newAPI( } a := &api{ - conf: conf, - pathManager: pathManager, - rtspServer: rtspServer, - rtspsServer: rtspsServer, - rtmpServer: rtmpServer, - rtmpsServer: rtmpsServer, - hlsServer: hlsServer, - webRTCServer: webRTCServer, - parent: parent, - ln: ln, + conf: conf, + pathManager: pathManager, + rtspServer: rtspServer, + rtspsServer: rtspsServer, + rtmpServer: rtmpServer, + rtmpsServer: rtmpsServer, + hlsManager: hlsManager, + webRTCManager: webRTCManager, + parent: parent, + ln: ln, } router := gin.New() @@ -167,7 +168,7 @@ func newAPI( group.POST("/v1/config/paths/edit/*name", a.onConfigPathsEdit) group.POST("/v1/config/paths/remove/*name", a.onConfigPathsDelete) - if !interfaceIsEmpty(a.hlsServer) { + if !interfaceIsEmpty(a.hlsManager) { group.GET("/v1/hlsmuxers/list", a.onHLSMuxersList) } @@ -195,9 +196,9 @@ func newAPI( group.POST("/v1/rtmpsconns/kick/:id", a.onRTMPSConnsKick) } - if !interfaceIsEmpty(a.webRTCServer) { - group.GET("/v1/webrtcconns/list", a.onWebRTCConnsList) - group.POST("/v1/webrtcconns/kick/:id", a.onWebRTCConnsKick) + if !interfaceIsEmpty(a.webRTCManager) { + group.GET("/v1/webrtcsessions/list", a.onWebRTCSessionsList) + group.POST("/v1/webrtcsessions/kick/:id", a.onWebRTCSessionsKick) } a.httpServer = &http.Server{ @@ -412,9 +413,13 @@ func (a *api) onRTSPSessionsList(ctx *gin.Context) { } func (a *api) onRTSPSessionsKick(ctx *gin.Context) { - id := ctx.Param("id") + uuid, err := uuid.Parse(ctx.Param("id")) + if err != nil { + ctx.AbortWithStatus(http.StatusBadRequest) + return + } - res := a.rtspServer.apiSessionsKick(id) + res := a.rtspServer.apiSessionsKick(uuid) if res.err != nil { return } @@ -443,9 +448,13 @@ func (a *api) onRTSPSSessionsList(ctx *gin.Context) { } func (a *api) onRTSPSSessionsKick(ctx *gin.Context) { - id := ctx.Param("id") + uuid, err := uuid.Parse(ctx.Param("id")) + if err != nil { + ctx.AbortWithStatus(http.StatusBadRequest) + return + } - res := a.rtspsServer.apiSessionsKick(id) + res := a.rtspsServer.apiSessionsKick(uuid) if res.err != nil { return } @@ -464,9 +473,13 @@ func (a *api) onRTMPConnsList(ctx *gin.Context) { } func (a *api) onRTMPConnsKick(ctx *gin.Context) { - id := ctx.Param("id") + uuid, err := uuid.Parse(ctx.Param("id")) + if err != nil { + ctx.AbortWithStatus(http.StatusBadRequest) + return + } - res := a.rtmpServer.apiConnsKick(id) + res := a.rtmpServer.apiConnsKick(uuid) if res.err != nil { return } @@ -485,9 +498,13 @@ func (a *api) onRTMPSConnsList(ctx *gin.Context) { } func (a *api) onRTMPSConnsKick(ctx *gin.Context) { - id := ctx.Param("id") + uuid, err := uuid.Parse(ctx.Param("id")) + if err != nil { + ctx.AbortWithStatus(http.StatusBadRequest) + return + } - res := a.rtmpsServer.apiConnsKick(id) + res := a.rtmpsServer.apiConnsKick(uuid) if res.err != nil { return } @@ -496,7 +513,7 @@ func (a *api) onRTMPSConnsKick(ctx *gin.Context) { } func (a *api) onHLSMuxersList(ctx *gin.Context) { - res := a.hlsServer.apiMuxersList() + res := a.hlsManager.apiMuxersList() if res.err != nil { ctx.AbortWithStatus(http.StatusInternalServerError) return @@ -505,8 +522,8 @@ func (a *api) onHLSMuxersList(ctx *gin.Context) { ctx.JSON(http.StatusOK, res.data) } -func (a *api) onWebRTCConnsList(ctx *gin.Context) { - res := a.webRTCServer.apiConnsList() +func (a *api) onWebRTCSessionsList(ctx *gin.Context) { + res := a.webRTCManager.apiSessionsList() if res.err != nil { ctx.AbortWithStatus(http.StatusInternalServerError) return @@ -515,10 +532,14 @@ func (a *api) onWebRTCConnsList(ctx *gin.Context) { ctx.JSON(http.StatusOK, res.data) } -func (a *api) onWebRTCConnsKick(ctx *gin.Context) { - id := ctx.Param("id") +func (a *api) onWebRTCSessionsKick(ctx *gin.Context) { + uuid, err := uuid.Parse(ctx.Param("id")) + if err != nil { + ctx.AbortWithStatus(http.StatusBadRequest) + return + } - res := a.webRTCServer.apiConnsKick(id) + res := a.webRTCManager.apiSessionsKick(uuid) if res.err != nil { return } diff --git a/internal/core/api_test.go b/internal/core/api_test.go index cf3b5aac..d4c79999 100644 --- a/internal/core/api_test.go +++ b/internal/core/api_test.go @@ -545,8 +545,7 @@ func TestAPIProtocolSpecificList(t *testing.T) { require.NoError(t, err) defer source.Close() - c, err := newWebRTCTestClient("ws://localhost:8889/mypath/ws") - require.NoError(t, err) + c := newWebRTCTestClient(t, "http://localhost:8889/mypath/whep", false) defer c.close() time.Sleep(500 * time.Millisecond) @@ -563,7 +562,7 @@ func TestAPIProtocolSpecificList(t *testing.T) { Payload: []byte{0x01, 0x02, 0x03, 0x04}, }) - <-c.track + <-c.incomingTrack } switch ca { @@ -639,7 +638,7 @@ func TestAPIProtocolSpecificList(t *testing.T) { var out struct { Items map[string]item `json:"items"` } - err = httpRequest(http.MethodGet, "http://localhost:9997/v1/webrtcconns/list", nil, &out) + err = httpRequest(http.MethodGet, "http://localhost:9997/v1/webrtcsessions/list", nil, &out) require.NoError(t, err) var firstID string @@ -667,6 +666,7 @@ func TestAPIKick(t *testing.T) { "rtsp", "rtsps", "rtmp", + "webrtc", } { t.Run(ca, func(t *testing.T) { conf := "api: yes\n" @@ -720,6 +720,10 @@ func TestAPIKick(t *testing.T) { err = conn.WriteTracks(testFormatH264, nil) require.NoError(t, err) + + case "webrtc": + c := newWebRTCTestClient(t, "http://localhost:8889/mypath/whip", true) + defer c.close() } var pa string @@ -732,6 +736,9 @@ func TestAPIKick(t *testing.T) { case "rtmp": pa = "rtmpconns" + + case "webrtc": + pa = "webrtcsessions" } var out1 struct { diff --git a/internal/core/core.go b/internal/core/core.go index 1a7a4472..d7df2f3c 100644 --- a/internal/core/core.go +++ b/internal/core/core.go @@ -38,8 +38,8 @@ type Core struct { rtspsServer *rtspServer rtmpServer *rtmpServer rtmpsServer *rtmpServer - hlsServer *hlsServer - webRTCServer *webRTCServer + hlsManager *hlsManager + webRTCManager *webRTCManager api *api confWatcher *confwatcher.ConfWatcher @@ -385,8 +385,8 @@ func (p *Core) createResources(initial bool) error { } if !p.conf.HLSDisable { - if p.hlsServer == nil { - p.hlsServer, err = newHLSServer( + if p.hlsManager == nil { + p.hlsManager, err = newHLSManager( p.ctx, p.conf.HLSAddress, p.conf.HLSEncryption, @@ -415,8 +415,8 @@ func (p *Core) createResources(initial bool) error { } if !p.conf.WebRTCDisable { - if p.webRTCServer == nil { - p.webRTCServer, err = newWebRTCServer( + if p.webRTCManager == nil { + p.webRTCManager, err = newWebRTCManager( p.ctx, p.conf.WebRTCAddress, p.conf.WebRTCEncryption, @@ -451,8 +451,8 @@ func (p *Core) createResources(initial bool) error { p.rtspsServer, p.rtmpServer, p.rtmpsServer, - p.hlsServer, - p.webRTCServer, + p.hlsManager, + p.webRTCManager, p, ) if err != nil { @@ -565,7 +565,7 @@ func (p *Core) closeResources(newConf *conf.Conf, calledByAPI bool) { closeMetrics || closePathManager - closeHLSServer := newConf == nil || + closeHLSManager := newConf == nil || newConf.HLSDisable != p.conf.HLSDisable || newConf.HLSAddress != p.conf.HLSAddress || newConf.HLSEncryption != p.conf.HLSEncryption || @@ -586,7 +586,7 @@ func (p *Core) closeResources(newConf *conf.Conf, calledByAPI bool) { closePathManager || closeMetrics - closeWebRTCServer := newConf == nil || + closeWebRTCManager := newConf == nil || newConf.WebRTCDisable != p.conf.WebRTCDisable || newConf.WebRTCAddress != p.conf.WebRTCAddress || newConf.WebRTCEncryption != p.conf.WebRTCEncryption || @@ -611,8 +611,8 @@ func (p *Core) closeResources(newConf *conf.Conf, calledByAPI bool) { closeRTSPServer || closeRTSPSServer || closeRTMPServer || - closeHLSServer || - closeWebRTCServer + closeHLSManager || + closeWebRTCManager if newConf == nil && p.confWatcher != nil { p.confWatcher.Close() @@ -643,14 +643,14 @@ func (p *Core) closeResources(newConf *conf.Conf, calledByAPI bool) { p.pathManager = nil } - if closeWebRTCServer && p.webRTCServer != nil { - p.webRTCServer.close() - p.webRTCServer = nil + if closeWebRTCManager && p.webRTCManager != nil { + p.webRTCManager.close() + p.webRTCManager = nil } - if closeHLSServer && p.hlsServer != nil { - p.hlsServer.close() - p.hlsServer = nil + if closeHLSManager && p.hlsManager != nil { + p.hlsManager.close() + p.hlsManager = nil } if closeRTMPSServer && p.rtmpsServer != nil { diff --git a/internal/core/hls_http_server.go b/internal/core/hls_http_server.go new file mode 100644 index 00000000..d4b571c0 --- /dev/null +++ b/internal/core/hls_http_server.go @@ -0,0 +1,148 @@ +package core + +import ( + "context" + "crypto/tls" + "log" + "net" + "net/http" + gopath "path" + "strings" + "time" + + "github.com/gin-gonic/gin" + + "github.com/aler9/mediamtx/internal/conf" + "github.com/aler9/mediamtx/internal/logger" +) + +type hlsHTTPServerParent interface { + logger.Writer + handleRequest(req hlsMuxerHandleRequestReq) +} + +type hlsHTTPServer struct { + allowOrigin string + parent hlsHTTPServerParent + + ln net.Listener + inner *http.Server +} + +func newHLSHTTPServer( + address string, + encryption bool, + serverKey string, + serverCert string, + allowOrigin string, + trustedProxies conf.IPsOrCIDRs, + readTimeout conf.StringDuration, + parent hlsHTTPServerParent, +) (*hlsHTTPServer, error) { + ln, err := net.Listen(restrictNetwork("tcp", address)) + if err != nil { + return nil, err + } + + var tlsConfig *tls.Config + if encryption { + crt, err := tls.LoadX509KeyPair(serverCert, serverKey) + if err != nil { + ln.Close() + return nil, err + } + + tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{crt}, + } + } + + s := &hlsHTTPServer{ + allowOrigin: allowOrigin, + parent: parent, + ln: ln, + } + + router := gin.New() + httpSetTrustedProxies(router, trustedProxies) + + router.NoRoute(httpLoggerMiddleware(s), httpServerHeaderMiddleware, s.onRequest) + + s.inner = &http.Server{ + Handler: router, + TLSConfig: tlsConfig, + ReadHeaderTimeout: time.Duration(readTimeout), + ErrorLog: log.New(&nilWriter{}, "", 0), + } + + if tlsConfig != nil { + go s.inner.ServeTLS(s.ln, "", "") + } else { + go s.inner.Serve(s.ln) + } + + return s, nil +} + +func (s *hlsHTTPServer) Log(level logger.Level, format string, args ...interface{}) { + s.parent.Log(level, format, args...) +} + +func (s *hlsHTTPServer) close() { + s.inner.Shutdown(context.Background()) + s.ln.Close() // in case Shutdown() is called before Serve() +} + +func (s *hlsHTTPServer) onRequest(ctx *gin.Context) { + ctx.Writer.Header().Set("Access-Control-Allow-Origin", s.allowOrigin) + ctx.Writer.Header().Set("Access-Control-Allow-Credentials", "true") + + switch ctx.Request.Method { + case http.MethodGet: + + case http.MethodOptions: + ctx.Writer.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") + ctx.Writer.Header().Set("Access-Control-Allow-Headers", ctx.Request.Header.Get("Access-Control-Request-Headers")) + ctx.Writer.WriteHeader(http.StatusOK) + return + + default: + return + } + + // remove leading prefix + pa := ctx.Request.URL.Path[1:] + + switch pa { + case "", "favicon.ico": + return + } + + dir, fname := func() (string, string) { + if strings.HasSuffix(pa, ".m3u8") || + strings.HasSuffix(pa, ".ts") || + strings.HasSuffix(pa, ".mp4") || + strings.HasSuffix(pa, ".mp") { + return gopath.Dir(pa), gopath.Base(pa) + } + return pa, "" + }() + + if fname == "" && !strings.HasSuffix(dir, "/") { + ctx.Writer.Header().Set("Location", "/"+dir+"/") + ctx.Writer.WriteHeader(http.StatusMovedPermanently) + return + } + + if strings.HasSuffix(fname, ".mp") { + fname += "4" + } + + dir = strings.TrimSuffix(dir, "/") + + s.parent.handleRequest(hlsMuxerHandleRequestReq{ + path: dir, + file: fname, + ctx: ctx, + }) +} diff --git a/internal/core/hls_manager.go b/internal/core/hls_manager.go new file mode 100644 index 00000000..921828e5 --- /dev/null +++ b/internal/core/hls_manager.go @@ -0,0 +1,314 @@ +package core + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/aler9/mediamtx/internal/conf" + "github.com/aler9/mediamtx/internal/logger" +) + +type nilWriter struct{} + +func (nilWriter) Write(p []byte) (int, error) { + return len(p), nil +} + +type hlsManagerAPIMuxersListItem struct { + Created time.Time `json:"created"` + LastRequest time.Time `json:"lastRequest"` + BytesSent uint64 `json:"bytesSent"` +} + +type hlsManagerAPIMuxersListData struct { + Items map[string]hlsManagerAPIMuxersListItem `json:"items"` +} + +type hlsManagerAPIMuxersListRes struct { + data *hlsManagerAPIMuxersListData + muxers map[string]*hlsMuxer + err error +} + +type hlsManagerAPIMuxersListReq struct { + res chan hlsManagerAPIMuxersListRes +} + +type hlsManagerAPIMuxersListSubReq struct { + data *hlsManagerAPIMuxersListData + res chan struct{} +} + +type hlsManagerParent interface { + logger.Writer +} + +type hlsManager struct { + externalAuthenticationURL string + alwaysRemux bool + variant conf.HLSVariant + segmentCount int + segmentDuration conf.StringDuration + partDuration conf.StringDuration + segmentMaxSize conf.StringSize + directory string + readBufferCount int + pathManager *pathManager + metrics *metrics + parent hlsManagerParent + + ctx context.Context + ctxCancel func() + wg sync.WaitGroup + httpServer *hlsHTTPServer + muxers map[string]*hlsMuxer + + // in + chPathSourceReady chan *path + chPathSourceNotReady chan *path + chHandleRequest chan hlsMuxerHandleRequestReq + chMuxerClose chan *hlsMuxer + chAPIMuxerList chan hlsManagerAPIMuxersListReq +} + +func newHLSManager( + parentCtx context.Context, + address string, + encryption bool, + serverKey string, + serverCert string, + externalAuthenticationURL string, + alwaysRemux bool, + variant conf.HLSVariant, + segmentCount int, + segmentDuration conf.StringDuration, + partDuration conf.StringDuration, + segmentMaxSize conf.StringSize, + allowOrigin string, + trustedProxies conf.IPsOrCIDRs, + directory string, + readTimeout conf.StringDuration, + readBufferCount int, + pathManager *pathManager, + metrics *metrics, + parent hlsManagerParent, +) (*hlsManager, error) { + ctx, ctxCancel := context.WithCancel(parentCtx) + + m := &hlsManager{ + externalAuthenticationURL: externalAuthenticationURL, + alwaysRemux: alwaysRemux, + variant: variant, + segmentCount: segmentCount, + segmentDuration: segmentDuration, + partDuration: partDuration, + segmentMaxSize: segmentMaxSize, + directory: directory, + readBufferCount: readBufferCount, + pathManager: pathManager, + parent: parent, + metrics: metrics, + ctx: ctx, + ctxCancel: ctxCancel, + muxers: make(map[string]*hlsMuxer), + chPathSourceReady: make(chan *path), + chPathSourceNotReady: make(chan *path), + chHandleRequest: make(chan hlsMuxerHandleRequestReq), + chMuxerClose: make(chan *hlsMuxer), + chAPIMuxerList: make(chan hlsManagerAPIMuxersListReq), + } + + var err error + m.httpServer, err = newHLSHTTPServer( + address, + encryption, + serverKey, + serverCert, + allowOrigin, + trustedProxies, + readTimeout, + m, + ) + if err != nil { + ctxCancel() + return nil, err + } + + m.Log(logger.Info, "listener opened on "+address) + + m.pathManager.hlsManagerSet(m) + + if m.metrics != nil { + m.metrics.hlsManagerSet(m) + } + + m.wg.Add(1) + go m.run() + + return m, nil +} + +// Log is the main logging function. +func (m *hlsManager) Log(level logger.Level, format string, args ...interface{}) { + m.parent.Log(level, "[HLS] "+format, append([]interface{}{}, args...)...) +} + +func (m *hlsManager) close() { + m.Log(logger.Info, "listener is closing") + m.ctxCancel() + m.wg.Wait() +} + +func (m *hlsManager) run() { + defer m.wg.Done() + +outer: + for { + select { + case pa := <-m.chPathSourceReady: + if m.alwaysRemux { + m.createMuxer(pa.name, "") + } + + case pa := <-m.chPathSourceNotReady: + if m.alwaysRemux { + c, ok := m.muxers[pa.name] + if ok { + c.close() + delete(m.muxers, pa.name) + } + } + + case req := <-m.chHandleRequest: + r, ok := m.muxers[req.path] + switch { + case ok: + r.processRequest(&req) + + case m.alwaysRemux: + req.res <- nil + + default: + r := m.createMuxer(req.path, req.ctx.ClientIP()) + r.processRequest(&req) + } + + case c := <-m.chMuxerClose: + if c2, ok := m.muxers[c.PathName()]; !ok || c2 != c { + continue + } + delete(m.muxers, c.PathName()) + + case req := <-m.chAPIMuxerList: + muxers := make(map[string]*hlsMuxer) + + for name, m := range m.muxers { + muxers[name] = m + } + + req.res <- hlsManagerAPIMuxersListRes{ + muxers: muxers, + } + + case <-m.ctx.Done(): + break outer + } + } + + m.ctxCancel() + + m.httpServer.close() + + m.pathManager.hlsManagerSet(nil) + + if m.metrics != nil { + m.metrics.hlsManagerSet(nil) + } +} + +func (m *hlsManager) createMuxer(pathName string, remoteAddr string) *hlsMuxer { + r := newHLSMuxer( + m.ctx, + remoteAddr, + m.externalAuthenticationURL, + m.alwaysRemux, + m.variant, + m.segmentCount, + m.segmentDuration, + m.partDuration, + m.segmentMaxSize, + m.directory, + m.readBufferCount, + &m.wg, + pathName, + m.pathManager, + m) + m.muxers[pathName] = r + return r +} + +// muxerClose is called by hlsMuxer. +func (m *hlsManager) muxerClose(c *hlsMuxer) { + select { + case m.chMuxerClose <- c: + case <-m.ctx.Done(): + } +} + +// pathSourceReady is called by pathManager. +func (m *hlsManager) pathSourceReady(pa *path) { + select { + case m.chPathSourceReady <- pa: + case <-m.ctx.Done(): + } +} + +// pathSourceNotReady is called by pathManager. +func (m *hlsManager) pathSourceNotReady(pa *path) { + select { + case m.chPathSourceNotReady <- pa: + case <-m.ctx.Done(): + } +} + +// apiMuxersList is called by api. +func (m *hlsManager) apiMuxersList() hlsManagerAPIMuxersListRes { + req := hlsManagerAPIMuxersListReq{ + res: make(chan hlsManagerAPIMuxersListRes), + } + + select { + case m.chAPIMuxerList <- req: + res := <-req.res + + res.data = &hlsManagerAPIMuxersListData{ + Items: make(map[string]hlsManagerAPIMuxersListItem), + } + + for _, pa := range res.muxers { + pa.apiMuxersList(hlsManagerAPIMuxersListSubReq{data: res.data}) + } + + return res + + case <-m.ctx.Done(): + return hlsManagerAPIMuxersListRes{err: fmt.Errorf("terminated")} + } +} + +func (m *hlsManager) handleRequest(req hlsMuxerHandleRequestReq) { + req.res = make(chan *hlsMuxer) + + select { + case m.chHandleRequest <- req: + muxer := <-req.res + if muxer != nil { + req.ctx.Request.URL.Path = req.file + muxer.handleRequest(req.ctx) + } + + case <-m.ctx.Done(): + } +} diff --git a/internal/core/hls_server_test.go b/internal/core/hls_manager_test.go similarity index 98% rename from internal/core/hls_server_test.go rename to internal/core/hls_manager_test.go index 5f352b2f..a0dd526f 100644 --- a/internal/core/hls_server_test.go +++ b/internal/core/hls_manager_test.go @@ -100,7 +100,7 @@ func httpPullFile(u string) ([]byte, error) { return io.ReadAll(res.Body) } -func TestHLSServerNotFound(t *testing.T) { +func TestHLSReadNotFound(t *testing.T) { p, ok := newInstance("") require.Equal(t, true, ok) defer p.Close() @@ -114,7 +114,7 @@ func TestHLSServerNotFound(t *testing.T) { require.Equal(t, http.StatusNotFound, res.StatusCode) } -func TestHLSServer(t *testing.T) { +func TestHLSRead(t *testing.T) { p, ok := newInstance("hlsAlwaysRemux: yes\n" + "paths:\n" + " all:\n") diff --git a/internal/core/hls_muxer.go b/internal/core/hls_muxer.go index 91e48d9c..92a0441c 100644 --- a/internal/core/hls_muxer.go +++ b/internal/core/hls_muxer.go @@ -48,11 +48,11 @@ func (w *responseWriterWithCounter) Write(p []byte) (int, error) { return n, err } -type hlsMuxerRequest struct { - path string - file string - clientIP string - res chan *hlsMuxer +type hlsMuxerHandleRequestReq struct { + path string + file string + ctx *gin.Context + res chan *hlsMuxer } type hlsMuxerPathManager interface { @@ -87,12 +87,12 @@ type hlsMuxer struct { ringBuffer *ringbuffer.RingBuffer lastRequestTime *int64 muxer *gohlslib.Muxer - requests []*hlsMuxerRequest + requests []*hlsMuxerHandleRequestReq bytesSent *uint64 // in - chRequest chan *hlsMuxerRequest - chAPIHLSMuxersList chan hlsServerAPIMuxersListSubReq + chRequest chan *hlsMuxerHandleRequestReq + chAPIHLSMuxersList chan hlsManagerAPIMuxersListSubReq } func newHLSMuxer( @@ -137,8 +137,8 @@ func newHLSMuxer( return &v }(), bytesSent: new(uint64), - chRequest: make(chan *hlsMuxerRequest), - chAPIHLSMuxersList: make(chan hlsServerAPIMuxersListSubReq), + chRequest: make(chan *hlsMuxerHandleRequestReq), + chAPIHLSMuxersList: make(chan hlsManagerAPIMuxersListSubReq), } m.Log(logger.Info, "created %s", func() string { @@ -213,7 +213,7 @@ func (m *hlsMuxer) run() { } case req := <-m.chAPIHLSMuxersList: - req.data.Items[m.pathName] = hlsServerAPIMuxersListItem{ + req.data.Items[m.pathName] = hlsManagerAPIMuxersListItem{ Created: m.created, LastRequest: time.Unix(0, atomic.LoadInt64(m.lastRequestTime)), BytesSent: atomic.LoadUint64(m.bytesSent), @@ -592,7 +592,7 @@ func (m *hlsMuxer) handleRequest(ctx *gin.Context) { } // processRequest is called by hlsserver.Server (forwarded from ServeHTTP). -func (m *hlsMuxer) processRequest(req *hlsMuxerRequest) { +func (m *hlsMuxer) processRequest(req *hlsMuxerHandleRequestReq) { select { case m.chRequest <- req: case <-m.ctx.Done(): @@ -601,7 +601,7 @@ func (m *hlsMuxer) processRequest(req *hlsMuxerRequest) { } // apiMuxersList is called by api. -func (m *hlsMuxer) apiMuxersList(req hlsServerAPIMuxersListSubReq) { +func (m *hlsMuxer) apiMuxersList(req hlsManagerAPIMuxersListSubReq) { req.res = make(chan struct{}) select { case m.chAPIHLSMuxersList <- req: diff --git a/internal/core/hls_server.go b/internal/core/hls_server.go deleted file mode 100644 index 12ea5852..00000000 --- a/internal/core/hls_server.go +++ /dev/null @@ -1,398 +0,0 @@ -package core - -import ( - "context" - "crypto/tls" - "fmt" - "log" - "net" - "net/http" - gopath "path" - "strings" - "sync" - "time" - - "github.com/gin-gonic/gin" - - "github.com/aler9/mediamtx/internal/conf" - "github.com/aler9/mediamtx/internal/logger" -) - -type nilWriter struct{} - -func (nilWriter) Write(p []byte) (int, error) { - return len(p), nil -} - -type hlsServerAPIMuxersListItem struct { - Created time.Time `json:"created"` - LastRequest time.Time `json:"lastRequest"` - BytesSent uint64 `json:"bytesSent"` -} - -type hlsServerAPIMuxersListData struct { - Items map[string]hlsServerAPIMuxersListItem `json:"items"` -} - -type hlsServerAPIMuxersListRes struct { - data *hlsServerAPIMuxersListData - muxers map[string]*hlsMuxer - err error -} - -type hlsServerAPIMuxersListReq struct { - res chan hlsServerAPIMuxersListRes -} - -type hlsServerAPIMuxersListSubReq struct { - data *hlsServerAPIMuxersListData - res chan struct{} -} - -type hlsServerParent interface { - logger.Writer -} - -type hlsServer struct { - externalAuthenticationURL string - alwaysRemux bool - variant conf.HLSVariant - segmentCount int - segmentDuration conf.StringDuration - partDuration conf.StringDuration - segmentMaxSize conf.StringSize - allowOrigin string - directory string - readBufferCount int - pathManager *pathManager - metrics *metrics - parent hlsServerParent - - ctx context.Context - ctxCancel func() - wg sync.WaitGroup - ln net.Listener - httpServer *http.Server - muxers map[string]*hlsMuxer - - // in - chPathSourceReady chan *path - chPathSourceNotReady chan *path - request chan *hlsMuxerRequest - chMuxerClose chan *hlsMuxer - chAPIMuxerList chan hlsServerAPIMuxersListReq -} - -func newHLSServer( - parentCtx context.Context, - address string, - encryption bool, - serverKey string, - serverCert string, - externalAuthenticationURL string, - alwaysRemux bool, - variant conf.HLSVariant, - segmentCount int, - segmentDuration conf.StringDuration, - partDuration conf.StringDuration, - segmentMaxSize conf.StringSize, - allowOrigin string, - trustedProxies conf.IPsOrCIDRs, - directory string, - readTimeout conf.StringDuration, - readBufferCount int, - pathManager *pathManager, - metrics *metrics, - parent hlsServerParent, -) (*hlsServer, error) { - ln, err := net.Listen(restrictNetwork("tcp", address)) - if err != nil { - return nil, err - } - - var tlsConfig *tls.Config - if encryption { - crt, err := tls.LoadX509KeyPair(serverCert, serverKey) - if err != nil { - ln.Close() - return nil, err - } - - tlsConfig = &tls.Config{ - Certificates: []tls.Certificate{crt}, - } - } - - ctx, ctxCancel := context.WithCancel(parentCtx) - - s := &hlsServer{ - externalAuthenticationURL: externalAuthenticationURL, - alwaysRemux: alwaysRemux, - variant: variant, - segmentCount: segmentCount, - segmentDuration: segmentDuration, - partDuration: partDuration, - segmentMaxSize: segmentMaxSize, - allowOrigin: allowOrigin, - directory: directory, - readBufferCount: readBufferCount, - pathManager: pathManager, - parent: parent, - metrics: metrics, - ctx: ctx, - ctxCancel: ctxCancel, - ln: ln, - muxers: make(map[string]*hlsMuxer), - chPathSourceReady: make(chan *path), - chPathSourceNotReady: make(chan *path), - request: make(chan *hlsMuxerRequest), - chMuxerClose: make(chan *hlsMuxer), - chAPIMuxerList: make(chan hlsServerAPIMuxersListReq), - } - - router := gin.New() - httpSetTrustedProxies(router, trustedProxies) - - router.NoRoute(httpLoggerMiddleware(s), httpServerHeaderMiddleware, s.onRequest) - - s.httpServer = &http.Server{ - Handler: router, - TLSConfig: tlsConfig, - ReadHeaderTimeout: time.Duration(readTimeout), - ErrorLog: log.New(&nilWriter{}, "", 0), - } - - s.Log(logger.Info, "listener opened on "+address) - - s.pathManager.hlsServerSet(s) - - if s.metrics != nil { - s.metrics.hlsServerSet(s) - } - - s.wg.Add(1) - go s.run() - - return s, nil -} - -// Log is the main logging function. -func (s *hlsServer) Log(level logger.Level, format string, args ...interface{}) { - s.parent.Log(level, "[HLS] "+format, append([]interface{}{}, args...)...) -} - -func (s *hlsServer) close() { - s.Log(logger.Info, "listener is closing") - s.ctxCancel() - s.wg.Wait() -} - -func (s *hlsServer) run() { - defer s.wg.Done() - - if s.httpServer.TLSConfig != nil { - go s.httpServer.ServeTLS(s.ln, "", "") - } else { - go s.httpServer.Serve(s.ln) - } - -outer: - for { - select { - case pa := <-s.chPathSourceReady: - if s.alwaysRemux { - s.createMuxer(pa.name, "") - } - - case pa := <-s.chPathSourceNotReady: - if s.alwaysRemux { - c, ok := s.muxers[pa.name] - if ok { - c.close() - delete(s.muxers, pa.name) - } - } - - case req := <-s.request: - r, ok := s.muxers[req.path] - switch { - case ok: - r.processRequest(req) - - case s.alwaysRemux: - req.res <- nil - - default: - r := s.createMuxer(req.path, req.clientIP) - r.processRequest(req) - } - - case c := <-s.chMuxerClose: - if c2, ok := s.muxers[c.PathName()]; !ok || c2 != c { - continue - } - delete(s.muxers, c.PathName()) - - case req := <-s.chAPIMuxerList: - muxers := make(map[string]*hlsMuxer) - - for name, m := range s.muxers { - muxers[name] = m - } - - req.res <- hlsServerAPIMuxersListRes{ - muxers: muxers, - } - - case <-s.ctx.Done(): - break outer - } - } - - s.ctxCancel() - - s.httpServer.Shutdown(context.Background()) - s.ln.Close() // in case Shutdown() is called before Serve() - - s.pathManager.hlsServerSet(nil) - - if s.metrics != nil { - s.metrics.hlsServerSet(nil) - } -} - -func (s *hlsServer) onRequest(ctx *gin.Context) { - ctx.Writer.Header().Set("Access-Control-Allow-Origin", s.allowOrigin) - ctx.Writer.Header().Set("Access-Control-Allow-Credentials", "true") - - switch ctx.Request.Method { - case http.MethodGet: - - case http.MethodOptions: - ctx.Writer.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") - ctx.Writer.Header().Set("Access-Control-Allow-Headers", ctx.Request.Header.Get("Access-Control-Request-Headers")) - ctx.Writer.WriteHeader(http.StatusOK) - return - - default: - return - } - - // remove leading prefix - pa := ctx.Request.URL.Path[1:] - - switch pa { - case "", "favicon.ico": - return - } - - dir, fname := func() (string, string) { - if strings.HasSuffix(pa, ".m3u8") || - strings.HasSuffix(pa, ".ts") || - strings.HasSuffix(pa, ".mp4") || - strings.HasSuffix(pa, ".mp") { - return gopath.Dir(pa), gopath.Base(pa) - } - return pa, "" - }() - - if fname == "" && !strings.HasSuffix(dir, "/") { - ctx.Writer.Header().Set("Location", "/"+dir+"/") - ctx.Writer.WriteHeader(http.StatusMovedPermanently) - return - } - - if strings.HasSuffix(fname, ".mp") { - fname += "4" - } - - dir = strings.TrimSuffix(dir, "/") - - hreq := &hlsMuxerRequest{ - path: dir, - file: fname, - clientIP: ctx.ClientIP(), - res: make(chan *hlsMuxer), - } - - select { - case s.request <- hreq: - muxer := <-hreq.res - if muxer != nil { - ctx.Request.URL.Path = fname - muxer.handleRequest(ctx) - } - - case <-s.ctx.Done(): - } -} - -func (s *hlsServer) createMuxer(pathName string, remoteAddr string) *hlsMuxer { - r := newHLSMuxer( - s.ctx, - remoteAddr, - s.externalAuthenticationURL, - s.alwaysRemux, - s.variant, - s.segmentCount, - s.segmentDuration, - s.partDuration, - s.segmentMaxSize, - s.directory, - s.readBufferCount, - &s.wg, - pathName, - s.pathManager, - s) - s.muxers[pathName] = r - return r -} - -// muxerClose is called by hlsMuxer. -func (s *hlsServer) muxerClose(c *hlsMuxer) { - select { - case s.chMuxerClose <- c: - case <-s.ctx.Done(): - } -} - -// pathSourceReady is called by pathManager. -func (s *hlsServer) pathSourceReady(pa *path) { - select { - case s.chPathSourceReady <- pa: - case <-s.ctx.Done(): - } -} - -// pathSourceNotReady is called by pathManager. -func (s *hlsServer) pathSourceNotReady(pa *path) { - select { - case s.chPathSourceNotReady <- pa: - case <-s.ctx.Done(): - } -} - -// apiMuxersList is called by api. -func (s *hlsServer) apiMuxersList() hlsServerAPIMuxersListRes { - req := hlsServerAPIMuxersListReq{ - res: make(chan hlsServerAPIMuxersListRes), - } - - select { - case s.chAPIMuxerList <- req: - res := <-req.res - - res.data = &hlsServerAPIMuxersListData{ - Items: make(map[string]hlsServerAPIMuxersListItem), - } - - for _, pa := range res.muxers { - pa.apiMuxersList(hlsServerAPIMuxersListSubReq{data: res.data}) - } - - return res - - case <-s.ctx.Done(): - return hlsServerAPIMuxersListRes{err: fmt.Errorf("terminated")} - } -} diff --git a/internal/core/hls_source_test.go b/internal/core/hls_source_test.go index dbf21648..91f45332 100644 --- a/internal/core/hls_source_test.go +++ b/internal/core/hls_source_test.go @@ -20,19 +20,19 @@ import ( "github.com/stretchr/testify/require" ) -type testHLSServer struct { +type testHLSManager struct { s *http.Server clientConnected chan struct{} } -func newTestHLSServer() (*testHLSServer, error) { +func newTestHLSManager() (*testHLSManager, error) { ln, err := net.Listen("tcp", "localhost:5780") if err != nil { return nil, err } - ts := &testHLSServer{ + ts := &testHLSManager{ clientConnected: make(chan struct{}), } @@ -48,11 +48,11 @@ func newTestHLSServer() (*testHLSServer, error) { return ts, nil } -func (ts *testHLSServer) close() { +func (ts *testHLSManager) close() { ts.s.Shutdown(context.Background()) } -func (ts *testHLSServer) onPlaylist(ctx *gin.Context) { +func (ts *testHLSManager) onPlaylist(ctx *gin.Context) { cnt := `#EXTM3U #EXT-X-VERSION:3 #EXT-X-ALLOW-CACHE:NO @@ -69,7 +69,7 @@ segment2.ts io.Copy(ctx.Writer, bytes.NewReader([]byte(cnt))) } -func (ts *testHLSServer) onSegment1(ctx *gin.Context) { +func (ts *testHLSManager) onSegment1(ctx *gin.Context) { ctx.Writer.Header().Set("Content-Type", `video/MP2T`) mux := astits.NewMuxer(context.Background(), ctx.Writer) @@ -113,7 +113,7 @@ func (ts *testHLSServer) onSegment1(ctx *gin.Context) { }) } -func (ts *testHLSServer) onSegment2(ctx *gin.Context) { +func (ts *testHLSManager) onSegment2(ctx *gin.Context) { <-ts.clientConnected ctx.Writer.Header().Set("Content-Type", `video/MP2T`) @@ -199,7 +199,7 @@ func (ts *testHLSServer) onSegment2(ctx *gin.Context) { } func TestHLSSource(t *testing.T) { - ts, err := newTestHLSServer() + ts, err := newTestHLSManager() require.NoError(t, err) defer ts.close() diff --git a/internal/core/http_requestpool.go b/internal/core/http_requestpool.go deleted file mode 100644 index b4c0375c..00000000 --- a/internal/core/http_requestpool.go +++ /dev/null @@ -1,25 +0,0 @@ -package core - -import ( - "sync" - - "github.com/gin-gonic/gin" -) - -type httpRequestPool struct { - wg sync.WaitGroup -} - -func newHTTPRequestPool() *httpRequestPool { - return &httpRequestPool{} -} - -func (rp *httpRequestPool) mw(ctx *gin.Context) { - rp.wg.Add(1) - ctx.Next() - rp.wg.Done() -} - -func (rp *httpRequestPool) close() { - rp.wg.Wait() -} diff --git a/internal/core/metrics.go b/internal/core/metrics.go index 48b7a6d8..3f9e1b45 100644 --- a/internal/core/metrics.go +++ b/internal/core/metrics.go @@ -27,15 +27,15 @@ type metricsParent interface { type metrics struct { parent metricsParent - ln net.Listener - httpServer *http.Server - mutex sync.Mutex - pathManager apiPathManager - rtspServer apiRTSPServer - rtspsServer apiRTSPServer - rtmpServer apiRTMPServer - hlsServer apiHLSServer - webRTCServer apiWebRTCServer + ln net.Listener + httpServer *http.Server + mutex sync.Mutex + pathManager apiPathManager + rtspServer apiRTSPServer + rtspsServer apiRTSPServer + rtmpServer apiRTMPServer + hlsManager apiHLSManager + webRTCManager apiWebRTCManager } func newMetrics( @@ -104,8 +104,8 @@ func (m *metrics) onMetrics(ctx *gin.Context) { out += metric("paths", "", 0) } - if !interfaceIsEmpty(m.hlsServer) { - res := m.hlsServer.apiMuxersList() + if !interfaceIsEmpty(m.hlsManager) { + res := m.hlsManager.apiMuxersList() if res.err == nil && len(res.data.Items) != 0 { for name, i := range res.data.Items { tags := "{name=\"" + name + "\"}" @@ -202,19 +202,19 @@ func (m *metrics) onMetrics(ctx *gin.Context) { } } - if !interfaceIsEmpty(m.webRTCServer) { - res := m.webRTCServer.apiConnsList() + if !interfaceIsEmpty(m.webRTCManager) { + res := m.webRTCManager.apiSessionsList() if res.err == nil && len(res.data.Items) != 0 { for id, i := range res.data.Items { tags := "{id=\"" + id + "\"}" - out += metric("webrtc_conns", tags, 1) - out += metric("webrtc_conns_bytes_received", tags, int64(i.BytesReceived)) - out += metric("webrtc_conns_bytes_sent", tags, int64(i.BytesSent)) + out += metric("webrtc_sessions", tags, 1) + out += metric("webrtc_sessions_bytes_received", tags, int64(i.BytesReceived)) + out += metric("webrtc_sessions_bytes_sent", tags, int64(i.BytesSent)) } } else { - out += metric("webrtc_conns", "", 0) - out += metric("webrtc_conns_bytes_received", "", 0) - out += metric("webrtc_conns_bytes_sent", "", 0) + out += metric("webrtc_sessions", "", 0) + out += metric("webrtc_sessions_bytes_received", "", 0) + out += metric("webrtc_sessions_bytes_sent", "", 0) } } @@ -229,11 +229,11 @@ func (m *metrics) pathManagerSet(s apiPathManager) { m.pathManager = s } -// hlsServerSet is called by hlsServer. -func (m *metrics) hlsServerSet(s apiHLSServer) { +// hlsManagerSet is called by hlsManager. +func (m *metrics) hlsManagerSet(s apiHLSManager) { m.mutex.Lock() defer m.mutex.Unlock() - m.hlsServer = s + m.hlsManager = s } // rtspServerSet is called by rtspServer (plain). @@ -257,9 +257,9 @@ func (m *metrics) rtmpServerSet(s apiRTMPServer) { m.rtmpServer = s } -// webRTCServerSet is called by webRTCServer. -func (m *metrics) webRTCServerSet(s apiWebRTCServer) { +// webRTCManagerSet is called by webRTCManager. +func (m *metrics) webRTCManagerSet(s apiWebRTCManager) { m.mutex.Lock() defer m.mutex.Unlock() - m.webRTCServer = s + m.webRTCManager = s } diff --git a/internal/core/metrics_test.go b/internal/core/metrics_test.go index c952a518..16deeda6 100644 --- a/internal/core/metrics_test.go +++ b/internal/core/metrics_test.go @@ -57,9 +57,9 @@ rtsps_sessions_bytes_sent 0 rtmp_conns 0 rtmp_conns_bytes_received 0 rtmp_conns_bytes_sent 0 -webrtc_conns 0 -webrtc_conns_bytes_received 0 -webrtc_conns_bytes_sent 0 +webrtc_sessions 0 +webrtc_sessions_bytes_received 0 +webrtc_sessions_bytes_sent 0 `, string(bo)) medi := testMediaH264 @@ -132,9 +132,9 @@ webrtc_conns_bytes_sent 0 `rtmp_conns\{id=".*?",state="publish"\} 1`+"\n"+ `rtmp_conns_bytes_received\{id=".*?",state="publish"\} [0-9]+`+"\n"+ `rtmp_conns_bytes_sent\{id=".*?",state="publish"\} [0-9]+`+"\n"+ - `webrtc_conns 0`+"\n"+ - `webrtc_conns_bytes_received 0`+"\n"+ - `webrtc_conns_bytes_sent 0`+"\n"+ + `webrtc_sessions 0`+"\n"+ + `webrtc_sessions_bytes_received 0`+"\n"+ + `webrtc_sessions_bytes_sent 0`+"\n"+ "$", string(bo)) } diff --git a/internal/core/path_manager.go b/internal/core/path_manager.go index ce06f7bb..b30943a4 100644 --- a/internal/core/path_manager.go +++ b/internal/core/path_manager.go @@ -29,7 +29,7 @@ func pathConfCanBeUpdated(oldPathConf *conf.PathConf, newPathConf *conf.PathConf return newPathConf.Equal(copy) } -type pathManagerHLSServer interface { +type pathManagerHLSManager interface { pathSourceReady(*path) pathSourceNotReady(*path) } @@ -54,7 +54,7 @@ type pathManager struct { ctx context.Context ctxCancel func() wg sync.WaitGroup - hlsServer pathManagerHLSServer + hlsManager pathManagerHLSManager paths map[string]*path pathsByConf map[string]map[*path]struct{} @@ -67,7 +67,7 @@ type pathManager struct { chDescribe chan pathDescribeReq chReaderAdd chan pathReaderAddReq chPublisherAdd chan pathPublisherAddReq - chHLSServerSet chan pathManagerHLSServer + chHLSManagerSet chan pathManagerHLSManager chAPIPathsList chan pathAPIPathsListReq } @@ -111,7 +111,7 @@ func newPathManager( chDescribe: make(chan pathDescribeReq), chReaderAdd: make(chan pathReaderAddReq), chPublisherAdd: make(chan pathPublisherAddReq), - chHLSServerSet: make(chan pathManagerHLSServer), + chHLSManagerSet: make(chan pathManagerHLSManager), chAPIPathsList: make(chan pathAPIPathsListReq), } @@ -193,13 +193,13 @@ outer: pm.removePath(pa) case pa := <-pm.chPathSourceReady: - if pm.hlsServer != nil { - pm.hlsServer.pathSourceReady(pa) + if pm.hlsManager != nil { + pm.hlsManager.pathSourceReady(pa) } case pa := <-pm.chPathSourceNotReady: - if pm.hlsServer != nil { - pm.hlsServer.pathSourceNotReady(pa) + if pm.hlsManager != nil { + pm.hlsManager.pathSourceNotReady(pa) } case req := <-pm.chPathGetPathConf: @@ -282,8 +282,8 @@ outer: req.res <- pathPublisherAnnounceRes{path: pm.paths[req.pathName]} - case s := <-pm.chHLSServerSet: - pm.hlsServer = s + case s := <-pm.chHLSManagerSet: + pm.hlsManager = s case req := <-pm.chAPIPathsList: paths := make(map[string]*path) @@ -473,10 +473,10 @@ func (pm *pathManager) readerAdd(req pathReaderAddReq) pathReaderSetupPlayRes { } } -// hlsServerSet is called by hlsServer. -func (pm *pathManager) hlsServerSet(s pathManagerHLSServer) { +// hlsManagerSet is called by hlsManager. +func (pm *pathManager) hlsManagerSet(s pathManagerHLSManager) { select { - case pm.chHLSServerSet <- s: + case pm.chHLSManagerSet <- s: case <-pm.ctx.Done(): } } diff --git a/internal/core/rtmp_server.go b/internal/core/rtmp_server.go index 5ed95e15..7e71ba56 100644 --- a/internal/core/rtmp_server.go +++ b/internal/core/rtmp_server.go @@ -8,6 +8,8 @@ import ( "sync" "time" + "github.com/google/uuid" + "github.com/aler9/mediamtx/internal/conf" "github.com/aler9/mediamtx/internal/externalcmd" "github.com/aler9/mediamtx/internal/logger" @@ -39,8 +41,8 @@ type rtmpServerAPIConnsKickRes struct { } type rtmpServerAPIConnsKickReq struct { - id string - res chan rtmpServerAPIConnsKickRes + uuid uuid.UUID + res chan rtmpServerAPIConnsKickRes } type rtmpServerParent interface { @@ -67,9 +69,9 @@ type rtmpServer struct { conns map[*rtmpConn]struct{} // in - chConnClose chan *rtmpConn - chAPIConnsList chan rtmpServerAPIConnsListReq - chAPIConnsKick chan rtmpServerAPIConnsKickReq + chConnClose chan *rtmpConn + chAPISessionsList chan rtmpServerAPIConnsListReq + chAPIConnsKick chan rtmpServerAPIConnsKickReq } func newRTMPServer( @@ -125,7 +127,7 @@ func newRTMPServer( ln: ln, conns: make(map[*rtmpConn]struct{}), chConnClose: make(chan *rtmpConn), - chAPIConnsList: make(chan rtmpServerAPIConnsListReq), + chAPISessionsList: make(chan rtmpServerAPIConnsListReq), chAPIConnsKick: make(chan rtmpServerAPIConnsKickReq), } @@ -213,7 +215,7 @@ outer: case c := <-s.chConnClose: delete(s.conns, c) - case req := <-s.chAPIConnsList: + case req := <-s.chAPISessionsList: data := &rtmpServerAPIConnsListData{ Items: make(map[string]rtmpServerAPIConnsListItem), } @@ -240,22 +242,16 @@ outer: req.res <- rtmpServerAPIConnsListRes{data: data} case req := <-s.chAPIConnsKick: - res := func() bool { - for c := range s.conns { - if c.uuid.String() == req.id { - delete(s.conns, c) - c.close() - return true - } - } - return false - }() - if res { - req.res <- rtmpServerAPIConnsKickRes{} - } else { + c := s.findConnByUUID(req.uuid) + if c == nil { req.res <- rtmpServerAPIConnsKickRes{fmt.Errorf("not found")} + continue } + delete(s.conns, c) + c.close() + req.res <- rtmpServerAPIConnsKickRes{} + case <-s.ctx.Done(): break outer } @@ -270,6 +266,15 @@ outer: } } +func (s *rtmpServer) findConnByUUID(uuid uuid.UUID) *rtmpConn { + for c := range s.conns { + if c.uuid == uuid { + return c + } + } + return nil +} + // connClose is called by rtmpConn. func (s *rtmpServer) connClose(c *rtmpConn) { select { @@ -285,7 +290,7 @@ func (s *rtmpServer) apiConnsList() rtmpServerAPIConnsListRes { } select { - case s.chAPIConnsList <- req: + case s.chAPISessionsList <- req: return <-req.res case <-s.ctx.Done(): @@ -294,10 +299,10 @@ func (s *rtmpServer) apiConnsList() rtmpServerAPIConnsListRes { } // apiConnsKick is called by api. -func (s *rtmpServer) apiConnsKick(id string) rtmpServerAPIConnsKickRes { +func (s *rtmpServer) apiConnsKick(uuid uuid.UUID) rtmpServerAPIConnsKickRes { req := rtmpServerAPIConnsKickReq{ - id: id, - res: make(chan rtmpServerAPIConnsKickRes), + uuid: uuid, + res: make(chan rtmpServerAPIConnsKickRes), } select { diff --git a/internal/core/rtsp_server.go b/internal/core/rtsp_server.go index 2b4b3fa5..4daf7671 100644 --- a/internal/core/rtsp_server.go +++ b/internal/core/rtsp_server.go @@ -12,6 +12,7 @@ import ( "github.com/bluenviron/gortsplib/v3/pkg/base" "github.com/bluenviron/gortsplib/v3/pkg/headers" "github.com/bluenviron/gortsplib/v3/pkg/liberrors" + "github.com/google/uuid" "github.com/aler9/mediamtx/internal/conf" "github.com/aler9/mediamtx/internal/externalcmd" @@ -359,6 +360,15 @@ func (s *rtspServer) OnDecodeError(ctx *gortsplib.ServerHandlerOnDecodeErrorCtx) se.onDecodeError(ctx) } +func (s *rtspServer) findSessionByUUID(uuid uuid.UUID) (*gortsplib.ServerSession, *rtspSession) { + for key, sx := range s.sessions { + if sx.uuid == uuid { + return key, sx + } + } + return nil, nil +} + // apiConnsList is called by api and metrics. func (s *rtspServer) apiConnsList() rtspServerAPIConnsListRes { select { @@ -426,7 +436,7 @@ func (s *rtspServer) apiSessionsList() rtspServerAPISessionsListRes { } // apiSessionsKick is called by api. -func (s *rtspServer) apiSessionsKick(id string) rtspServerAPISessionsKickRes { +func (s *rtspServer) apiSessionsKick(uuid uuid.UUID) rtspServerAPISessionsKickRes { select { case <-s.ctx.Done(): return rtspServerAPISessionsKickRes{err: fmt.Errorf("terminated")} @@ -436,14 +446,13 @@ func (s *rtspServer) apiSessionsKick(id string) rtspServerAPISessionsKickRes { s.mutex.RLock() defer s.mutex.RUnlock() - for key, se := range s.sessions { - if se.uuid.String() == id { - se.close() - delete(s.sessions, key) - se.onClose(liberrors.ErrServerTerminated{}) - return rtspServerAPISessionsKickRes{} - } + key, sx := s.findSessionByUUID(uuid) + if sx == nil { + return rtspServerAPISessionsKickRes{err: fmt.Errorf("not found")} } - return rtspServerAPISessionsKickRes{err: fmt.Errorf("not found")} + sx.close() + delete(s.sessions, key) + sx.onClose(liberrors.ErrServerTerminated{}) + return rtspServerAPISessionsKickRes{} } diff --git a/internal/core/webrtc_candidate_reader.go b/internal/core/webrtc_candidate_reader.go deleted file mode 100644 index df54da83..00000000 --- a/internal/core/webrtc_candidate_reader.go +++ /dev/null @@ -1,73 +0,0 @@ -package core - -import ( - "context" - - "github.com/pion/webrtc/v3" - - "github.com/aler9/mediamtx/internal/websocket" -) - -type webRTCCandidateReader struct { - ws *websocket.ServerConn - - ctx context.Context - ctxCancel func() - - stopGathering chan struct{} - readError chan error - remoteCandidate chan *webrtc.ICECandidateInit -} - -func newWebRTCCandidateReader(ws *websocket.ServerConn) *webRTCCandidateReader { - ctx, ctxCancel := context.WithCancel(context.Background()) - - r := &webRTCCandidateReader{ - ws: ws, - ctx: ctx, - ctxCancel: ctxCancel, - stopGathering: make(chan struct{}), - readError: make(chan error), - remoteCandidate: make(chan *webrtc.ICECandidateInit), - } - - go r.run() - - return r -} - -func (r *webRTCCandidateReader) close() { - r.ctxCancel() - // do not wait for ReadJSON() to return - // it is terminated by ws.Close() later -} - -func (r *webRTCCandidateReader) run() { - for { - candidate, err := r.readCandidate() - if err != nil { - select { - case r.readError <- err: - case <-r.ctx.Done(): - } - return - } - - select { - case r.remoteCandidate <- candidate: - case <-r.stopGathering: - case <-r.ctx.Done(): - return - } - } -} - -func (r *webRTCCandidateReader) readCandidate() (*webrtc.ICECandidateInit, error) { - var candidate webrtc.ICECandidateInit - err := r.ws.ReadJSON(&candidate) - if err != nil { - return nil, err - } - - return &candidate, err -} diff --git a/internal/core/webrtc_conn.go b/internal/core/webrtc_conn.go deleted file mode 100644 index a87a830c..00000000 --- a/internal/core/webrtc_conn.go +++ /dev/null @@ -1,694 +0,0 @@ -package core - -import ( - "context" - "crypto/hmac" - "crypto/sha1" - "encoding/base64" - "errors" - "fmt" - "math/rand" - "net" - "strconv" - "strings" - "sync" - "time" - - "github.com/bluenviron/gortsplib/v3/pkg/media" - "github.com/bluenviron/gortsplib/v3/pkg/ringbuffer" - "github.com/google/uuid" - "github.com/pion/ice/v2" - "github.com/pion/sdp/v3" - "github.com/pion/webrtc/v3" - - "github.com/aler9/mediamtx/internal/formatprocessor" - "github.com/aler9/mediamtx/internal/logger" - "github.com/aler9/mediamtx/internal/websocket" -) - -const ( - webrtcHandshakeTimeout = 10 * time.Second - webrtcTrackGatherTimeout = 2 * time.Second - webrtcPayloadMaxSize = 1188 // 1200 - 12 (RTP header) -) - -type trackRecvPair struct { - track *webrtc.TrackRemote - receiver *webrtc.RTPReceiver -} - -func mediasOfOutgoingTracks(tracks []*webRTCOutgoingTrack) media.Medias { - ret := make(media.Medias, len(tracks)) - for i, track := range tracks { - ret[i] = track.media - } - return ret -} - -func mediasOfIncomingTracks(tracks []*webRTCIncomingTrack) media.Medias { - ret := make(media.Medias, len(tracks)) - for i, track := range tracks { - ret[i] = track.media - } - return ret -} - -func insertTias(offer *webrtc.SessionDescription, value uint64) { - var sd sdp.SessionDescription - err := sd.Unmarshal([]byte(offer.SDP)) - if err != nil { - return - } - - for _, media := range sd.MediaDescriptions { - if media.MediaName.Media == "video" { - media.Bandwidth = append(media.Bandwidth, sdp.Bandwidth{ - Type: "TIAS", - Bandwidth: value, - }) - } - } - - enc, err := sd.Marshal() - if err != nil { - return - } - - offer.SDP = string(enc) -} - -type webRTCConnPathManager interface { - publisherAdd(req pathPublisherAddReq) pathPublisherAnnounceRes - readerAdd(req pathReaderAddReq) pathReaderSetupPlayRes -} - -type webRTCConnParent interface { - logger.Writer - connClose(*webRTCConn) -} - -type webRTCConn struct { - readBufferCount int - pathName string - publish bool - ws *websocket.ServerConn - videoCodec string - audioCodec string - videoBitrate string - iceServers []string - wg *sync.WaitGroup - pathManager webRTCConnPathManager - parent webRTCConnParent - iceUDPMux ice.UDPMux - iceTCPMux ice.TCPMux - iceHostNAT1To1IPs []string - - ctx context.Context - ctxCancel func() - uuid uuid.UUID - created time.Time - pc *peerConnection - mutex sync.RWMutex - - closed chan struct{} -} - -func newWebRTCConn( - parentCtx context.Context, - readBufferCount int, - pathName string, - publish bool, - ws *websocket.ServerConn, - videoCodec string, - audioCodec string, - videoBitrate string, - iceServers []string, - wg *sync.WaitGroup, - pathManager webRTCConnPathManager, - parent webRTCConnParent, - iceHostNAT1To1IPs []string, - iceUDPMux ice.UDPMux, - iceTCPMux ice.TCPMux, -) *webRTCConn { - ctx, ctxCancel := context.WithCancel(parentCtx) - - c := &webRTCConn{ - readBufferCount: readBufferCount, - pathName: pathName, - publish: publish, - ws: ws, - iceServers: iceServers, - wg: wg, - videoCodec: videoCodec, - audioCodec: audioCodec, - videoBitrate: videoBitrate, - pathManager: pathManager, - parent: parent, - ctx: ctx, - ctxCancel: ctxCancel, - uuid: uuid.New(), - created: time.Now(), - iceUDPMux: iceUDPMux, - iceTCPMux: iceTCPMux, - iceHostNAT1To1IPs: iceHostNAT1To1IPs, - closed: make(chan struct{}), - } - - c.Log(logger.Info, "opened") - - wg.Add(1) - go c.run() - - return c -} - -func (c *webRTCConn) close() { - c.ctxCancel() -} - -func (c *webRTCConn) wait() { - <-c.closed -} - -func (c *webRTCConn) remoteAddr() net.Addr { - return c.ws.RemoteAddr() -} - -func (c *webRTCConn) safePC() *peerConnection { - c.mutex.RLock() - defer c.mutex.RUnlock() - return c.pc -} - -func (c *webRTCConn) Log(level logger.Level, format string, args ...interface{}) { - c.parent.Log(level, "[conn %v] "+format, append([]interface{}{c.ws.RemoteAddr()}, args...)...) -} - -func (c *webRTCConn) run() { - defer close(c.closed) - defer c.wg.Done() - - innerCtx, innerCtxCancel := context.WithCancel(c.ctx) - runErr := make(chan error) - go func() { - runErr <- c.runInner(innerCtx) - }() - - var err error - select { - case err = <-runErr: - innerCtxCancel() - - case <-c.ctx.Done(): - innerCtxCancel() - <-runErr - err = errors.New("terminated") - } - - c.ctxCancel() - - c.parent.connClose(c) - - c.Log(logger.Info, "closed (%v)", err) -} - -func (c *webRTCConn) runInner(ctx context.Context) error { - if c.publish { - return c.runPublish(ctx) - } - return c.runRead(ctx) -} - -func (c *webRTCConn) runPublish(ctx context.Context) error { - res := c.pathManager.publisherAdd(pathPublisherAddReq{ - author: c, - pathName: c.pathName, - skipAuth: true, - }) - if res.err != nil { - return res.err - } - - defer res.path.publisherRemove(pathPublisherRemoveReq{author: c}) - - err := c.writeICEServers() - if err != nil { - return err - } - - pc, err := newPeerConnection( - c.videoCodec, - c.audioCodec, - c.genICEServers(), - c.iceHostNAT1To1IPs, - c.iceUDPMux, - c.iceTCPMux, - c) - if err != nil { - return err - } - defer pc.close() - - _, err = pc.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo, webrtc.RtpTransceiverInit{ - Direction: webrtc.RTPTransceiverDirectionRecvonly, - }) - if err != nil { - return err - } - - _, err = pc.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio, webrtc.RtpTransceiverInit{ - Direction: webrtc.RTPTransceiverDirectionRecvonly, - }) - if err != nil { - return err - } - - trackRecv := make(chan trackRecvPair) - - pc.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { - select { - case trackRecv <- trackRecvPair{track, receiver}: - case <-pc.closed: - } - }) - - offer, err := pc.CreateOffer(nil) - if err != nil { - return err - } - - err = pc.SetLocalDescription(offer) - if err != nil { - return err - } - - tmp, err := strconv.ParseUint(c.videoBitrate, 10, 31) - if err != nil { - return err - } - - insertTias(&offer, tmp*1024) - - err = c.writeOffer(&offer) - if err != nil { - return err - } - - answer, err := c.readAnswer() - if err != nil { - return err - } - - err = pc.SetRemoteDescription(*answer) - if err != nil { - return err - } - - cr := newWebRTCCandidateReader(c.ws) - defer cr.close() - - err = c.establishConnection(ctx, pc, cr) - if err != nil { - return err - } - - close(cr.stopGathering) - - tracks, err := c.gatherIncomingTracks(ctx, pc, cr, trackRecv) - if err != nil { - return err - } - medias := mediasOfIncomingTracks(tracks) - - rres := res.path.publisherStart(pathPublisherStartReq{ - author: c, - medias: medias, - generateRTPPackets: false, - }) - if rres.err != nil { - return rres.err - } - - c.Log(logger.Info, "is publishing to path '%s', %s", - res.path.name, - sourceMediaInfo(medias)) - - for _, track := range tracks { - track.start(rres.stream) - } - - select { - case <-pc.disconnected: - return fmt.Errorf("peer connection closed") - - case err := <-cr.readError: - return fmt.Errorf("websocket error: %v", err) - - case <-ctx.Done(): - return fmt.Errorf("terminated") - } -} - -func (c *webRTCConn) runRead(ctx context.Context) error { - res := c.pathManager.readerAdd(pathReaderAddReq{ - author: c, - pathName: c.pathName, - skipAuth: true, - }) - if res.err != nil { - return res.err - } - - defer res.path.readerRemove(pathReaderRemoveReq{author: c}) - - tracks, err := c.gatherOutgoingTracks(res.stream.medias()) - if err != nil { - return err - } - - err = c.writeICEServers() - if err != nil { - return err - } - - offer, err := c.readOffer() - if err != nil { - return err - } - - pc, err := newPeerConnection( - "", - "", - c.genICEServers(), - c.iceHostNAT1To1IPs, - c.iceUDPMux, - c.iceTCPMux, - c) - if err != nil { - return err - } - defer pc.close() - - for _, track := range tracks { - var err error - track.sender, err = pc.AddTrack(track.track) - if err != nil { - return err - } - } - - err = pc.SetRemoteDescription(*offer) - if err != nil { - return err - } - - answer, err := pc.CreateAnswer(nil) - if err != nil { - return err - } - - err = pc.SetLocalDescription(answer) - if err != nil { - return err - } - - err = c.writeAnswer(&answer) - if err != nil { - return err - } - - cr := newWebRTCCandidateReader(c.ws) - defer cr.close() - - err = c.establishConnection(ctx, pc, cr) - if err != nil { - return err - } - - close(cr.stopGathering) - - for _, track := range tracks { - track.start() - } - - ringBuffer, _ := ringbuffer.New(uint64(c.readBufferCount)) - defer ringBuffer.Close() - - writeError := make(chan error) - - for _, track := range tracks { - ctrack := track - res.stream.readerAdd(c, track.media, track.format, func(unit formatprocessor.Unit) { - ringBuffer.Push(func() { - ctrack.cb(unit, ctx, writeError) - }) - }) - } - defer res.stream.readerRemove(c) - - c.Log(logger.Info, "is reading from path '%s', %s", - res.path.name, sourceMediaInfo(mediasOfOutgoingTracks(tracks))) - - go func() { - for { - item, ok := ringBuffer.Pull() - if !ok { - return - } - item.(func())() - } - }() - - select { - case <-pc.disconnected: - return fmt.Errorf("peer connection closed") - - case err := <-cr.readError: - return fmt.Errorf("websocket error: %v", err) - - case err := <-writeError: - return err - - case <-ctx.Done(): - return fmt.Errorf("terminated") - } -} - -func (c *webRTCConn) gatherOutgoingTracks(medias media.Medias) ([]*webRTCOutgoingTrack, error) { - var tracks []*webRTCOutgoingTrack - - videoTrack, err := newWebRTCOutgoingTrackVideo(medias) - if err != nil { - return nil, err - } - - if videoTrack != nil { - tracks = append(tracks, videoTrack) - } - - audioTrack, err := newWebRTCOutgoingTrackAudio(medias) - if err != nil { - return nil, err - } - - if audioTrack != nil { - tracks = append(tracks, audioTrack) - } - - if tracks == nil { - return nil, fmt.Errorf( - "the stream doesn't contain any supported codec, which are currently H264, VP8, VP9, G711, G722, Opus") - } - - return tracks, nil -} - -func (c *webRTCConn) gatherIncomingTracks( - ctx context.Context, - pc *peerConnection, - cr *webRTCCandidateReader, - trackRecv chan trackRecvPair, -) ([]*webRTCIncomingTrack, error) { - var tracks []*webRTCIncomingTrack - - t := time.NewTimer(webrtcTrackGatherTimeout) - defer t.Stop() - - for { - select { - case <-t.C: - return tracks, nil - - case pair := <-trackRecv: - track, err := newWebRTCIncomingTrack(pair.track, pair.receiver, pc.WriteRTCP) - if err != nil { - return nil, err - } - tracks = append(tracks, track) - - if len(tracks) == 2 { - return tracks, nil - } - - case <-pc.disconnected: - return nil, fmt.Errorf("peer connection closed") - - case err := <-cr.readError: - return nil, fmt.Errorf("websocket error: %v", err) - - case <-ctx.Done(): - return nil, fmt.Errorf("terminated") - } - } -} - -func (c *webRTCConn) genICEServers() []webrtc.ICEServer { - ret := make([]webrtc.ICEServer, len(c.iceServers)) - for i, s := range c.iceServers { - parts := strings.Split(s, ":") - if len(parts) == 5 { - if parts[1] == "AUTH_SECRET" { - s := webrtc.ICEServer{ - URLs: []string{parts[0] + ":" + parts[3] + ":" + parts[4]}, - } - - randomUser := func() string { - const charset = "abcdefghijklmnopqrstuvwxyz1234567890" - b := make([]byte, 20) - for i := range b { - b[i] = charset[rand.Intn(len(charset))] - } - return string(b) - }() - - expireDate := time.Now().Add(24 * 3600 * time.Second).Unix() - s.Username = strconv.FormatInt(expireDate, 10) + ":" + randomUser - - h := hmac.New(sha1.New, []byte(parts[2])) - h.Write([]byte(s.Username)) - s.Credential = base64.StdEncoding.EncodeToString(h.Sum(nil)) - - ret[i] = s - } else { - ret[i] = webrtc.ICEServer{ - URLs: []string{parts[0] + ":" + parts[3] + ":" + parts[4]}, - Username: parts[1], - Credential: parts[2], - } - } - } else { - ret[i] = webrtc.ICEServer{ - URLs: []string{s}, - } - } - } - return ret -} - -func (c *webRTCConn) establishConnection( - ctx context.Context, - pc *peerConnection, - cr *webRTCCandidateReader, -) error { - t := time.NewTimer(webrtcHandshakeTimeout) - defer t.Stop() - -outer: - for { - select { - case candidate := <-pc.localCandidateRecv: - c.Log(logger.Debug, "local candidate: %+v", candidate.Candidate) - err := c.ws.WriteJSON(candidate) - if err != nil { - return err - } - - case candidate := <-cr.remoteCandidate: - c.Log(logger.Debug, "remote candidate: %+v", candidate.Candidate) - err := pc.AddICECandidate(*candidate) - if err != nil { - return err - } - - case err := <-cr.readError: - return err - - case <-t.C: - return fmt.Errorf("deadline exceeded") - - case <-pc.connected: - break outer - - case <-ctx.Done(): - return fmt.Errorf("terminated") - } - } - - // Keep WebSocket connection open and use it to notify shutdowns. - // This is because pion/webrtc doesn't write yet a WebRTC shutdown - // message to clients (like a DTLS close alert or a RTCP BYE), - // therefore browsers do not properly detect shutdowns and do not - // attempt to restart the connection immediately. - - c.mutex.Lock() - c.pc = pc - c.mutex.Unlock() - - c.Log(logger.Info, "peer connection established, local candidate: %v, remote candidate: %v", - pc.localCandidate(), pc.remoteCandidate()) - - return nil -} - -func (c *webRTCConn) writeICEServers() error { - return c.ws.WriteJSON(c.genICEServers()) -} - -func (c *webRTCConn) readOffer() (*webrtc.SessionDescription, error) { - var offer webrtc.SessionDescription - err := c.ws.ReadJSON(&offer) - if err != nil { - return nil, err - } - - if offer.Type != webrtc.SDPTypeOffer { - return nil, fmt.Errorf("received SDP is not an offer") - } - - return &offer, nil -} - -func (c *webRTCConn) writeOffer(offer *webrtc.SessionDescription) error { - return c.ws.WriteJSON(offer) -} - -func (c *webRTCConn) readAnswer() (*webrtc.SessionDescription, error) { - var answer webrtc.SessionDescription - err := c.ws.ReadJSON(&answer) - if err != nil { - return nil, err - } - - if answer.Type != webrtc.SDPTypeAnswer { - return nil, fmt.Errorf("received SDP is not an offer") - } - - return &answer, nil -} - -func (c *webRTCConn) writeAnswer(answer *webrtc.SessionDescription) error { - return c.ws.WriteJSON(answer) -} - -// apiSourceDescribe implements sourceStaticImpl. -func (c *webRTCConn) apiSourceDescribe() pathAPISourceOrReader { - return pathAPISourceOrReader{ - Type: "webRTCConn", - ID: c.uuid.String(), - } -} - -// apiReaderDescribe implements reader. -func (c *webRTCConn) apiReaderDescribe() pathAPISourceOrReader { - return c.apiSourceDescribe() -} diff --git a/internal/core/webrtc_http_server.go b/internal/core/webrtc_http_server.go new file mode 100644 index 00000000..9a8939b4 --- /dev/null +++ b/internal/core/webrtc_http_server.go @@ -0,0 +1,364 @@ +package core + +import ( + "context" + "crypto/tls" + _ "embed" + "fmt" + "io" + "log" + "net" + "net/http" + "strconv" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "github.com/pion/sdp/v3" + "github.com/pion/webrtc/v3" + + "github.com/aler9/mediamtx/internal/conf" + "github.com/aler9/mediamtx/internal/logger" +) + +//go:embed webrtc_publish_index.html +var webrtcPublishIndex []byte + +//go:embed webrtc_read_index.html +var webrtcReadIndex []byte + +func unmarshalICEFragment(buf []byte) ([]*webrtc.ICECandidateInit, error) { + buf = append([]byte("v=0\r\no=- 0 0 IN IP4 0.0.0.0\r\ns=-\r\nt=0 0\r\n"), buf...) + + var sdp sdp.SessionDescription + err := sdp.Unmarshal(buf) + if err != nil { + return nil, err + } + + usernameFragment, ok := sdp.Attribute("ice-ufrag") + if !ok { + return nil, fmt.Errorf("ice-ufrag attribute is missing") + } + + var ret []*webrtc.ICECandidateInit + + for _, media := range sdp.MediaDescriptions { + mid, ok := media.Attribute("mid") + if !ok { + return nil, fmt.Errorf("mid attribute is missing") + } + + tmp, err := strconv.ParseUint(mid, 10, 16) + if err != nil { + return nil, fmt.Errorf("invalid mid attribute") + } + midNum := uint16(tmp) + + for _, attr := range media.Attributes { + if attr.Key == "candidate" { + ret = append(ret, &webrtc.ICECandidateInit{ + Candidate: attr.Value, + SDPMid: &mid, + SDPMLineIndex: &midNum, + UsernameFragment: &usernameFragment, + }) + } + } + } + + return ret, nil +} + +func marshalICEFragment(offer *webrtc.SessionDescription, candidates []*webrtc.ICECandidateInit) ([]byte, error) { + var sdp sdp.SessionDescription + err := sdp.Unmarshal([]byte(offer.SDP)) + if err != nil || len(sdp.MediaDescriptions) == 0 { + return nil, err + } + + firstMedia := sdp.MediaDescriptions[0] + iceUfrag, _ := firstMedia.Attribute("ice-ufrag") + icePwd, _ := firstMedia.Attribute("ice-pwd") + + candidatesByMedia := make(map[uint16][]*webrtc.ICECandidateInit) + for _, candidate := range candidates { + mid := *candidate.SDPMLineIndex + candidatesByMedia[mid] = append(candidatesByMedia[mid], candidate) + } + + frag := "a=ice-ufrag:" + iceUfrag + "\r\n" + + "a=ice-pwd:" + icePwd + "\r\n" + + for mid, media := range sdp.MediaDescriptions { + cbm, ok := candidatesByMedia[uint16(mid)] + if ok { + frag += "m=" + media.MediaName.String() + "\r\n" + + "a=mid:" + strconv.FormatUint(uint64(mid), 10) + "\r\n" + + for _, candidate := range cbm { + frag += "a=" + candidate.Candidate + "\r\n" + } + } + } + + return []byte(frag), nil +} + +type webRTCHTTPServerParent interface { + logger.Writer + genICEServers() []webrtc.ICEServer + sessionNew(req webRTCSessionNewReq) webRTCNewSessionRes + sessionAddCandidates(req webRTCSessionAddCandidatesReq) webRTCSessionAddCandidatesRes +} + +type webRTCHTTPServer struct { + allowOrigin string + pathManager *pathManager + parent webRTCHTTPServerParent + + ln net.Listener + inner *http.Server +} + +func newWebRTCHTTPServer( + address string, + encryption bool, + serverKey string, + serverCert string, + allowOrigin string, + trustedProxies conf.IPsOrCIDRs, + readTimeout conf.StringDuration, + pathManager *pathManager, + parent webRTCHTTPServerParent, +) (*webRTCHTTPServer, error) { + ln, err := net.Listen(restrictNetwork("tcp", address)) + if err != nil { + return nil, err + } + + var tlsConfig *tls.Config + if encryption { + crt, err := tls.LoadX509KeyPair(serverCert, serverKey) + if err != nil { + ln.Close() + return nil, err + } + + tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{crt}, + } + } + + s := &webRTCHTTPServer{ + allowOrigin: allowOrigin, + pathManager: pathManager, + parent: parent, + ln: ln, + } + + router := gin.New() + httpSetTrustedProxies(router, trustedProxies) + router.NoRoute(httpLoggerMiddleware(s), httpServerHeaderMiddleware, s.onRequest) + + s.inner = &http.Server{ + Handler: router, + TLSConfig: tlsConfig, + ReadHeaderTimeout: time.Duration(readTimeout), + ErrorLog: log.New(&nilWriter{}, "", 0), + } + + if tlsConfig != nil { + go s.inner.ServeTLS(s.ln, "", "") + } else { + go s.inner.Serve(s.ln) + } + + return s, nil +} + +func (s *webRTCHTTPServer) Log(level logger.Level, format string, args ...interface{}) { + s.parent.Log(level, format, args...) +} + +func (s *webRTCHTTPServer) close() { + s.inner.Shutdown(context.Background()) + s.ln.Close() // in case Shutdown() is called before Serve() +} + +func (s *webRTCHTTPServer) onRequest(ctx *gin.Context) { + ctx.Writer.Header().Set("Access-Control-Allow-Origin", s.allowOrigin) + ctx.Writer.Header().Set("Access-Control-Allow-Credentials", "true") + + // remove leading prefix + pa := ctx.Request.URL.Path[1:] + + if !strings.HasSuffix(pa, "/whip") && !strings.HasSuffix(pa, "/whep") { + switch ctx.Request.Method { + case http.MethodGet: + + case http.MethodOptions: + ctx.Writer.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") + ctx.Writer.Header().Set("Access-Control-Allow-Headers", ctx.Request.Header.Get("Access-Control-Request-Headers")) + ctx.Writer.WriteHeader(http.StatusOK) + return + + default: + return + } + } + + var dir string + var fname string + var publish bool + + switch { + case pa == "favicon.ico": + return + + case strings.HasSuffix(pa, "/publish"): + dir, fname = pa[:len(pa)-len("/publish")], "publish" + publish = true + + case strings.HasSuffix(pa, "/whip"): + dir, fname = pa[:len(pa)-len("/whip")], "whip" + publish = true + + case strings.HasSuffix(pa, "/whep"): + dir, fname = pa[:len(pa)-len("/whep")], "whep" + publish = false + + default: + dir, fname = pa, "" + publish = false + + if !strings.HasSuffix(dir, "/") { + ctx.Writer.Header().Set("Location", "/"+dir+"/") + ctx.Writer.WriteHeader(http.StatusMovedPermanently) + return + } + } + + dir = strings.TrimSuffix(dir, "/") + if dir == "" { + return + } + + user, pass, hasCredentials := ctx.Request.BasicAuth() + + res := s.pathManager.getPathConf(pathGetPathConfReq{ + name: dir, + publish: publish, + credentials: authCredentials{ + query: ctx.Request.URL.RawQuery, + ip: net.ParseIP(ctx.ClientIP()), + user: user, + pass: pass, + proto: authProtocolWebRTC, + }, + }) + if res.err != nil { + if terr, ok := res.err.(pathErrAuth); ok { + if !hasCredentials { + ctx.Header("WWW-Authenticate", `Basic realm="mediamtx"`) + ctx.Writer.WriteHeader(http.StatusUnauthorized) + return + } + + s.Log(logger.Info, "authentication error: %v", terr.wrapped) + ctx.Writer.WriteHeader(http.StatusUnauthorized) + return + } + + ctx.Writer.WriteHeader(http.StatusNotFound) + return + } + + switch fname { + case "": + ctx.Writer.Header().Set("Content-Type", "text/html") + ctx.Writer.WriteHeader(http.StatusOK) + ctx.Writer.Write(webrtcReadIndex) + + case "publish": + ctx.Writer.Header().Set("Content-Type", "text/html") + ctx.Writer.WriteHeader(http.StatusOK) + ctx.Writer.Write(webrtcPublishIndex) + + case "whip", "whep": + switch ctx.Request.Method { + case http.MethodOptions: + ctx.Writer.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") + ctx.Writer.Header().Set("Access-Control-Allow-Headers", ctx.Request.Header.Get("Access-Control-Request-Headers")) + ctx.Writer.Header()["Link"] = iceServersToLinkHeader(s.parent.genICEServers()) + ctx.Writer.WriteHeader(http.StatusOK) + + case http.MethodPost: + if ctx.Request.Header.Get("Content-Type") != "application/sdp" { + ctx.Writer.WriteHeader(http.StatusBadRequest) + return + } + + offer, err := io.ReadAll(ctx.Request.Body) + if err != nil { + return + } + + res := s.parent.sessionNew(webRTCSessionNewReq{ + pathName: dir, + remoteAddr: ctx.ClientIP(), + offer: offer, + publish: (fname == "whip"), + videoCodec: ctx.Query("video_codec"), + audioCodec: ctx.Query("audio_codec"), + videoBitrate: ctx.Query("video_bitrate"), + }) + if res.err != nil { + ctx.Writer.WriteHeader(http.StatusInternalServerError) + return + } + + ctx.Writer.Header().Set("Content-Type", "application/sdp") + ctx.Writer.Header().Set("E-Tag", res.sx.secret.String()) + ctx.Writer.Header().Set("Accept-Patch", "application/trickle-ice-sdpfrag") + ctx.Writer.Header()["Link"] = iceServersToLinkHeader(s.parent.genICEServers()) + ctx.Writer.WriteHeader(http.StatusCreated) + ctx.Writer.Write(res.answer) + + case http.MethodPatch: + secret, err := uuid.Parse(ctx.Request.Header.Get("If-Match")) + if err != nil { + ctx.Writer.WriteHeader(http.StatusBadRequest) + return + } + + if ctx.Request.Header.Get("Content-Type") != "application/trickle-ice-sdpfrag" { + ctx.Writer.WriteHeader(http.StatusBadRequest) + return + } + + byts, err := io.ReadAll(ctx.Request.Body) + if err != nil { + return + } + + candidates, err := unmarshalICEFragment(byts) + if err != nil { + ctx.Writer.WriteHeader(http.StatusBadRequest) + return + } + + res := s.parent.sessionAddCandidates(webRTCSessionAddCandidatesReq{ + secret: secret, + candidates: candidates, + }) + if res.err != nil { + ctx.Writer.WriteHeader(http.StatusBadRequest) + return + } + + ctx.Writer.WriteHeader(http.StatusNoContent) + } + } +} diff --git a/internal/core/webrtc_incoming_track.go b/internal/core/webrtc_incoming_track.go index 818bcc9f..d994fa87 100644 --- a/internal/core/webrtc_incoming_track.go +++ b/internal/core/webrtc_incoming_track.go @@ -117,6 +117,7 @@ func (t *webRTCIncomingTrack) start(stream *stream) { if t.mediaType == media.TypeVideo { go func() { keyframeTicker := time.NewTicker(keyFrameInterval) + defer keyframeTicker.Stop() for range keyframeTicker.C { err := t.writeRTCP([]rtcp.Packet{ diff --git a/internal/core/webrtc_manager.go b/internal/core/webrtc_manager.go new file mode 100644 index 00000000..d66b4206 --- /dev/null +++ b/internal/core/webrtc_manager.go @@ -0,0 +1,508 @@ +package core + +import ( + "context" + "crypto/hmac" + "crypto/sha1" + "encoding/base64" + "fmt" + "math/rand" + "net" + "regexp" + "strconv" + "strings" + "sync" + "time" + + "github.com/google/uuid" + "github.com/pion/ice/v2" + "github.com/pion/webrtc/v3" + + "github.com/aler9/mediamtx/internal/conf" + "github.com/aler9/mediamtx/internal/logger" +) + +func iceServersToLinkHeader(iceServers []webrtc.ICEServer) []string { + ret := make([]string, len(iceServers)) + + for i, server := range iceServers { + link := "<" + server.URLs[0] + ">; rel=\"ice-server\"" + if server.Username != "" { + link += "; username=\"" + server.Username + "\"" + + "; credential=\"" + server.Credential.(string) + "\"; credential-type=\"password\"" + } + ret[i] = link + } + + return ret +} + +var reLink = regexp.MustCompile(`^<(.+?)>; rel="ice-server"(; username="(.+?)"` + + `; credential="(.+?)"; credential-type="password")?`) + +func linkHeaderToIceServers(link []string) []webrtc.ICEServer { + var ret []webrtc.ICEServer + + for _, li := range link { + m := reLink.FindStringSubmatch(li) + if m != nil { + s := webrtc.ICEServer{ + URLs: []string{m[1]}, + } + + if m[3] != "" { + s.Username = m[3] + s.Credential = m[4] + s.CredentialType = webrtc.ICECredentialTypePassword + } + + ret = append(ret, s) + } + } + + return ret +} + +type webRTCManagerAPISessionsListItem struct { + Created time.Time `json:"created"` + RemoteAddr string `json:"remoteAddr"` + PeerConnectionEstablished bool `json:"peerConnectionEstablished"` + LocalCandidate string `json:"localCandidate"` + RemoteCandidate string `json:"remoteCandidate"` + State string `json:"state"` + BytesReceived uint64 `json:"bytesReceived"` + BytesSent uint64 `json:"bytesSent"` +} + +type webRTCManagerAPISessionsListData struct { + Items map[string]webRTCManagerAPISessionsListItem `json:"items"` +} + +type webRTCManagerAPISessionsListRes struct { + data *webRTCManagerAPISessionsListData + err error +} + +type webRTCManagerAPISessionsListReq struct { + res chan webRTCManagerAPISessionsListRes +} + +type webRTCManagerAPISessionsKickRes struct { + err error +} + +type webRTCManagerAPISessionsKickReq struct { + uuid uuid.UUID + res chan webRTCManagerAPISessionsKickRes +} + +type webRTCNewSessionRes struct { + sx *webRTCSession + answer []byte + err error +} + +type webRTCSessionNewReq struct { + pathName string + remoteAddr string + offer []byte + publish bool + videoCodec string + audioCodec string + videoBitrate string + res chan webRTCNewSessionRes +} + +type webRTCSessionAddCandidatesRes struct { + sx *webRTCSession + err error +} + +type webRTCSessionAddCandidatesReq struct { + secret uuid.UUID + candidates []*webrtc.ICECandidateInit + res chan webRTCSessionAddCandidatesRes +} + +type webRTCManagerParent interface { + logger.Writer +} + +type webRTCManager struct { + allowOrigin string + trustedProxies conf.IPsOrCIDRs + iceServers []string + readBufferCount int + pathManager *pathManager + metrics *metrics + parent webRTCManagerParent + + ctx context.Context + ctxCancel func() + httpServer *webRTCHTTPServer + udpMuxLn net.PacketConn + tcpMuxLn net.Listener + sessions map[*webRTCSession]struct{} + sessionsBySecret map[uuid.UUID]*webRTCSession + iceHostNAT1To1IPs []string + iceUDPMux ice.UDPMux + iceTCPMux ice.TCPMux + + // in + chSessionNew chan webRTCSessionNewReq + chSessionClose chan *webRTCSession + chSessionAddCandidates chan webRTCSessionAddCandidatesReq + chAPISessionsList chan webRTCManagerAPISessionsListReq + chAPIConnsKick chan webRTCManagerAPISessionsKickReq + + // out + done chan struct{} +} + +func newWebRTCManager( + parentCtx context.Context, + address string, + encryption bool, + serverKey string, + serverCert string, + allowOrigin string, + trustedProxies conf.IPsOrCIDRs, + iceServers []string, + readTimeout conf.StringDuration, + readBufferCount int, + pathManager *pathManager, + metrics *metrics, + parent webRTCManagerParent, + iceHostNAT1To1IPs []string, + iceUDPMuxAddress string, + iceTCPMuxAddress string, +) (*webRTCManager, error) { + ctx, ctxCancel := context.WithCancel(parentCtx) + + m := &webRTCManager{ + allowOrigin: allowOrigin, + trustedProxies: trustedProxies, + iceServers: iceServers, + readBufferCount: readBufferCount, + pathManager: pathManager, + metrics: metrics, + parent: parent, + ctx: ctx, + ctxCancel: ctxCancel, + iceHostNAT1To1IPs: iceHostNAT1To1IPs, + sessions: make(map[*webRTCSession]struct{}), + sessionsBySecret: make(map[uuid.UUID]*webRTCSession), + chSessionNew: make(chan webRTCSessionNewReq), + chSessionClose: make(chan *webRTCSession), + chSessionAddCandidates: make(chan webRTCSessionAddCandidatesReq), + chAPISessionsList: make(chan webRTCManagerAPISessionsListReq), + chAPIConnsKick: make(chan webRTCManagerAPISessionsKickReq), + done: make(chan struct{}), + } + + var err error + m.httpServer, err = newWebRTCHTTPServer( + address, + encryption, + serverKey, + serverCert, + allowOrigin, + trustedProxies, + readTimeout, + pathManager, + m, + ) + if err != nil { + ctxCancel() + return nil, err + } + + if iceUDPMuxAddress != "" { + m.udpMuxLn, err = net.ListenPacket(restrictNetwork("udp", iceUDPMuxAddress)) + if err != nil { + m.httpServer.close() + ctxCancel() + return nil, err + } + m.iceUDPMux = webrtc.NewICEUDPMux(nil, m.udpMuxLn) + } + + if iceTCPMuxAddress != "" { + m.tcpMuxLn, err = net.Listen(restrictNetwork("tcp", iceTCPMuxAddress)) + if err != nil { + m.udpMuxLn.Close() + m.httpServer.close() + ctxCancel() + return nil, err + } + m.iceTCPMux = webrtc.NewICETCPMux(nil, m.tcpMuxLn, 8) + } + + str := "listener opened on " + address + " (HTTP)" + if m.udpMuxLn != nil { + str += ", " + iceUDPMuxAddress + " (ICE/UDP)" + } + if m.tcpMuxLn != nil { + str += ", " + iceTCPMuxAddress + " (ICE/TCP)" + } + m.Log(logger.Info, str) + + if m.metrics != nil { + m.metrics.webRTCManagerSet(m) + } + + go m.run() + + return m, nil +} + +// Log is the main logging function. +func (m *webRTCManager) Log(level logger.Level, format string, args ...interface{}) { + m.parent.Log(level, "[WebRTC] "+format, append([]interface{}{}, args...)...) +} + +func (m *webRTCManager) close() { + m.Log(logger.Info, "listener is closing") + m.ctxCancel() + <-m.done +} + +func (m *webRTCManager) run() { + defer close(m.done) + + var wg sync.WaitGroup + +outer: + for { + select { + case req := <-m.chSessionNew: + sx := newWebRTCSession( + m.ctx, + m.readBufferCount, + req, + &wg, + m.iceHostNAT1To1IPs, + m.iceUDPMux, + m.iceTCPMux, + m.pathManager, + m, + ) + m.sessions[sx] = struct{}{} + m.sessionsBySecret[sx.secret] = sx + req.res <- webRTCNewSessionRes{sx: sx} + + case sx := <-m.chSessionClose: + delete(m.sessions, sx) + delete(m.sessionsBySecret, sx.secret) + + case req := <-m.chSessionAddCandidates: + sx, ok := m.sessionsBySecret[req.secret] + if !ok { + req.res <- webRTCSessionAddCandidatesRes{err: fmt.Errorf("session not found")} + continue + } + + req.res <- webRTCSessionAddCandidatesRes{sx: sx} + + case req := <-m.chAPISessionsList: + data := &webRTCManagerAPISessionsListData{ + Items: make(map[string]webRTCManagerAPISessionsListItem), + } + + for sx := range m.sessions { + peerConnectionEstablished := false + localCandidate := "" + remoteCandidate := "" + bytesReceived := uint64(0) + bytesSent := uint64(0) + + pc := sx.safePC() + if pc != nil { + peerConnectionEstablished = true + localCandidate = pc.localCandidate() + remoteCandidate = pc.remoteCandidate() + bytesReceived = pc.bytesReceived() + bytesSent = pc.bytesSent() + } + + data.Items[sx.uuid.String()] = webRTCManagerAPISessionsListItem{ + Created: sx.created, + RemoteAddr: sx.req.remoteAddr, + PeerConnectionEstablished: peerConnectionEstablished, + LocalCandidate: localCandidate, + RemoteCandidate: remoteCandidate, + State: func() string { + if sx.req.publish { + return "publish" + } + return "read" + }(), + BytesReceived: bytesReceived, + BytesSent: bytesSent, + } + } + + req.res <- webRTCManagerAPISessionsListRes{data: data} + + case req := <-m.chAPIConnsKick: + sx := m.findSessionByUUID(req.uuid) + if sx == nil { + req.res <- webRTCManagerAPISessionsKickRes{fmt.Errorf("not found")} + continue + } + + delete(m.sessions, sx) + delete(m.sessionsBySecret, sx.secret) + sx.close() + req.res <- webRTCManagerAPISessionsKickRes{} + + case <-m.ctx.Done(): + break outer + } + } + + m.ctxCancel() + + wg.Wait() + + m.httpServer.close() + + if m.udpMuxLn != nil { + m.udpMuxLn.Close() + } + + if m.tcpMuxLn != nil { + m.tcpMuxLn.Close() + } +} + +func (m *webRTCManager) findSessionByUUID(uuid uuid.UUID) *webRTCSession { + for sx := range m.sessions { + if sx.uuid == uuid { + return sx + } + } + return nil +} + +func (m *webRTCManager) genICEServers() []webrtc.ICEServer { + ret := make([]webrtc.ICEServer, len(m.iceServers)) + for i, s := range m.iceServers { + parts := strings.Split(s, ":") + if len(parts) == 5 { + if parts[1] == "AUTH_SECRET" { + s := webrtc.ICEServer{ + URLs: []string{parts[0] + ":" + parts[3] + ":" + parts[4]}, + } + + randomUser := func() string { + const charset = "abcdefghijklmnopqrstuvwxyz1234567890" + b := make([]byte, 20) + for i := range b { + b[i] = charset[rand.Intn(len(charset))] + } + return string(b) + }() + + expireDate := time.Now().Add(24 * 3600 * time.Second).Unix() + s.Username = strconv.FormatInt(expireDate, 10) + ":" + randomUser + + h := hmac.New(sha1.New, []byte(parts[2])) + h.Write([]byte(s.Username)) + s.Credential = base64.StdEncoding.EncodeToString(h.Sum(nil)) + + ret[i] = s + } else { + ret[i] = webrtc.ICEServer{ + URLs: []string{parts[0] + ":" + parts[3] + ":" + parts[4]}, + Username: parts[1], + Credential: parts[2], + } + } + } else { + ret[i] = webrtc.ICEServer{ + URLs: []string{s}, + } + } + } + return ret +} + +// sessionNew is called by webRTCHTTPServer. +func (m *webRTCManager) sessionNew(req webRTCSessionNewReq) webRTCNewSessionRes { + req.res = make(chan webRTCNewSessionRes) + + select { + case m.chSessionNew <- req: + res1 := <-req.res + + select { + case res2 := <-req.res: + return res2 + + case <-res1.sx.ctx.Done(): + return webRTCNewSessionRes{err: fmt.Errorf("terminated")} + } + + case <-m.ctx.Done(): + return webRTCNewSessionRes{err: fmt.Errorf("terminated")} + } +} + +// sessionClose is called by webRTCSession. +func (m *webRTCManager) sessionClose(sx *webRTCSession) { + select { + case m.chSessionClose <- sx: + case <-m.ctx.Done(): + } +} + +// sessionAddCandidates is called by webRTCHTTPServer. +func (m *webRTCManager) sessionAddCandidates( + req webRTCSessionAddCandidatesReq, +) webRTCSessionAddCandidatesRes { + req.res = make(chan webRTCSessionAddCandidatesRes) + select { + case m.chSessionAddCandidates <- req: + res1 := <-req.res + if res1.err != nil { + return res1 + } + + return res1.sx.addRemoteCandidates(req) + + case <-m.ctx.Done(): + return webRTCSessionAddCandidatesRes{err: fmt.Errorf("terminated")} + } +} + +// apiSessionsList is called by api. +func (m *webRTCManager) apiSessionsList() webRTCManagerAPISessionsListRes { + req := webRTCManagerAPISessionsListReq{ + res: make(chan webRTCManagerAPISessionsListRes), + } + + select { + case m.chAPISessionsList <- req: + return <-req.res + + case <-m.ctx.Done(): + return webRTCManagerAPISessionsListRes{err: fmt.Errorf("terminated")} + } +} + +// apiSessionsKick is called by api. +func (m *webRTCManager) apiSessionsKick(uuid uuid.UUID) webRTCManagerAPISessionsKickRes { + req := webRTCManagerAPISessionsKickReq{ + uuid: uuid, + res: make(chan webRTCManagerAPISessionsKickRes), + } + + select { + case m.chAPIConnsKick <- req: + return <-req.res + + case <-m.ctx.Done(): + return webRTCManagerAPISessionsKickRes{err: fmt.Errorf("terminated")} + } +} diff --git a/internal/core/webrtc_manager_test.go b/internal/core/webrtc_manager_test.go new file mode 100644 index 00000000..9339f811 --- /dev/null +++ b/internal/core/webrtc_manager_test.go @@ -0,0 +1,353 @@ +package core + +import ( + "bytes" + "encoding/json" + "net/http" + "sync" + "testing" + "time" + + "github.com/bluenviron/gortsplib/v3" + "github.com/bluenviron/gortsplib/v3/pkg/formats" + "github.com/bluenviron/gortsplib/v3/pkg/media" + "github.com/bluenviron/gortsplib/v3/pkg/url" + "github.com/pion/rtp" + "github.com/pion/webrtc/v3" + "github.com/stretchr/testify/require" +) + +func whipGetICEServers(t *testing.T, ur string) []webrtc.ICEServer { + req, err := http.NewRequest("OPTIONS", ur, nil) + require.NoError(t, err) + + res, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer res.Body.Close() + + require.Equal(t, http.StatusOK, res.StatusCode) + + link, ok := res.Header["Link"] + require.Equal(t, true, ok) + servers := linkHeaderToIceServers(link) + require.NotEqual(t, 0, len(servers)) + + return servers +} + +func whipPostOffer(t *testing.T, ur string, offer *webrtc.SessionDescription) (*webrtc.SessionDescription, string) { + enc, err := json.Marshal(offer) + require.NoError(t, err) + + req, err := http.NewRequest("POST", ur, bytes.NewReader(enc)) + require.NoError(t, err) + + req.Header.Set("Content-Type", "application/sdp") + + res, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer res.Body.Close() + + require.Equal(t, http.StatusCreated, res.StatusCode) + + link, ok := res.Header["Link"] + require.Equal(t, true, ok) + servers := linkHeaderToIceServers(link) + require.NotEqual(t, 0, len(servers)) + + require.Equal(t, "application/sdp", res.Header.Get("Content-Type")) + etag := res.Header.Get("E-Tag") + require.NotEqual(t, 0, len(etag)) + require.Equal(t, "application/trickle-ice-sdpfrag", res.Header.Get("Accept-Patch")) + + var answer webrtc.SessionDescription + err = json.NewDecoder(res.Body).Decode(&answer) + require.NoError(t, err) + + return &answer, etag +} + +func whipPostCandidate(t *testing.T, ur string, offer *webrtc.SessionDescription, + etag string, candidate *webrtc.ICECandidateInit, +) { + frag, err := marshalICEFragment(offer, []*webrtc.ICECandidateInit{candidate}) + require.NoError(t, err) + + req, err := http.NewRequest("PATCH", ur, bytes.NewReader(frag)) + require.NoError(t, err) + + req.Header.Set("Content-Type", "application/trickle-ice-sdpfrag") + req.Header.Set("If-Match", etag) + + res, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer res.Body.Close() + + require.Equal(t, http.StatusNoContent, res.StatusCode) +} + +type webRTCTestClient struct { + pc *webrtc.PeerConnection + outgoingTrack1 *webrtc.TrackLocalStaticRTP + outgoingTrack2 *webrtc.TrackLocalStaticRTP + incomingTrack chan *webrtc.TrackRemote + closed chan struct{} +} + +func newWebRTCTestClient(t *testing.T, ur string, publish bool) *webRTCTestClient { + iceServers := whipGetICEServers(t, ur) + + pc, err := webrtc.NewPeerConnection(webrtc.Configuration{ + ICEServers: iceServers, + }) + require.NoError(t, err) + + connected := make(chan struct{}) + closed := make(chan struct{}) + var stateChangeMutex sync.Mutex + + pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { + stateChangeMutex.Lock() + defer stateChangeMutex.Unlock() + + select { + case <-closed: + return + default: + } + + switch state { + case webrtc.PeerConnectionStateConnected: + close(connected) + + case webrtc.PeerConnectionStateClosed: + close(closed) + } + }) + + var outgoingTrack1 *webrtc.TrackLocalStaticRTP + var outgoingTrack2 *webrtc.TrackLocalStaticRTP + var incomingTrack chan *webrtc.TrackRemote + + if publish { + var err error + outgoingTrack1, err = webrtc.NewTrackLocalStaticRTP( + webrtc.RTPCodecCapability{ + MimeType: webrtc.MimeTypeVP8, + ClockRate: 90000, + }, + "vp8", + webrtcStreamID, + ) + require.NoError(t, err) + + _, err = pc.AddTrack(outgoingTrack1) + require.NoError(t, err) + + outgoingTrack2, err = webrtc.NewTrackLocalStaticRTP( + webrtc.RTPCodecCapability{ + MimeType: webrtc.MimeTypeOpus, + ClockRate: 48000, + Channels: 2, + }, + "opus", + webrtcStreamID, + ) + require.NoError(t, err) + + _, err = pc.AddTrack(outgoingTrack2) + require.NoError(t, err) + } else { + incomingTrack = make(chan *webrtc.TrackRemote, 1) + pc.OnTrack(func(trak *webrtc.TrackRemote, recv *webrtc.RTPReceiver) { + incomingTrack <- trak + }) + + _, err = pc.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo) + require.NoError(t, err) + } + + offer, err := pc.CreateOffer(nil) + require.NoError(t, err) + + answer, etag := whipPostOffer(t, ur, &offer) + + // test adding additional candidates, even if it is not mandatory here + gatheringDone := make(chan struct{}) + pc.OnICECandidate(func(i *webrtc.ICECandidate) { + if i != nil { + c := i.ToJSON() + whipPostCandidate(t, ur, &offer, etag, &c) + } else { + close(gatheringDone) + } + }) + + err = pc.SetLocalDescription(offer) + require.NoError(t, err) + + err = pc.SetRemoteDescription(*answer) + require.NoError(t, err) + + <-gatheringDone + <-connected + + if publish { + time.Sleep(200 * time.Millisecond) + + err := outgoingTrack1.WriteRTP(&rtp.Packet{ + Header: rtp.Header{ + Version: 2, + Marker: true, + PayloadType: 96, + SequenceNumber: 123, + Timestamp: 45343, + SSRC: 563423, + }, + Payload: []byte{0x01, 0x02, 0x03, 0x04}, + }) + require.NoError(t, err) + + err = outgoingTrack2.WriteRTP(&rtp.Packet{ + Header: rtp.Header{ + Version: 2, + Marker: true, + PayloadType: 96, + SequenceNumber: 1123, + Timestamp: 45343, + SSRC: 563423, + }, + Payload: []byte{0x01, 0x02, 0x03, 0x04}, + }) + require.NoError(t, err) + + time.Sleep(200 * time.Millisecond) + } + + return &webRTCTestClient{ + pc: pc, + outgoingTrack1: outgoingTrack1, + outgoingTrack2: outgoingTrack2, + incomingTrack: incomingTrack, + closed: closed, + } +} + +func (c *webRTCTestClient) close() { + c.pc.Close() + <-c.closed +} + +func TestWebRTCRead(t *testing.T) { + p, ok := newInstance("paths:\n" + + " all:\n") + require.Equal(t, true, ok) + defer p.Close() + + medi := &media.Media{ + Type: media.TypeVideo, + Formats: []formats.Format{&formats.H264{ + PayloadTyp: 96, + PacketizationMode: 1, + }}, + } + + v := gortsplib.TransportTCP + source := gortsplib.Client{ + Transport: &v, + } + err := source.StartRecording("rtsp://localhost:8554/stream", media.Medias{medi}) + require.NoError(t, err) + defer source.Close() + + c := newWebRTCTestClient(t, "http://localhost:8889/stream/whep", false) + defer c.close() + + time.Sleep(500 * time.Millisecond) + + source.WritePacketRTP(medi, &rtp.Packet{ + Header: rtp.Header{ + Version: 2, + Marker: true, + PayloadType: 96, + SequenceNumber: 123, + Timestamp: 45343, + SSRC: 563423, + }, + Payload: []byte{0x01, 0x02, 0x03, 0x04}, + }) + + trak := <-c.incomingTrack + + pkt, _, err := trak.ReadRTP() + require.NoError(t, err) + require.Equal(t, &rtp.Packet{ + Header: rtp.Header{ + Version: 2, + Marker: true, + PayloadType: 102, + SequenceNumber: pkt.SequenceNumber, + Timestamp: pkt.Timestamp, + SSRC: pkt.SSRC, + CSRC: []uint32{}, + }, + Payload: []byte{0x01, 0x02, 0x03, 0x04}, + }, pkt) +} + +func TestWebRTCPublish(t *testing.T) { + p, ok := newInstance("paths:\n" + + " all:\n") + require.Equal(t, true, ok) + defer p.Close() + + s := newWebRTCTestClient(t, "http://localhost:8889/stream/whip", true) + defer s.close() + + c := gortsplib.Client{ + OnDecodeError: func(err error) { + panic(err) + }, + } + + u, err := url.Parse("rtsp://127.0.0.1:8554/stream") + require.NoError(t, err) + + err = c.Start(u.Scheme, u.Host) + require.NoError(t, err) + defer c.Close() + + medias, baseURL, _, err := c.Describe(u) + require.NoError(t, err) + + var forma *formats.VP8 + medi := medias.FindFormat(&forma) + + _, err = c.Setup(medi, baseURL, 0, 0) + require.NoError(t, err) + + received := make(chan struct{}) + + c.OnPacketRTP(medi, forma, func(pkt *rtp.Packet) { + require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, pkt.Payload) + close(received) + }) + + _, err = c.Play(nil) + require.NoError(t, err) + + err = s.outgoingTrack1.WriteRTP(&rtp.Packet{ + Header: rtp.Header{ + Version: 2, + Marker: true, + PayloadType: 96, + SequenceNumber: 124, + Timestamp: 45343, + SSRC: 563423, + }, + Payload: []byte{0x05, 0x06, 0x07, 0x08}, + }) + require.NoError(t, err) + + <-received +} diff --git a/internal/core/webrtc_outgoing_track.go b/internal/core/webrtc_outgoing_track.go index 3aecf554..1788969b 100644 --- a/internal/core/webrtc_outgoing_track.go +++ b/internal/core/webrtc_outgoing_track.go @@ -12,6 +12,7 @@ import ( "github.com/bluenviron/gortsplib/v3/pkg/formats/rtpvp8" "github.com/bluenviron/gortsplib/v3/pkg/formats/rtpvp9" "github.com/bluenviron/gortsplib/v3/pkg/media" + "github.com/bluenviron/gortsplib/v3/pkg/ringbuffer" "github.com/pion/webrtc/v3" ) @@ -20,7 +21,7 @@ type webRTCOutgoingTrack struct { media *media.Media format formats.Format track *webrtc.TrackLocalStaticRTP - cb func(formatprocessor.Unit, context.Context, chan error) + cb func(formatprocessor.Unit) error } func newWebRTCOutgoingTrackVideo(medias media.Medias) (*webRTCOutgoingTrack, error) { @@ -34,7 +35,7 @@ func newWebRTCOutgoingTrackVideo(medias media.Medias) (*webRTCOutgoingTrack, err ClockRate: 90000, }, "av1", - "rtspss", + webrtcStreamID, ) if err != nil { return nil, err @@ -50,21 +51,23 @@ func newWebRTCOutgoingTrackVideo(medias media.Medias) (*webRTCOutgoingTrack, err media: av1Media, format: av1Format, track: webRTCTrak, - cb: func(unit formatprocessor.Unit, ctx context.Context, writeError chan error) { + cb: func(unit formatprocessor.Unit) error { tunit := unit.(*formatprocessor.UnitAV1) if tunit.OBUs == nil { - return + return nil } packets, err := encoder.Encode(tunit.OBUs, tunit.PTS) if err != nil { - return + return nil } for _, pkt := range packets { webRTCTrak.WriteRTP(pkt) } + + return nil }, }, nil } @@ -79,7 +82,7 @@ func newWebRTCOutgoingTrackVideo(medias media.Medias) (*webRTCOutgoingTrack, err ClockRate: uint32(vp9Format.ClockRate()), }, "vp9", - "rtspss", + webrtcStreamID, ) if err != nil { return nil, err @@ -95,21 +98,23 @@ func newWebRTCOutgoingTrackVideo(medias media.Medias) (*webRTCOutgoingTrack, err media: vp9Media, format: vp9Format, track: webRTCTrak, - cb: func(unit formatprocessor.Unit, ctx context.Context, writeError chan error) { + cb: func(unit formatprocessor.Unit) error { tunit := unit.(*formatprocessor.UnitVP9) if tunit.Frame == nil { - return + return nil } packets, err := encoder.Encode(tunit.Frame, tunit.PTS) if err != nil { - return + return nil } for _, pkt := range packets { webRTCTrak.WriteRTP(pkt) } + + return nil }, }, nil } @@ -124,7 +129,7 @@ func newWebRTCOutgoingTrackVideo(medias media.Medias) (*webRTCOutgoingTrack, err ClockRate: uint32(vp8Format.ClockRate()), }, "vp8", - "rtspss", + webrtcStreamID, ) if err != nil { return nil, err @@ -140,21 +145,23 @@ func newWebRTCOutgoingTrackVideo(medias media.Medias) (*webRTCOutgoingTrack, err media: vp8Media, format: vp8Format, track: webRTCTrak, - cb: func(unit formatprocessor.Unit, ctx context.Context, writeError chan error) { + cb: func(unit formatprocessor.Unit) error { tunit := unit.(*formatprocessor.UnitVP8) if tunit.Frame == nil { - return + return nil } packets, err := encoder.Encode(tunit.Frame, tunit.PTS) if err != nil { - return + return nil } for _, pkt := range packets { webRTCTrak.WriteRTP(pkt) } + + return nil }, }, nil } @@ -169,7 +176,7 @@ func newWebRTCOutgoingTrackVideo(medias media.Medias) (*webRTCOutgoingTrack, err ClockRate: uint32(h264Format.ClockRate()), }, "h264", - "rtspss", + webrtcStreamID, ) if err != nil { return nil, err @@ -188,11 +195,11 @@ func newWebRTCOutgoingTrackVideo(medias media.Medias) (*webRTCOutgoingTrack, err media: h264Media, format: h264Format, track: webRTCTrak, - cb: func(unit formatprocessor.Unit, ctx context.Context, writeError chan error) { + cb: func(unit formatprocessor.Unit) error { tunit := unit.(*formatprocessor.UnitH264) if tunit.AU == nil { - return + return nil } if !firstNALUReceived { @@ -200,23 +207,21 @@ func newWebRTCOutgoingTrackVideo(medias media.Medias) (*webRTCOutgoingTrack, err lastPTS = tunit.PTS } else { if tunit.PTS < lastPTS { - select { - case writeError <- fmt.Errorf("WebRTC doesn't support H264 streams with B-frames"): - case <-ctx.Done(): - } - return + return fmt.Errorf("WebRTC doesn't support H264 streams with B-frames") } lastPTS = tunit.PTS } packets, err := encoder.Encode(tunit.AU, tunit.PTS) if err != nil { - return + return nil } for _, pkt := range packets { webRTCTrak.WriteRTP(pkt) } + + return nil }, }, nil } @@ -233,9 +238,10 @@ func newWebRTCOutgoingTrackAudio(medias media.Medias) (*webRTCOutgoingTrack, err webrtc.RTPCodecCapability{ MimeType: webrtc.MimeTypeOpus, ClockRate: uint32(opusFormat.ClockRate()), + Channels: 2, }, "opus", - "rtspss", + webrtcStreamID, ) if err != nil { return nil, err @@ -245,10 +251,12 @@ func newWebRTCOutgoingTrackAudio(medias media.Medias) (*webRTCOutgoingTrack, err media: opusMedia, format: opusFormat, track: webRTCTrak, - cb: func(unit formatprocessor.Unit, ctx context.Context, writeError chan error) { + cb: func(unit formatprocessor.Unit) error { for _, pkt := range unit.GetRTPPackets() { webRTCTrak.WriteRTP(pkt) } + + return nil }, }, nil } @@ -263,7 +271,7 @@ func newWebRTCOutgoingTrackAudio(medias media.Medias) (*webRTCOutgoingTrack, err ClockRate: uint32(g722Format.ClockRate()), }, "g722", - "rtspss", + webrtcStreamID, ) if err != nil { return nil, err @@ -273,10 +281,12 @@ func newWebRTCOutgoingTrackAudio(medias media.Medias) (*webRTCOutgoingTrack, err media: g722Media, format: g722Format, track: webRTCTrak, - cb: func(unit formatprocessor.Unit, ctx context.Context, writeError chan error) { + cb: func(unit formatprocessor.Unit) error { for _, pkt := range unit.GetRTPPackets() { webRTCTrak.WriteRTP(pkt) } + + return nil }, }, nil } @@ -298,7 +308,7 @@ func newWebRTCOutgoingTrackAudio(medias media.Medias) (*webRTCOutgoingTrack, err ClockRate: uint32(g711Format.ClockRate()), }, "g711", - "rtspss", + webrtcStreamID, ) if err != nil { return nil, err @@ -308,10 +318,12 @@ func newWebRTCOutgoingTrackAudio(medias media.Medias) (*webRTCOutgoingTrack, err media: g711Media, format: g711Format, track: webRTCTrak, - cb: func(unit formatprocessor.Unit, ctx context.Context, writeError chan error) { + cb: func(unit formatprocessor.Unit) error { for _, pkt := range unit.GetRTPPackets() { webRTCTrak.WriteRTP(pkt) } + + return nil }, }, nil } @@ -319,7 +331,13 @@ func newWebRTCOutgoingTrackAudio(medias media.Medias) (*webRTCOutgoingTrack, err return nil, nil } -func (t *webRTCOutgoingTrack) start() { +func (t *webRTCOutgoingTrack) start( + ctx context.Context, + r reader, + stream *stream, + ringBuffer *ringbuffer.RingBuffer, + writeError chan error, +) { // read incoming RTCP packets to make interceptors work go func() { buf := make([]byte, 1500) @@ -330,4 +348,16 @@ func (t *webRTCOutgoingTrack) start() { } } }() + + stream.readerAdd(r, t.media, t.format, func(unit formatprocessor.Unit) { + ringBuffer.Push(func() { + err := t.cb(unit) + if err != nil { + select { + case writeError <- err: + case <-ctx.Done(): + } + } + }) + }) } diff --git a/internal/core/webrtc_pc.go b/internal/core/webrtc_pc.go index eb8bcb2b..b626be25 100644 --- a/internal/core/webrtc_pc.go +++ b/internal/core/webrtc_pc.go @@ -18,6 +18,7 @@ type peerConnection struct { connected chan struct{} disconnected chan struct{} closed chan struct{} + gatheringDone chan struct{} } func newPeerConnection( @@ -222,6 +223,7 @@ func newPeerConnection( connected: make(chan struct{}), disconnected: make(chan struct{}), closed: make(chan struct{}), + gatheringDone: make(chan struct{}), } pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { @@ -238,6 +240,9 @@ func newPeerConnection( switch state { case webrtc.PeerConnectionStateConnected: + log.Log(logger.Info, "peer connection established, local candidate: %v, remote candidate: %v", + co.localCandidate(), co.remoteCandidate()) + close(co.connected) case webrtc.PeerConnectionStateDisconnected: @@ -256,6 +261,8 @@ func newPeerConnection( case <-co.connected: case <-co.closed: } + } else { + close(co.gatheringDone) } }) diff --git a/internal/core/webrtc_publish_index.html b/internal/core/webrtc_publish_index.html index 53952f80..74921000 100644 --- a/internal/core/webrtc_publish_index.html +++ b/internal/core/webrtc_publish_index.html @@ -110,56 +110,208 @@ const setState = (newState) => { const restartPause = 2000; +const linkToIceServers = (links) => ( + links.split(', ').map((link) => { + const m = link.match(/^<(.+?)>; rel="ice-server"(; username="(.*?)"; credential="(.*?)"; credential-type="password")?/i); + const ret = { + urls: [m[1]], + }; + + if (m[3] !== undefined) { + ret.username = m[3]; + ret.credential = m[4]; + ret.credentialType = "password"; + } + + return ret; + }) +); + +const parseOffer = (offer) => { + const ret = { + iceUfrag: '', + icePwd: '', + medias: [], + }; + + for (const line of offer.split('\r\n')) { + if (line.startsWith('m=')) { + ret.medias.push(line.slice('m='.length)); + } else if (ret.iceUfrag === '' && line.startsWith('a=ice-ufrag:')) { + ret.iceUfrag = line.slice('a=ice-ufrag:'.length); + } else if (ret.icePwd === '' && line.startsWith('a=ice-pwd:')) { + ret.icePwd = line.slice('a=ice-pwd:'.length); + } + } + + return ret; +}; + +const generateSdpFragment = (offerData, candidates) => { + const candidatesByMedia = {}; + for (const candidate of candidates) { + const mid = candidate.sdpMLineIndex; + if (candidatesByMedia[mid] === undefined) { + candidatesByMedia[mid] = []; + } + candidatesByMedia[mid].push(candidate); + } + + let frag = 'a=ice-ufrag:' + offerData.iceUfrag + '\r\n' + + 'a=ice-pwd:' + offerData.icePwd + '\r\n'; + + let mid = 0; + + for (const media of offerData.medias) { + if (candidatesByMedia[mid] !== undefined) { + frag += 'm=' + media + '\r\n' + + 'a=mid:' + mid + '\r\n'; + + for (const candidate of candidatesByMedia[mid]) { + frag += 'a=' + candidate.candidate + '\r\n'; + } + } + mid++; + } + + return frag; +} + class Transmitter { constructor(stream) { this.stream = stream; - this.terminated = false; - this.ws = null; this.pc = null; this.restartTimeout = null; + this.eTag = ''; + this.queuedCandidates = []; this.start(); } - start = () => { - console.log("connecting"); + start() { + console.log("requesting ICE servers"); - const videoCodec = document.getElementById('video_codec').value; - const audioCodec = document.getElementById('audio_codec').value; - const videoBitrate = document.getElementById('video_bitrate').value; + fetch('whip', { + method: 'OPTIONS', + }) + .then((res) => this.onIceServers(res)) + .catch((err) => { + console.log('error: ' + err); + this.scheduleRestart(); + }); + } - const u = window.location.href.replace(/^http/, "ws") + '/ws' + - '?video_codec=' + videoCodec + - '&audio_codec=' + audioCodec + - '&video_bitrate=' + videoBitrate; + onIceServers(res) { + this.pc = new RTCPeerConnection({ + iceServers: linkToIceServers(res.headers.get('Link')), + }); - this.ws = new WebSocket(u); + this.pc.onicecandidate = (evt) => this.onLocalCandidate(evt); + this.pc.oniceconnectionstatechange = () => this.onConnectionState(); - this.ws.onerror = () => { - console.log("ws error"); - if (this.ws === null) { - return; - } - this.ws.close(); - this.ws = null; - }; + this.stream.getTracks().forEach((track) => { + this.pc.addTrack(track, this.stream); + }); - this.ws.onclose = () => { - console.log("ws closed"); - this.ws = null; - this.scheduleRestart(); - }; + this.pc.createOffer() + .then((desc) => { + this.offerData = parseOffer(desc.sdp); + this.pc.setLocalDescription(desc); - this.ws.onmessage = this.onIceServers; - }; + console.log("sending offer"); - scheduleRestart = () => { - if (this.terminated) { + const videoCodec = document.getElementById('video_codec').value; + const audioCodec = document.getElementById('audio_codec').value; + const videoBitrate = document.getElementById('video_bitrate').value; + + let params = '?video_codec=' + videoCodec + + '&audio_codec=' + audioCodec + + '&video_bitrate=' + videoBitrate; + + fetch('whip' + params, { + method: 'POST', + headers: { + 'Content-Type': 'application/sdp', + }, + body: JSON.stringify(desc), + }) + .then((res) => { + if (res.status !== 201) { + throw new Error('bad status code'); + } + this.eTag = res.headers.get('E-Tag'); + return res.json(); + }) + .then((answer) => this.onRemoteDescription(answer)) + .catch((err) => { + console.log('error: ' + err); + this.scheduleRestart(); + }); + }); + } + + onConnectionState() { + if (this.restartTimeout !== null) { return; } - if (this.ws !== null) { - this.ws.close(); - this.ws = null; + console.log("peer connection state:", this.pc.iceConnectionState); + + switch (this.pc.iceConnectionState) { + case "disconnected": + this.scheduleRestart(); + } + } + + onRemoteDescription(answer) { + if (this.restartTimeout !== null) { + return; + } + + this.pc.setRemoteDescription(new RTCSessionDescription(answer)); + + if (this.queuedCandidates.length !== 0) { + this.sendLocalCandidates(this.queuedCandidates); + this.queuedCandidates = []; + } + } + + onLocalCandidate(evt) { + if (this.restartTimeout !== null) { + return; + } + + if (evt.candidate !== null) { + if (this.eTag === '') { + this.queuedCandidates.push(evt.candidate); + } else { + this.sendLocalCandidates([evt.candidate]) + } + } + } + + sendLocalCandidates(candidates) { + fetch('whip', { + method: 'PATCH', + headers: { + 'Content-Type': 'application/trickle-ice-sdpfrag', + 'If-Match': this.eTag, + }, + body: generateSdpFragment(this.offerData, candidates), + }) + .then((res) => { + if (res.status !== 204) { + throw new Error('bad status code'); + } + }) + .catch((err) => { + console.log('error: ' + err); + this.scheduleRestart(); + }); + } + + scheduleRestart() { + if (this.restartTimeout !== null) { + return; } if (this.pc !== null) { @@ -171,74 +323,10 @@ class Transmitter { this.restartTimeout = null; this.start(); }, restartPause); - }; - onIceServers = (msg) => { - if (this.ws === null) { - return; - } - - this.pc = new RTCPeerConnection({ - iceServers: JSON.parse(msg.data), - }); - - this.ws.onmessage = this.onOffer; - }; - - onOffer = (msg) => { - if (this.ws === null || this.pc === null) { - return; - } - - this.stream.getTracks().forEach((track) => { - this.pc.addTrack(track, this.stream); - }); - - this.ws.onmessage = (msg) => { - if (this.pc === null) { - return; - } - this.pc.addIceCandidate(JSON.parse(msg.data)); - }; - - this.pc.onicecandidate = (evt) => { - if (this.ws === null) { - return; - } - - if (evt.candidate !== null) { - if (evt.candidate.candidate !== "") { - this.ws.send(JSON.stringify(evt.candidate)); - } - } - }; - - this.pc.oniceconnectionstatechange = () => { - if (this.pc === null) { - return; - } - - console.log("peer connection state:", this.pc.iceConnectionState); - - switch (this.pc.iceConnectionState) { - case "failed": - case "disconnected": - this.scheduleRestart(); - } - }; - - this.pc.setRemoteDescription(new RTCSessionDescription(JSON.parse(msg.data))); - - this.pc.createAnswer() - .then((desc) => { - if (this.ws === null || this.pc === null) { - return; - } - - this.pc.setLocalDescription(desc); - this.ws.send(JSON.stringify(desc)); - }); - }; + this.eTag = ''; + this.queuedCandidates = []; + } } const onTransmit = (stream) => { diff --git a/internal/core/webrtc_read_index.html b/internal/core/webrtc_read_index.html index ca715aab..3e3a8b6f 100644 --- a/internal/core/webrtc_read_index.html +++ b/internal/core/webrtc_read_index.html @@ -25,124 +25,206 @@ html, body { const restartPause = 2000; -class Receiver { +const linkToIceServers = (links) => ( + links.split(', ').map((link) => { + const m = link.match(/^<(.+?)>; rel="ice-server"(; username="(.+?)"; credential="(.+?)"; credential-type="password")?/i); + const ret = { + urls: [m[1]], + }; + + if (m[3] !== undefined) { + ret.username = m[3]; + ret.credential = m[4]; + ret.credentialType = "password"; + } + + return ret; + }) +); + +const parseOffer = (offer) => { + const ret = { + iceUfrag: '', + icePwd: '', + medias: [], + }; + + for (const line of offer.split('\r\n')) { + if (line.startsWith('m=')) { + ret.medias.push(line.slice('m='.length)); + } else if (ret.iceUfrag === '' && line.startsWith('a=ice-ufrag:')) { + ret.iceUfrag = line.slice('a=ice-ufrag:'.length); + } else if (ret.icePwd === '' && line.startsWith('a=ice-pwd:')) { + ret.icePwd = line.slice('a=ice-pwd:'.length); + } + } + + return ret; +}; + +const generateSdpFragment = (offerData, candidates) => { + const candidatesByMedia = {}; + for (const candidate of candidates) { + const mid = candidate.sdpMLineIndex; + if (candidatesByMedia[mid] === undefined) { + candidatesByMedia[mid] = []; + } + candidatesByMedia[mid].push(candidate); + } + + let frag = 'a=ice-ufrag:' + offerData.iceUfrag + '\r\n' + + 'a=ice-pwd:' + offerData.icePwd + '\r\n'; + + let mid = 0; + + for (const media of offerData.medias) { + if (candidatesByMedia[mid] !== undefined) { + frag += 'm=' + media + '\r\n' + + 'a=mid:' + mid + '\r\n'; + + for (const candidate of candidatesByMedia[mid]) { + frag += 'a=' + candidate.candidate + '\r\n'; + } + } + mid++; + } + + return frag; +} + +class WHEPClient { constructor() { - this.terminated = false; - this.ws = null; this.pc = null; this.restartTimeout = null; + this.eTag = ''; + this.queuedCandidates = []; this.start(); } start() { - console.log("connecting"); + console.log("requesting ICE servers"); - this.ws = new WebSocket(window.location.href.replace(/^http/, "ws") + 'ws'); - - this.ws.onerror = () => { - console.log("ws error"); - if (this.ws === null) { - return; - } - this.ws.close(); - this.ws = null; - }; - - this.ws.onclose = () => { - console.log("ws closed"); - this.ws = null; - this.scheduleRestart(); - }; - - this.ws.onmessage = (msg) => this.onIceServers(msg); + fetch('whep', { + method: 'OPTIONS', + }) + .then((res) => this.onIceServers(res)) + .catch((err) => { + console.log('error: ' + err); + this.scheduleRestart(); + }); } - onIceServers(msg) { - if (this.ws === null) { - return; - } - + onIceServers(res) { this.pc = new RTCPeerConnection({ - iceServers: JSON.parse(msg.data), + iceServers: linkToIceServers(res.headers.get('Link')), }); - this.ws.onmessage = (msg) => this.onRemoteDescription(msg); - this.pc.onicecandidate = (evt) => this.onIceCandidate(evt); - - this.pc.oniceconnectionstatechange = () => { - if (this.pc === null) { - return; - } - - console.log("peer connection state:", this.pc.iceConnectionState); - - switch (this.pc.iceConnectionState) { - case "disconnected": - this.scheduleRestart(); - } - }; - - this.pc.ontrack = (evt) => { - console.log("new track " + evt.track.kind); - document.getElementById("video").srcObject = evt.streams[0]; - }; - const direction = "sendrecv"; this.pc.addTransceiver("video", { direction }); this.pc.addTransceiver("audio", { direction }); + this.pc.onicecandidate = (evt) => this.onLocalCandidate(evt); + this.pc.oniceconnectionstatechange = () => this.onConnectionState(); + + this.pc.ontrack = (evt) => { + console.log("new track:", evt.track.kind); + document.getElementById("video").srcObject = evt.streams[0]; + }; + this.pc.createOffer() .then((desc) => { - if (this.pc === null || this.ws === null) { - return; - } - + this.offerData = parseOffer(desc.sdp); this.pc.setLocalDescription(desc); console.log("sending offer"); - this.ws.send(JSON.stringify(desc)); + + fetch('whep', { + method: 'POST', + headers: { + 'Content-Type': 'application/sdp', + }, + body: JSON.stringify(desc), + }) + .then((res) => { + if (res.status !== 201) { + throw new Error('bad status code'); + } + this.eTag = res.headers.get('E-Tag'); + return res.json(); + }) + .then((answer) => this.onRemoteDescription(answer)) + .catch((err) => { + console.log('error: ' + err); + this.scheduleRestart(); + }); }); } - onRemoteDescription(msg) { - if (this.pc === null || this.ws === null) { - return; - } + onConnectionState() { + if (this.restartTimeout !== null) { + return; + } - this.pc.setRemoteDescription(new RTCSessionDescription(JSON.parse(msg.data))); - this.ws.onmessage = (msg) => this.onRemoteCandidate(msg); + console.log("peer connection state:", this.pc.iceConnectionState); + + switch (this.pc.iceConnectionState) { + case "disconnected": + this.scheduleRestart(); + } + } + + onRemoteDescription(answer) { + if (this.restartTimeout !== null) { + return; + } + + this.pc.setRemoteDescription(new RTCSessionDescription(answer)); + + if (this.queuedCandidates.length !== 0) { + this.sendLocalCandidates(this.queuedCandidates); + this.queuedCandidates = []; + } } - onIceCandidate(evt) { - if (this.ws === null) { + onLocalCandidate(evt) { + if (this.restartTimeout !== null) { return; } if (evt.candidate !== null) { - if (evt.candidate.candidate !== "") { - this.ws.send(JSON.stringify(evt.candidate)); + if (this.eTag === '') { + this.queuedCandidates.push(evt.candidate); + } else { + this.sendLocalCandidates([evt.candidate]) } } } - onRemoteCandidate(msg) { - if (this.pc === null) { - return; - } - - this.pc.addIceCandidate(JSON.parse(msg.data)); - } + sendLocalCandidates(candidates) { + fetch('whep', { + method: 'PATCH', + headers: { + 'Content-Type': 'application/trickle-ice-sdpfrag', + 'If-Match': this.eTag, + }, + body: generateSdpFragment(this.offerData, candidates), + }) + .then((res) => { + if (res.status !== 204) { + throw new Error('bad status code'); + } + }) + .catch((err) => { + console.log('error: ' + err); + this.scheduleRestart(); + }); + } scheduleRestart() { - if (this.terminated) { + if (this.restartTimeout !== null) { return; } - if (this.ws !== null) { - this.ws.close(); - this.ws = null; - } - if (this.pc !== null) { this.pc.close(); this.pc = null; @@ -152,10 +234,13 @@ class Receiver { this.restartTimeout = null; this.start(); }, restartPause); + + this.eTag = ''; + this.queuedCandidates = []; } } -window.addEventListener('DOMContentLoaded', () => new Receiver()); +window.addEventListener('DOMContentLoaded', () => new WHEPClient()); diff --git a/internal/core/webrtc_server.go b/internal/core/webrtc_server.go deleted file mode 100644 index 44263f1b..00000000 --- a/internal/core/webrtc_server.go +++ /dev/null @@ -1,522 +0,0 @@ -package core - -import ( - "context" - "crypto/tls" - _ "embed" - "fmt" - "log" - "net" - "net/http" - "strings" - "sync" - "time" - - "github.com/gin-gonic/gin" - "github.com/pion/ice/v2" - "github.com/pion/webrtc/v3" - - "github.com/aler9/mediamtx/internal/conf" - "github.com/aler9/mediamtx/internal/logger" - "github.com/aler9/mediamtx/internal/websocket" -) - -//go:embed webrtc_publish_index.html -var webrtcPublishIndex []byte - -//go:embed webrtc_read_index.html -var webrtcReadIndex []byte - -type webRTCServerAPIConnsListItem struct { - Created time.Time `json:"created"` - RemoteAddr string `json:"remoteAddr"` - PeerConnectionEstablished bool `json:"peerConnectionEstablished"` - LocalCandidate string `json:"localCandidate"` - RemoteCandidate string `json:"remoteCandidate"` - State string `json:"state"` - BytesReceived uint64 `json:"bytesReceived"` - BytesSent uint64 `json:"bytesSent"` -} - -type webRTCServerAPIConnsListData struct { - Items map[string]webRTCServerAPIConnsListItem `json:"items"` -} - -type webRTCServerAPIConnsListRes struct { - data *webRTCServerAPIConnsListData - err error -} - -type webRTCServerAPIConnsListReq struct { - res chan webRTCServerAPIConnsListRes -} - -type webRTCServerAPIConnsKickRes struct { - err error -} - -type webRTCServerAPIConnsKickReq struct { - id string - res chan webRTCServerAPIConnsKickRes -} - -type webRTCConnNewReq struct { - pathName string - publish bool - wsconn *websocket.ServerConn - res chan *webRTCConn - videoCodec string - audioCodec string - videoBitrate string -} - -type webRTCServerParent interface { - logger.Writer -} - -type webRTCServer struct { - allowOrigin string - trustedProxies conf.IPsOrCIDRs - iceServers []string - readBufferCount int - pathManager *pathManager - metrics *metrics - parent webRTCServerParent - - ctx context.Context - ctxCancel func() - ln net.Listener - requestPool *httpRequestPool - httpServer *http.Server - udpMuxLn net.PacketConn - tcpMuxLn net.Listener - conns map[*webRTCConn]struct{} - iceHostNAT1To1IPs []string - iceUDPMux ice.UDPMux - iceTCPMux ice.TCPMux - - // in - connNew chan webRTCConnNewReq - chConnClose chan *webRTCConn - chAPIConnsList chan webRTCServerAPIConnsListReq - chAPIConnsKick chan webRTCServerAPIConnsKickReq - - // out - done chan struct{} -} - -func newWebRTCServer( - parentCtx context.Context, - address string, - encryption bool, - serverKey string, - serverCert string, - allowOrigin string, - trustedProxies conf.IPsOrCIDRs, - iceServers []string, - readTimeout conf.StringDuration, - readBufferCount int, - pathManager *pathManager, - metrics *metrics, - parent webRTCServerParent, - iceHostNAT1To1IPs []string, - iceUDPMuxAddress string, - iceTCPMuxAddress string, -) (*webRTCServer, error) { - ln, err := net.Listen(restrictNetwork("tcp", address)) - if err != nil { - return nil, err - } - - var tlsConfig *tls.Config - if encryption { - crt, err := tls.LoadX509KeyPair(serverCert, serverKey) - if err != nil { - ln.Close() - return nil, err - } - - tlsConfig = &tls.Config{ - Certificates: []tls.Certificate{crt}, - } - } - - var iceUDPMux ice.UDPMux - var udpMuxLn net.PacketConn - if iceUDPMuxAddress != "" { - udpMuxLn, err = net.ListenPacket(restrictNetwork("udp", iceUDPMuxAddress)) - if err != nil { - return nil, err - } - iceUDPMux = webrtc.NewICEUDPMux(nil, udpMuxLn) - } - - var iceTCPMux ice.TCPMux - var tcpMuxLn net.Listener - if iceTCPMuxAddress != "" { - tcpMuxLn, err = net.Listen(restrictNetwork("tcp", iceTCPMuxAddress)) - if err != nil { - return nil, err - } - iceTCPMux = webrtc.NewICETCPMux(nil, tcpMuxLn, 8) - } - - ctx, ctxCancel := context.WithCancel(parentCtx) - - s := &webRTCServer{ - allowOrigin: allowOrigin, - trustedProxies: trustedProxies, - iceServers: iceServers, - readBufferCount: readBufferCount, - pathManager: pathManager, - metrics: metrics, - parent: parent, - ctx: ctx, - ctxCancel: ctxCancel, - ln: ln, - udpMuxLn: udpMuxLn, - tcpMuxLn: tcpMuxLn, - iceUDPMux: iceUDPMux, - iceTCPMux: iceTCPMux, - iceHostNAT1To1IPs: iceHostNAT1To1IPs, - conns: make(map[*webRTCConn]struct{}), - connNew: make(chan webRTCConnNewReq), - chConnClose: make(chan *webRTCConn), - chAPIConnsList: make(chan webRTCServerAPIConnsListReq), - chAPIConnsKick: make(chan webRTCServerAPIConnsKickReq), - done: make(chan struct{}), - } - - s.requestPool = newHTTPRequestPool() - - router := gin.New() - httpSetTrustedProxies(router, trustedProxies) - - router.NoRoute(s.requestPool.mw, httpLoggerMiddleware(s), httpServerHeaderMiddleware, s.onRequest) - - s.httpServer = &http.Server{ - Handler: router, - TLSConfig: tlsConfig, - ReadHeaderTimeout: time.Duration(readTimeout), - ErrorLog: log.New(&nilWriter{}, "", 0), - } - - str := "listener opened on " + address + " (HTTP)" - if udpMuxLn != nil { - str += ", " + iceUDPMuxAddress + " (ICE/UDP)" - } - if tcpMuxLn != nil { - str += ", " + iceTCPMuxAddress + " (ICE/TCP)" - } - s.Log(logger.Info, str) - - if s.metrics != nil { - s.metrics.webRTCServerSet(s) - } - - go s.run() - - return s, nil -} - -// Log is the main logging function. -func (s *webRTCServer) Log(level logger.Level, format string, args ...interface{}) { - s.parent.Log(level, "[WebRTC] "+format, append([]interface{}{}, args...)...) -} - -func (s *webRTCServer) close() { - s.Log(logger.Info, "listener is closing") - s.ctxCancel() - <-s.done -} - -func (s *webRTCServer) run() { - defer close(s.done) - - if s.httpServer.TLSConfig != nil { - go s.httpServer.ServeTLS(s.ln, "", "") - } else { - go s.httpServer.Serve(s.ln) - } - - var wg sync.WaitGroup - -outer: - for { - select { - case req := <-s.connNew: - c := newWebRTCConn( - s.ctx, - s.readBufferCount, - req.pathName, - req.publish, - req.wsconn, - req.videoCodec, - req.audioCodec, - req.videoBitrate, - s.iceServers, - &wg, - s.pathManager, - s, - s.iceHostNAT1To1IPs, - s.iceUDPMux, - s.iceTCPMux, - ) - s.conns[c] = struct{}{} - req.res <- c - - case conn := <-s.chConnClose: - delete(s.conns, conn) - - case req := <-s.chAPIConnsList: - data := &webRTCServerAPIConnsListData{ - Items: make(map[string]webRTCServerAPIConnsListItem), - } - - for c := range s.conns { - peerConnectionEstablished := false - localCandidate := "" - remoteCandidate := "" - bytesReceived := uint64(0) - bytesSent := uint64(0) - - pc := c.safePC() - if pc != nil { - peerConnectionEstablished = true - localCandidate = pc.localCandidate() - remoteCandidate = pc.remoteCandidate() - bytesReceived = pc.bytesReceived() - bytesSent = pc.bytesSent() - } - - data.Items[c.uuid.String()] = webRTCServerAPIConnsListItem{ - Created: c.created, - RemoteAddr: c.remoteAddr().String(), - PeerConnectionEstablished: peerConnectionEstablished, - LocalCandidate: localCandidate, - RemoteCandidate: remoteCandidate, - State: func() string { - if c.publish { - return "publish" - } - return "read" - }(), - BytesReceived: bytesReceived, - BytesSent: bytesSent, - } - } - - req.res <- webRTCServerAPIConnsListRes{data: data} - - case req := <-s.chAPIConnsKick: - res := func() bool { - for c := range s.conns { - if c.uuid.String() == req.id { - delete(s.conns, c) - c.close() - return true - } - } - return false - }() - if res { - req.res <- webRTCServerAPIConnsKickRes{} - } else { - req.res <- webRTCServerAPIConnsKickRes{fmt.Errorf("not found")} - } - - case <-s.ctx.Done(): - break outer - } - } - - s.ctxCancel() - - s.httpServer.Shutdown(context.Background()) - s.ln.Close() // in case Shutdown() is called before Serve() - - s.requestPool.close() - wg.Wait() - - if s.udpMuxLn != nil { - s.udpMuxLn.Close() - } - - if s.tcpMuxLn != nil { - s.tcpMuxLn.Close() - } -} - -func (s *webRTCServer) onRequest(ctx *gin.Context) { - ctx.Writer.Header().Set("Access-Control-Allow-Origin", s.allowOrigin) - ctx.Writer.Header().Set("Access-Control-Allow-Credentials", "true") - - switch ctx.Request.Method { - case http.MethodGet: - - case http.MethodOptions: - ctx.Writer.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") - ctx.Writer.Header().Set("Access-Control-Allow-Headers", ctx.Request.Header.Get("Access-Control-Request-Headers")) - ctx.Writer.WriteHeader(http.StatusOK) - return - - default: - return - } - - // remove leading prefix - pa := ctx.Request.URL.Path[1:] - - var dir string - var fname string - var publish bool - - switch { - case strings.HasSuffix(pa, "/publish/ws"): - dir = pa[:len(pa)-len("/publish/ws")] - fname = "publish/ws" - publish = true - - case strings.HasSuffix(pa, "/publish"): - dir = pa[:len(pa)-len("/publish")] - fname = "publish" - publish = true - - case strings.HasSuffix(pa, "/ws"): - dir = pa[:len(pa)-len("/ws")] - fname = "ws" - publish = false - - case pa == "favicon.ico": - return - - default: - dir = pa - fname = "" - publish = false - - if !strings.HasSuffix(dir, "/") { - ctx.Writer.Header().Set("Location", "/"+dir+"/") - ctx.Writer.WriteHeader(http.StatusMovedPermanently) - return - } - } - - dir = strings.TrimSuffix(dir, "/") - if dir == "" { - return - } - - user, pass, hasCredentials := ctx.Request.BasicAuth() - - res := s.pathManager.getPathConf(pathGetPathConfReq{ - name: dir, - publish: publish, - credentials: authCredentials{ - query: ctx.Request.URL.RawQuery, - ip: net.ParseIP(ctx.ClientIP()), - user: user, - pass: pass, - proto: authProtocolWebRTC, - }, - }) - if res.err != nil { - if terr, ok := res.err.(pathErrAuth); ok { - if !hasCredentials { - ctx.Header("WWW-Authenticate", `Basic realm="mediamtx"`) - ctx.Writer.WriteHeader(http.StatusUnauthorized) - return - } - - s.Log(logger.Info, "authentication error: %v", terr.wrapped) - ctx.Writer.WriteHeader(http.StatusUnauthorized) - return - } - - ctx.Writer.WriteHeader(http.StatusNotFound) - return - } - - switch fname { - case "": - ctx.Writer.Header().Set("Content-Type", "text/html") - ctx.Writer.WriteHeader(http.StatusOK) - ctx.Writer.Write(webrtcReadIndex) - - case "publish": - ctx.Writer.Header().Set("Content-Type", "text/html") - ctx.Writer.WriteHeader(http.StatusOK) - ctx.Writer.Write(webrtcPublishIndex) - - case "ws", "publish/ws": - wsconn, err := websocket.NewServerConn(ctx.Writer, ctx.Request) - if err != nil { - return - } - defer wsconn.Close() - - c := s.newConn(webRTCConnNewReq{ - pathName: dir, - publish: (fname == "publish/ws"), - wsconn: wsconn, - videoCodec: ctx.Query("video_codec"), - audioCodec: ctx.Query("audio_codec"), - videoBitrate: ctx.Query("video_bitrate"), - }) - if c == nil { - return - } - - c.wait() - } -} - -func (s *webRTCServer) newConn(req webRTCConnNewReq) *webRTCConn { - req.res = make(chan *webRTCConn) - - select { - case s.connNew <- req: - return <-req.res - case <-s.ctx.Done(): - return nil - } -} - -// connClose is called by webRTCConn. -func (s *webRTCServer) connClose(c *webRTCConn) { - select { - case s.chConnClose <- c: - case <-s.ctx.Done(): - } -} - -// apiConnsList is called by api. -func (s *webRTCServer) apiConnsList() webRTCServerAPIConnsListRes { - req := webRTCServerAPIConnsListReq{ - res: make(chan webRTCServerAPIConnsListRes), - } - - select { - case s.chAPIConnsList <- req: - return <-req.res - - case <-s.ctx.Done(): - return webRTCServerAPIConnsListRes{err: fmt.Errorf("terminated")} - } -} - -// apiConnsKick is called by api. -func (s *webRTCServer) apiConnsKick(id string) webRTCServerAPIConnsKickRes { - req := webRTCServerAPIConnsKickReq{ - id: id, - res: make(chan webRTCServerAPIConnsKickRes), - } - - select { - case s.chAPIConnsKick <- req: - return <-req.res - - case <-s.ctx.Done(): - return webRTCServerAPIConnsKickRes{err: fmt.Errorf("terminated")} - } -} diff --git a/internal/core/webrtc_server_test.go b/internal/core/webrtc_server_test.go deleted file mode 100644 index 131267db..00000000 --- a/internal/core/webrtc_server_test.go +++ /dev/null @@ -1,235 +0,0 @@ -package core - -import ( - "encoding/json" - "sync" - "testing" - "time" - - "github.com/bluenviron/gortsplib/v3" - "github.com/bluenviron/gortsplib/v3/pkg/formats" - "github.com/bluenviron/gortsplib/v3/pkg/media" - "github.com/gorilla/websocket" - "github.com/pion/rtp" - "github.com/pion/webrtc/v3" - "github.com/stretchr/testify/require" -) - -type webRTCTestClient struct { - wc *websocket.Conn - pc *webrtc.PeerConnection - track chan *webrtc.TrackRemote - closed chan struct{} -} - -func newWebRTCTestClient(addr string) (*webRTCTestClient, error) { - wc, res, err := websocket.DefaultDialer.Dial(addr, nil) - if err != nil { - return nil, err - } - defer res.Body.Close() - - _, msg, err := wc.ReadMessage() - if err != nil { - wc.Close() - return nil, err - } - - var iceServers []webrtc.ICEServer - err = json.Unmarshal(msg, &iceServers) - if err != nil { - wc.Close() - return nil, err - } - - pc, err := webrtc.NewPeerConnection(webrtc.Configuration{ - ICEServers: iceServers, - }) - if err != nil { - wc.Close() - return nil, err - } - - pc.OnICECandidate(func(i *webrtc.ICECandidate) { - if i != nil { - enc, _ := json.Marshal(i.ToJSON()) - wc.WriteMessage(websocket.TextMessage, enc) - } - }) - - connected := make(chan struct{}) - closed := make(chan struct{}) - var stateChangeMutex sync.Mutex - - pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { - stateChangeMutex.Lock() - defer stateChangeMutex.Unlock() - - select { - case <-closed: - return - default: - } - - switch state { - case webrtc.PeerConnectionStateConnected: - close(connected) - - case webrtc.PeerConnectionStateClosed: - close(closed) - } - }) - - track := make(chan *webrtc.TrackRemote, 1) - - pc.OnTrack(func(trak *webrtc.TrackRemote, recv *webrtc.RTPReceiver) { - track <- trak - }) - - _, err = pc.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo) - if err != nil { - wc.Close() - pc.Close() - return nil, err - } - - localOffer, err := pc.CreateOffer(nil) - if err != nil { - wc.Close() - pc.Close() - return nil, err - } - - enc, err := json.Marshal(localOffer) - if err != nil { - wc.Close() - pc.Close() - return nil, err - } - - err = wc.WriteMessage(websocket.TextMessage, enc) - if err != nil { - wc.Close() - pc.Close() - return nil, err - } - - err = pc.SetLocalDescription(localOffer) - if err != nil { - wc.Close() - pc.Close() - return nil, err - } - - _, msg, err = wc.ReadMessage() - if err != nil { - wc.Close() - pc.Close() - return nil, err - } - - var remoteOffer webrtc.SessionDescription - err = json.Unmarshal(msg, &remoteOffer) - if err != nil { - wc.Close() - pc.Close() - return nil, err - } - - err = pc.SetRemoteDescription(remoteOffer) - if err != nil { - wc.Close() - pc.Close() - return nil, err - } - - go func() { - for { - _, msg, err := wc.ReadMessage() - if err != nil { - return - } - - var candidate webrtc.ICECandidateInit - err = json.Unmarshal(msg, &candidate) - if err != nil { - return - } - - pc.AddICECandidate(candidate) - } - }() - - <-connected - - return &webRTCTestClient{ - wc: wc, - pc: pc, - track: track, - closed: closed, - }, nil -} - -func (c *webRTCTestClient) close() { - c.pc.Close() - c.wc.Close() - <-c.closed -} - -func TestWebRTCServer(t *testing.T) { - p, ok := newInstance("paths:\n" + - " all:\n") - require.Equal(t, true, ok) - defer p.Close() - - medi := &media.Media{ - Type: media.TypeVideo, - Formats: []formats.Format{&formats.H264{ - PayloadTyp: 96, - PacketizationMode: 1, - }}, - } - - v := gortsplib.TransportTCP - source := gortsplib.Client{ - Transport: &v, - } - err := source.StartRecording("rtsp://localhost:8554/stream", media.Medias{medi}) - require.NoError(t, err) - defer source.Close() - - c, err := newWebRTCTestClient("ws://localhost:8889/stream/ws") - require.NoError(t, err) - defer c.close() - - time.Sleep(500 * time.Millisecond) - - source.WritePacketRTP(medi, &rtp.Packet{ - Header: rtp.Header{ - Version: 2, - Marker: true, - PayloadType: 96, - SequenceNumber: 123, - Timestamp: 45343, - SSRC: 563423, - }, - Payload: []byte{0x01, 0x02, 0x03, 0x04}, - }) - - trak := <-c.track - - pkt, _, err := trak.ReadRTP() - require.NoError(t, err) - require.Equal(t, &rtp.Packet{ - Header: rtp.Header{ - Version: 2, - Marker: true, - PayloadType: 102, - SequenceNumber: pkt.SequenceNumber, - Timestamp: pkt.Timestamp, - SSRC: pkt.SSRC, - CSRC: []uint32{}, - }, - Payload: []byte{0x01, 0x02, 0x03, 0x04}, - }, pkt) -} diff --git a/internal/core/webrtc_session.go b/internal/core/webrtc_session.go new file mode 100644 index 00000000..6b152bd9 --- /dev/null +++ b/internal/core/webrtc_session.go @@ -0,0 +1,592 @@ +package core + +import ( + "context" + "encoding/hex" + "encoding/json" + "fmt" + "strconv" + "sync" + "time" + + "github.com/bluenviron/gortsplib/v3/pkg/media" + "github.com/bluenviron/gortsplib/v3/pkg/ringbuffer" + "github.com/google/uuid" + "github.com/pion/ice/v2" + "github.com/pion/sdp/v3" + "github.com/pion/webrtc/v3" + + "github.com/aler9/mediamtx/internal/logger" +) + +const ( + webrtcHandshakeTimeout = 10 * time.Second + webrtcTrackGatherTimeout = 2 * time.Second + webrtcPayloadMaxSize = 1188 // 1200 - 12 (RTP header) + webrtcStreamID = "mediamtx" +) + +type trackRecvPair struct { + track *webrtc.TrackRemote + receiver *webrtc.RTPReceiver +} + +func mediasOfOutgoingTracks(tracks []*webRTCOutgoingTrack) media.Medias { + ret := make(media.Medias, len(tracks)) + for i, track := range tracks { + ret[i] = track.media + } + return ret +} + +func mediasOfIncomingTracks(tracks []*webRTCIncomingTrack) media.Medias { + ret := make(media.Medias, len(tracks)) + for i, track := range tracks { + ret[i] = track.media + } + return ret +} + +func insertTias(offer *webrtc.SessionDescription, value uint64) { + var sd sdp.SessionDescription + err := sd.Unmarshal([]byte(offer.SDP)) + if err != nil { + return + } + + for _, media := range sd.MediaDescriptions { + if media.MediaName.Media == "video" { + media.Bandwidth = append(media.Bandwidth, sdp.Bandwidth{ + Type: "TIAS", + Bandwidth: value, + }) + } + } + + enc, err := sd.Marshal() + if err != nil { + return + } + + offer.SDP = string(enc) +} + +func gatherOutgoingTracks(medias media.Medias) ([]*webRTCOutgoingTrack, error) { + var tracks []*webRTCOutgoingTrack + + videoTrack, err := newWebRTCOutgoingTrackVideo(medias) + if err != nil { + return nil, err + } + + if videoTrack != nil { + tracks = append(tracks, videoTrack) + } + + audioTrack, err := newWebRTCOutgoingTrackAudio(medias) + if err != nil { + return nil, err + } + + if audioTrack != nil { + tracks = append(tracks, audioTrack) + } + + if tracks == nil { + return nil, fmt.Errorf( + "the stream doesn't contain any supported codec, which are currently H264, VP8, VP9, G711, G722, Opus") + } + + return tracks, nil +} + +func gatherIncomingTracks( + ctx context.Context, + pc *peerConnection, + trackRecv chan trackRecvPair, +) ([]*webRTCIncomingTrack, error) { + var tracks []*webRTCIncomingTrack + + t := time.NewTimer(webrtcTrackGatherTimeout) + defer t.Stop() + + for { + select { + case <-t.C: + return tracks, nil + + case pair := <-trackRecv: + track, err := newWebRTCIncomingTrack(pair.track, pair.receiver, pc.WriteRTCP) + if err != nil { + return nil, err + } + tracks = append(tracks, track) + + if len(tracks) == 2 { + return tracks, nil + } + + case <-pc.disconnected: + return nil, fmt.Errorf("peer connection closed") + + case <-ctx.Done(): + return nil, fmt.Errorf("terminated") + } + } +} + +type webRTCSessionPathManager interface { + publisherAdd(req pathPublisherAddReq) pathPublisherAnnounceRes + readerAdd(req pathReaderAddReq) pathReaderSetupPlayRes +} + +type webRTCSession struct { + readBufferCount int + req webRTCSessionNewReq + wg *sync.WaitGroup + iceHostNAT1To1IPs []string + iceUDPMux ice.UDPMux + iceTCPMux ice.TCPMux + pathManager webRTCSessionPathManager + parent *webRTCManager + + ctx context.Context + ctxCancel func() + created time.Time + uuid uuid.UUID + secret uuid.UUID + answerSent bool + pcMutex sync.RWMutex + pc *peerConnection + + chAddRemoteCandidates chan webRTCSessionAddCandidatesReq +} + +func newWebRTCSession( + parentCtx context.Context, + readBufferCount int, + req webRTCSessionNewReq, + wg *sync.WaitGroup, + iceHostNAT1To1IPs []string, + iceUDPMux ice.UDPMux, + iceTCPMux ice.TCPMux, + pathManager webRTCSessionPathManager, + parent *webRTCManager, +) *webRTCSession { + ctx, ctxCancel := context.WithCancel(parentCtx) + + s := &webRTCSession{ + readBufferCount: readBufferCount, + req: req, + wg: wg, + iceHostNAT1To1IPs: iceHostNAT1To1IPs, + iceUDPMux: iceUDPMux, + iceTCPMux: iceTCPMux, + parent: parent, + pathManager: pathManager, + ctx: ctx, + ctxCancel: ctxCancel, + created: time.Now(), + uuid: uuid.New(), + secret: uuid.New(), + chAddRemoteCandidates: make(chan webRTCSessionAddCandidatesReq), + } + + s.Log(logger.Info, "created by %s", req.remoteAddr) + + wg.Add(1) + go s.run() + + return s +} + +func (s *webRTCSession) Log(level logger.Level, format string, args ...interface{}) { + id := hex.EncodeToString(s.uuid[:4]) + s.parent.Log(level, "[session %v] "+format, append([]interface{}{id}, args...)...) +} + +func (s *webRTCSession) close() { + s.ctxCancel() +} + +func (s *webRTCSession) safePC() *peerConnection { + s.pcMutex.RLock() + defer s.pcMutex.RUnlock() + return s.pc +} + +func (s *webRTCSession) run() { + defer s.wg.Done() + + err := s.runInner() + + if !s.answerSent { + select { + case s.req.res <- webRTCNewSessionRes{ + err: err, + }: + case <-s.ctx.Done(): + } + } + + s.parent.sessionClose(s) + + s.Log(logger.Info, "closed (%v)", err) +} + +func (s *webRTCSession) runInner() error { + if s.req.publish { + return s.runPublish() + } + return s.runRead() +} + +func (s *webRTCSession) runPublish() error { + res := s.pathManager.publisherAdd(pathPublisherAddReq{ + author: s, + pathName: s.req.pathName, + skipAuth: true, + }) + if res.err != nil { + return res.err + } + + defer res.path.publisherRemove(pathPublisherRemoveReq{author: s}) + + offer, err := s.decodeOffer() + if err != nil { + return err + } + + pc, err := newPeerConnection( + s.req.videoCodec, + s.req.audioCodec, + s.parent.genICEServers(), + s.iceHostNAT1To1IPs, + s.iceUDPMux, + s.iceTCPMux, + s) + if err != nil { + return err + } + defer pc.close() + + _, err = pc.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo, webrtc.RtpTransceiverInit{ + Direction: webrtc.RTPTransceiverDirectionRecvonly, + }) + if err != nil { + return err + } + + _, err = pc.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio, webrtc.RtpTransceiverInit{ + Direction: webrtc.RTPTransceiverDirectionRecvonly, + }) + if err != nil { + return err + } + + trackRecv := make(chan trackRecvPair) + + pc.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { + select { + case trackRecv <- trackRecvPair{track, receiver}: + case <-pc.closed: + } + }) + + err = pc.SetRemoteDescription(*offer) + if err != nil { + return err + } + + answer, err := pc.CreateAnswer(nil) + if err != nil { + return err + } + + err = pc.SetLocalDescription(answer) + if err != nil { + return err + } + + if s.req.videoBitrate != "" { + tmp, err := strconv.ParseUint(s.req.videoBitrate, 10, 31) + if err != nil { + return err + } + + insertTias(&answer, tmp*1024) + } + + err = s.waitGatheringDone(pc) + if err != nil { + return err + } + + err = s.writeAnswer(pc.LocalDescription()) + if err != nil { + return err + } + + go s.readRemoteCandidates(pc) + + err = s.waitUntilConnected(pc) + if err != nil { + return err + } + + tracks, err := gatherIncomingTracks(s.ctx, pc, trackRecv) + if err != nil { + return err + } + medias := mediasOfIncomingTracks(tracks) + + rres := res.path.publisherStart(pathPublisherStartReq{ + author: s, + medias: medias, + generateRTPPackets: false, + }) + if rres.err != nil { + return rres.err + } + + s.Log(logger.Info, "is publishing to path '%s', %s", + res.path.name, + sourceMediaInfo(medias)) + + for _, track := range tracks { + track.start(rres.stream) + } + + select { + case <-pc.disconnected: + return fmt.Errorf("peer connection closed") + + case <-s.ctx.Done(): + return fmt.Errorf("terminated") + } +} + +func (s *webRTCSession) runRead() error { + res := s.pathManager.readerAdd(pathReaderAddReq{ + author: s, + pathName: s.req.pathName, + skipAuth: true, + }) + if res.err != nil { + return res.err + } + + defer res.path.readerRemove(pathReaderRemoveReq{author: s}) + + tracks, err := gatherOutgoingTracks(res.stream.medias()) + if err != nil { + return err + } + + offer, err := s.decodeOffer() + if err != nil { + return err + } + + pc, err := newPeerConnection( + "", + "", + s.parent.genICEServers(), + s.iceHostNAT1To1IPs, + s.iceUDPMux, + s.iceTCPMux, + s) + if err != nil { + return err + } + defer pc.close() + + for _, track := range tracks { + var err error + track.sender, err = pc.AddTrack(track.track) + if err != nil { + return err + } + } + + err = pc.SetRemoteDescription(*offer) + if err != nil { + return err + } + + answer, err := pc.CreateAnswer(nil) + if err != nil { + return err + } + + err = pc.SetLocalDescription(answer) + if err != nil { + return err + } + + err = s.waitGatheringDone(pc) + if err != nil { + return err + } + + err = s.writeAnswer(pc.LocalDescription()) + if err != nil { + return err + } + + go s.readRemoteCandidates(pc) + + err = s.waitUntilConnected(pc) + if err != nil { + return err + } + + ringBuffer, _ := ringbuffer.New(uint64(s.readBufferCount)) + defer ringBuffer.Close() + + writeError := make(chan error) + + for _, track := range tracks { + track.start(s.ctx, s, res.stream, ringBuffer, writeError) + } + + defer res.stream.readerRemove(s) + + s.Log(logger.Info, "is reading from path '%s', %s", + res.path.name, sourceMediaInfo(mediasOfOutgoingTracks(tracks))) + + go func() { + for { + item, ok := ringBuffer.Pull() + if !ok { + return + } + item.(func())() + } + }() + + select { + case <-pc.disconnected: + return fmt.Errorf("peer connection closed") + + case err := <-writeError: + return err + + case <-s.ctx.Done(): + return fmt.Errorf("terminated") + } +} + +func (s *webRTCSession) decodeOffer() (*webrtc.SessionDescription, error) { + var offer webrtc.SessionDescription + err := json.Unmarshal(s.req.offer, &offer) + if err != nil { + return nil, err + } + + if offer.Type != webrtc.SDPTypeOffer { + return nil, fmt.Errorf("received SDP is not an offer") + } + + return &offer, nil +} + +func (s *webRTCSession) waitGatheringDone(pc *peerConnection) error { + for { + select { + case <-pc.localCandidateRecv: + case <-pc.gatheringDone: + return nil + case <-s.ctx.Done(): + return fmt.Errorf("terminated") + } + } +} + +func (s *webRTCSession) writeAnswer(answer *webrtc.SessionDescription) error { + enc, err := json.Marshal(answer) + if err != nil { + return err + } + + select { + case s.req.res <- webRTCNewSessionRes{ + sx: s, + answer: enc, + }: + s.answerSent = true + case <-s.ctx.Done(): + return fmt.Errorf("terminated") + } + + return nil +} + +func (s *webRTCSession) waitUntilConnected(pc *peerConnection) error { + t := time.NewTimer(webrtcHandshakeTimeout) + defer t.Stop() + +outer: + for { + select { + case <-t.C: + return fmt.Errorf("deadline exceeded") + + case <-pc.connected: + break outer + + case <-s.ctx.Done(): + return fmt.Errorf("terminated") + } + } + + s.pcMutex.Lock() + s.pc = pc + s.pcMutex.Unlock() + + return nil +} + +func (s *webRTCSession) readRemoteCandidates(pc *peerConnection) { + for { + select { + case req := <-s.chAddRemoteCandidates: + for _, candidate := range req.candidates { + err := pc.AddICECandidate(*candidate) + if err != nil { + req.res <- webRTCSessionAddCandidatesRes{err: err} + } + } + req.res <- webRTCSessionAddCandidatesRes{} + + case <-s.ctx.Done(): + return + } + } +} + +func (s *webRTCSession) addRemoteCandidates( + req webRTCSessionAddCandidatesReq, +) webRTCSessionAddCandidatesRes { + select { + case s.chAddRemoteCandidates <- req: + return <-req.res + + case <-s.ctx.Done(): + return webRTCSessionAddCandidatesRes{err: fmt.Errorf("terminated")} + } +} + +// apiSourceDescribe implements sourceStaticImpl. +func (s *webRTCSession) apiSourceDescribe() pathAPISourceOrReader { + return pathAPISourceOrReader{ + Type: "webRTCSession", + ID: s.uuid.String(), + } +} + +// apiReaderDescribe implements reader. +func (s *webRTCSession) apiReaderDescribe() pathAPISourceOrReader { + return s.apiSourceDescribe() +}