From ade0cddeb3f6206c62bacc14ffb65702e520ed7c Mon Sep 17 00:00:00 2001 From: KHuynh <3639452+secit@users.noreply.github.com> Date: Fri, 21 Nov 2025 02:00:46 +0100 Subject: [PATCH] support multiple CORS origins (#5150) Co-authored-by: aler9 <46489434+aler9@users.noreply.github.com> --- api/openapi.yaml | 36 ++-- internal/api/api.go | 11 +- internal/api/api_test.go | 2 +- internal/conf/allowed_origins.go | 4 + internal/conf/conf.go | 175 ++++++++++++------ internal/core/core.go | 25 +-- internal/metrics/metrics.go | 11 +- internal/metrics/metrics_test.go | 8 +- internal/playback/server.go | 11 +- internal/playback/server_test.go | 2 +- internal/pprof/pprof.go | 11 +- internal/pprof/pprof_test.go | 6 +- .../httpp/handler_filter_requests_test.go | 42 +++++ internal/protocols/httpp/handler_origin.go | 88 +++++++++ .../protocols/httpp/handler_origin_test.go | 87 +++++++++ internal/protocols/httpp/server.go | 4 +- internal/protocols/httpp/server_test.go | 30 --- internal/servers/hls/http_server.go | 11 +- internal/servers/hls/server.go | 4 +- internal/servers/hls/server_test.go | 4 +- internal/servers/webrtc/http_server.go | 11 +- internal/servers/webrtc/server.go | 4 +- internal/servers/webrtc/server_test.go | 2 +- mediamtx.yml | 32 ++-- 24 files changed, 441 insertions(+), 180 deletions(-) create mode 100644 internal/conf/allowed_origins.go create mode 100644 internal/protocols/httpp/handler_filter_requests_test.go create mode 100644 internal/protocols/httpp/handler_origin.go create mode 100644 internal/protocols/httpp/handler_origin_test.go diff --git a/api/openapi.yaml b/api/openapi.yaml index 0428214a..7b475a8c 100644 --- a/api/openapi.yaml +++ b/api/openapi.yaml @@ -124,8 +124,10 @@ components: type: string apiServerCert: type: string - apiAllowOrigin: - type: string + apiAllowOrigins: + type: array + items: + type: string apiTrustedProxies: type: array items: @@ -142,8 +144,10 @@ components: type: string metricsServerCert: type: string - metricsAllowOrigin: - type: string + metricsAllowOrigins: + type: array + items: + type: string metricsTrustedProxies: type: array items: @@ -160,8 +164,10 @@ components: type: string pprofServerCert: type: string - pprofAllowOrigin: - type: string + pprofAllowOrigins: + type: array + items: + type: string pprofTrustedProxies: type: array items: @@ -178,8 +184,10 @@ components: type: string playbackServerCert: type: string - playbackAllowOrigin: - type: string + playbackAllowOrigins: + type: array + items: + type: string playbackTrustedProxies: type: array items: @@ -254,8 +262,10 @@ components: type: string hlsServerCert: type: string - hlsAllowOrigin: - type: string + hlsAllowOrigins: + type: array + items: + type: string hlsTrustedProxies: type: array items: @@ -289,8 +299,10 @@ components: type: string webrtcServerCert: type: string - webrtcAllowOrigin: - type: string + webrtcAllowOrigins: + type: array + items: + type: string webrtcTrustedProxies: type: array items: diff --git a/internal/api/api.go b/internal/api/api.go index bd6d33aa..f2e597bd 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -94,7 +94,7 @@ type API struct { Encryption bool ServerKey string ServerCert string - AllowOrigin string + AllowOrigins []string TrustedProxies conf.IPNetworks ReadTimeout conf.Duration WriteTimeout conf.Duration @@ -119,7 +119,7 @@ func (a *API) Initialize() error { router := gin.New() router.SetTrustedProxies(a.TrustedProxies.ToTrustedProxies()) //nolint:errcheck - router.Use(a.middlewareOrigin) + router.Use(a.middlewarePreflightRequests) router.Use(a.middlewareAuth) group := router.Group("/v3") @@ -195,6 +195,7 @@ func (a *API) Initialize() error { a.httpServer = &httpp.Server{ Address: a.Address, + AllowOrigins: a.AllowOrigins, ReadTimeout: time.Duration(a.ReadTimeout), WriteTimeout: time.Duration(a.WriteTimeout), Encryption: a.Encryption, @@ -234,11 +235,7 @@ func (a *API) writeError(ctx *gin.Context, status int, err error) { }) } -func (a *API) middlewareOrigin(ctx *gin.Context) { - ctx.Header("Access-Control-Allow-Origin", a.AllowOrigin) - ctx.Header("Access-Control-Allow-Credentials", "true") - - // preflight requests +func (a *API) middlewarePreflightRequests(ctx *gin.Context) { if ctx.Request.Method == http.MethodOptions && ctx.Request.Header.Get("Access-Control-Request-Method") != "" { ctx.Header("Access-Control-Allow-Methods", "OPTIONS, GET, POST, PATCH, DELETE") diff --git a/internal/api/api_test.go b/internal/api/api_test.go index d216e52b..0612018a 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -83,7 +83,7 @@ func checkError(t *testing.T, msg string, body io.Reader) { func TestPreflightRequest(t *testing.T) { api := API{ Address: "localhost:9997", - AllowOrigin: "*", + AllowOrigins: []string{"*"}, ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), AuthManager: test.NilAuthManager, diff --git a/internal/conf/allowed_origins.go b/internal/conf/allowed_origins.go new file mode 100644 index 00000000..77db4fba --- /dev/null +++ b/internal/conf/allowed_origins.go @@ -0,0 +1,4 @@ +package conf + +// AllowedOrigins is a list of allowed CORS origins. +type AllowedOrigins []string diff --git a/internal/conf/conf.go b/internal/conf/conf.go index 8595adda..5aa7c5c3 100644 --- a/internal/conf/conf.go +++ b/internal/conf/conf.go @@ -177,40 +177,44 @@ type Conf struct { AuthJWTInHTTPQuery bool `json:"authJWTInHTTPQuery"` // Control API - API bool `json:"api"` - APIAddress string `json:"apiAddress"` - APIEncryption bool `json:"apiEncryption"` - APIServerKey string `json:"apiServerKey"` - APIServerCert string `json:"apiServerCert"` - APIAllowOrigin string `json:"apiAllowOrigin"` - APITrustedProxies IPNetworks `json:"apiTrustedProxies"` + API bool `json:"api"` + APIAddress string `json:"apiAddress"` + APIEncryption bool `json:"apiEncryption"` + APIServerKey string `json:"apiServerKey"` + APIServerCert string `json:"apiServerCert"` + APIAllowOrigin *string `json:"apiAllowOrigin,omitempty"` // deprecated + APIAllowOrigins AllowedOrigins `json:"apiAllowOrigins"` + APITrustedProxies IPNetworks `json:"apiTrustedProxies"` // Metrics - Metrics bool `json:"metrics"` - MetricsAddress string `json:"metricsAddress"` - MetricsEncryption bool `json:"metricsEncryption"` - MetricsServerKey string `json:"metricsServerKey"` - MetricsServerCert string `json:"metricsServerCert"` - MetricsAllowOrigin string `json:"metricsAllowOrigin"` - MetricsTrustedProxies IPNetworks `json:"metricsTrustedProxies"` + Metrics bool `json:"metrics"` + MetricsAddress string `json:"metricsAddress"` + MetricsEncryption bool `json:"metricsEncryption"` + MetricsServerKey string `json:"metricsServerKey"` + MetricsServerCert string `json:"metricsServerCert"` + MetricsAllowOrigin *string `json:"metricsAllowOrigin,omitempty"` // deprecated + MetricsAllowOrigins AllowedOrigins `json:"metricsAllowOrigins"` + MetricsTrustedProxies IPNetworks `json:"metricsTrustedProxies"` // PPROF - PPROF bool `json:"pprof"` - PPROFAddress string `json:"pprofAddress"` - PPROFEncryption bool `json:"pprofEncryption"` - PPROFServerKey string `json:"pprofServerKey"` - PPROFServerCert string `json:"pprofServerCert"` - PPROFAllowOrigin string `json:"pprofAllowOrigin"` - PPROFTrustedProxies IPNetworks `json:"pprofTrustedProxies"` + PPROF bool `json:"pprof"` + PPROFAddress string `json:"pprofAddress"` + PPROFEncryption bool `json:"pprofEncryption"` + PPROFServerKey string `json:"pprofServerKey"` + PPROFServerCert string `json:"pprofServerCert"` + PPROFAllowOrigin *string `json:"pprofAllowOrigin,omitempty"` // deprecated + PPROFAllowOrigins AllowedOrigins `json:"pprofAllowOrigins"` + PPROFTrustedProxies IPNetworks `json:"pprofTrustedProxies"` // Playback - Playback bool `json:"playback"` - PlaybackAddress string `json:"playbackAddress"` - PlaybackEncryption bool `json:"playbackEncryption"` - PlaybackServerKey string `json:"playbackServerKey"` - PlaybackServerCert string `json:"playbackServerCert"` - PlaybackAllowOrigin string `json:"playbackAllowOrigin"` - PlaybackTrustedProxies IPNetworks `json:"playbackTrustedProxies"` + Playback bool `json:"playback"` + PlaybackAddress string `json:"playbackAddress"` + PlaybackEncryption bool `json:"playbackEncryption"` + PlaybackServerKey string `json:"playbackServerKey"` + PlaybackServerCert string `json:"playbackServerCert"` + PlaybackAllowOrigin *string `json:"playbackAllowOrigin,omitempty"` // deprecated + PlaybackAllowOrigins AllowedOrigins `json:"playbackAllowOrigins"` + PlaybackTrustedProxies IPNetworks `json:"playbackTrustedProxies"` // RTSP server RTSP bool `json:"rtsp"` @@ -248,22 +252,23 @@ type Conf struct { RTMPServerCert string `json:"rtmpServerCert"` // HLS server - HLS bool `json:"hls"` - HLSDisable *bool `json:"hlsDisable,omitempty"` // deprecated - HLSAddress string `json:"hlsAddress"` - HLSEncryption bool `json:"hlsEncryption"` - HLSServerKey string `json:"hlsServerKey"` - HLSServerCert string `json:"hlsServerCert"` - HLSAllowOrigin string `json:"hlsAllowOrigin"` - HLSTrustedProxies IPNetworks `json:"hlsTrustedProxies"` - HLSAlwaysRemux bool `json:"hlsAlwaysRemux"` - HLSVariant HLSVariant `json:"hlsVariant"` - HLSSegmentCount int `json:"hlsSegmentCount"` - HLSSegmentDuration Duration `json:"hlsSegmentDuration"` - HLSPartDuration Duration `json:"hlsPartDuration"` - HLSSegmentMaxSize StringSize `json:"hlsSegmentMaxSize"` - HLSDirectory string `json:"hlsDirectory"` - HLSMuxerCloseAfter Duration `json:"hlsMuxerCloseAfter"` + HLS bool `json:"hls"` + HLSDisable *bool `json:"hlsDisable,omitempty"` // deprecated + HLSAddress string `json:"hlsAddress"` + HLSEncryption bool `json:"hlsEncryption"` + HLSServerKey string `json:"hlsServerKey"` + HLSServerCert string `json:"hlsServerCert"` + HLSAllowOrigin *string `json:"hlsAllowOrigin,omitempty"` // deprecated + HLSAllowOrigins AllowedOrigins `json:"hlsAllowOrigins"` + HLSTrustedProxies IPNetworks `json:"hlsTrustedProxies"` + HLSAlwaysRemux bool `json:"hlsAlwaysRemux"` + HLSVariant HLSVariant `json:"hlsVariant"` + HLSSegmentCount int `json:"hlsSegmentCount"` + HLSSegmentDuration Duration `json:"hlsSegmentDuration"` + HLSPartDuration Duration `json:"hlsPartDuration"` + HLSSegmentMaxSize StringSize `json:"hlsSegmentMaxSize"` + HLSDirectory string `json:"hlsDirectory"` + HLSMuxerCloseAfter Duration `json:"hlsMuxerCloseAfter"` // WebRTC server WebRTC bool `json:"webrtc"` @@ -272,7 +277,8 @@ type Conf struct { WebRTCEncryption bool `json:"webrtcEncryption"` WebRTCServerKey string `json:"webrtcServerKey"` WebRTCServerCert string `json:"webrtcServerCert"` - WebRTCAllowOrigin string `json:"webrtcAllowOrigin"` + WebRTCAllowOrigin *string `json:"webrtcAllowOrigin,omitempty"` // deprecated + WebRTCAllowOrigins AllowedOrigins `json:"webrtcAllowOrigins"` WebRTCTrustedProxies IPNetworks `json:"webrtcTrustedProxies"` WebRTCLocalUDPAddress string `json:"webrtcLocalUDPAddress"` WebRTCLocalTCPAddress string `json:"webrtcLocalTCPAddress"` @@ -340,25 +346,25 @@ func (conf *Conf) setDefaults() { conf.APIAddress = ":9997" conf.APIServerKey = "server.key" conf.APIServerCert = "server.crt" - conf.APIAllowOrigin = "*" + conf.APIAllowOrigins = []string{"*"} // Metrics conf.MetricsAddress = ":9998" conf.MetricsServerKey = "server.key" conf.MetricsServerCert = "server.crt" - conf.MetricsAllowOrigin = "*" + conf.MetricsAllowOrigins = []string{"*"} // PPROF conf.PPROFAddress = ":9999" conf.PPROFServerKey = "server.key" conf.PPROFServerCert = "server.crt" - conf.PPROFAllowOrigin = "*" + conf.PPROFAllowOrigins = []string{"*"} // Playback server conf.PlaybackAddress = ":9996" conf.PlaybackServerKey = "server.key" conf.PlaybackServerCert = "server.crt" - conf.PlaybackAllowOrigin = "*" + conf.PlaybackAllowOrigins = []string{"*"} // RTSP server conf.RTSP = true @@ -394,7 +400,7 @@ func (conf *Conf) setDefaults() { conf.HLSAddress = ":8888" conf.HLSServerKey = "server.key" conf.HLSServerCert = "server.crt" - conf.HLSAllowOrigin = "*" + conf.HLSAllowOrigins = []string{"*"} conf.HLSVariant = HLSVariant(gohlslib.MuxerVariantLowLatency) conf.HLSSegmentCount = 7 conf.HLSSegmentDuration = 1 * Duration(time.Second) @@ -407,7 +413,7 @@ func (conf *Conf) setDefaults() { conf.WebRTCAddress = ":8889" conf.WebRTCServerKey = "server.key" conf.WebRTCServerCert = "server.crt" - conf.WebRTCAllowOrigin = "*" + conf.WebRTCAllowOrigins = []string{"*"} conf.WebRTCLocalUDPAddress = ":8189" conf.WebRTCIPsFromInterfaces = true conf.WebRTCIPsFromInterfacesList = []string{} @@ -522,16 +528,20 @@ func (conf *Conf) Validate(l logger.Writer) error { if conf.ReadTimeout <= 0 { return fmt.Errorf("'readTimeout' must be greater than zero") } + if conf.WriteTimeout <= 0 { return fmt.Errorf("'writeTimeout' must be greater than zero") } + if conf.ReadBufferCount != nil { l.Log(logger.Warn, "parameter 'readBufferCount' is deprecated and has been replaced with 'writeQueueSize'") conf.WriteQueueSize = *conf.ReadBufferCount } + if (conf.WriteQueueSize & (conf.WriteQueueSize - 1)) != 0 { return fmt.Errorf("'writeQueueSize' must be a power of two") } + if conf.UDPMaxPayloadSize > 1472 { return fmt.Errorf("'udpMaxPayloadSize' must be less than 1472") } @@ -544,16 +554,19 @@ func (conf *Conf) Validate(l logger.Writer) error { conf.AuthMethod = AuthMethodHTTP conf.AuthHTTPAddress = *conf.ExternalAuthenticationURL } + if conf.AuthHTTPAddress != "" && !strings.HasPrefix(conf.AuthHTTPAddress, "http://") && !strings.HasPrefix(conf.AuthHTTPAddress, "https://") { return fmt.Errorf("'externalAuthenticationURL' must be a HTTP URL") } + if conf.AuthJWTJWKS != "" && !strings.HasPrefix(conf.AuthJWTJWKS, "http://") && !strings.HasPrefix(conf.AuthJWTJWKS, "https://") { return fmt.Errorf("'authJWTJWKS' must be a HTTP URL") } + deprecatedCredentialsMode := false if anyPathHasDeprecatedCredentials(conf.PathDefaults, conf.OptionalPaths) { l.Log(logger.Warn, "you are using one or more authentication-related deprecated parameters "+ @@ -592,6 +605,7 @@ func (conf *Conf) Validate(l logger.Writer) error { } deprecatedCredentialsMode = true } + switch conf.AuthMethod { case AuthMethodHTTP: if conf.AuthHTTPAddress == "" { @@ -607,24 +621,56 @@ func (conf *Conf) Validate(l logger.Writer) error { } } - // RTSP + // Control API + + if conf.APIAllowOrigin != nil { + l.Log(logger.Warn, "parameter 'apiAllowOrigin' is deprecated and has been replaced with 'apiAllowOrigins'") + conf.APIAllowOrigins = []string{*conf.APIAllowOrigin} + } + + // Metrics + + if conf.MetricsAllowOrigin != nil { + l.Log(logger.Warn, "parameter 'metricsAllowOrigin' is deprecated and has been replaced with 'metricsAllowOrigins'") + conf.MetricsAllowOrigins = []string{*conf.MetricsAllowOrigin} + } + + // PPROF + + if conf.PPROFAllowOrigin != nil { + l.Log(logger.Warn, "parameter 'pprofAllowOrigin' is deprecated and has been replaced with 'pprofAllowOrigins'") + conf.PPROFAllowOrigins = []string{*conf.PPROFAllowOrigin} + } + + // Playback + + if conf.PlaybackAllowOrigin != nil { + l.Log(logger.Warn, "parameter 'playbackAllowOrigin' is deprecated and has been replaced with 'playbackAllowOrigins'") + conf.PlaybackAllowOrigins = []string{*conf.PlaybackAllowOrigin} + } + + // RTSP server if conf.RTSPDisable != nil { l.Log(logger.Warn, "parameter 'rtspDisabled' is deprecated and has been replaced with 'rtsp'") conf.RTSP = !*conf.RTSPDisable } + if conf.Protocols != nil { l.Log(logger.Warn, "parameter 'protocols' is deprecated and has been replaced with 'rtspTransports'") conf.RTSPTransports = *conf.Protocols } + if conf.Encryption != nil { l.Log(logger.Warn, "parameter 'encryption' is deprecated and has been replaced with 'rtspEncryption'") conf.RTSPEncryption = *conf.Encryption } + if conf.AuthMethods != nil { l.Log(logger.Warn, "parameter 'authMethods' is deprecated and has been replaced with 'rtspAuthMethods'") conf.RTSPAuthMethods = *conf.AuthMethods } + if slices.Contains(conf.RTSPAuthMethods, auth.VerifyMethodDigestMD5) { if conf.AuthMethod != AuthMethodInternal { return fmt.Errorf("when RTSP digest is enabled, the only supported auth method is 'internal'") @@ -635,14 +681,17 @@ func (conf *Conf) Validate(l logger.Writer) error { } } } + if conf.ServerCert != nil { l.Log(logger.Warn, "parameter 'serverCert' is deprecated and has been replaced with 'rtspServerCert'") conf.RTSPServerCert = *conf.ServerCert } + if conf.ServerKey != nil { l.Log(logger.Warn, "parameter 'serverKey' is deprecated and has been replaced with 'rtspServerKey'") conf.RTSPServerKey = *conf.ServerKey } + if len(conf.RTSPAuthMethods) == 0 { return fmt.Errorf("at least one 'rtspAuthMethods' must be provided") } @@ -661,27 +710,36 @@ func (conf *Conf) Validate(l logger.Writer) error { conf.HLS = !*conf.HLSDisable } + if conf.HLSAllowOrigin != nil { + l.Log(logger.Warn, "parameter 'hlsAllowOrigin' is deprecated and has been replaced with 'hlsAllowOrigins'") + conf.HLSAllowOrigins = []string{*conf.HLSAllowOrigin} + } + // WebRTC if conf.WebRTCDisable != nil { l.Log(logger.Warn, "parameter 'webrtcDisable' is deprecated and has been replaced with 'webrtc'") conf.WebRTC = !*conf.WebRTCDisable } + if conf.WebRTCICEUDPMuxAddress != nil { l.Log(logger.Warn, "parameter 'webrtcICEUDPMuxAdderss' is deprecated "+ "and has been replaced with 'webrtcLocalUDPAddress'") conf.WebRTCLocalUDPAddress = *conf.WebRTCICEUDPMuxAddress } + if conf.WebRTCICETCPMuxAddress != nil { l.Log(logger.Warn, "parameter 'webrtcICETCPMuxAddress' is deprecated "+ "and has been replaced with 'webrtcLocalTCPAddress'") conf.WebRTCLocalTCPAddress = *conf.WebRTCICETCPMuxAddress } + if conf.WebRTCICEHostNAT1To1IPs != nil { l.Log(logger.Warn, "parameter 'webrtcICEHostNAT1To1IPs' is deprecated "+ "and has been replaced with 'webrtcAdditionalHosts'") conf.WebRTCAdditionalHosts = *conf.WebRTCICEHostNAT1To1IPs } + if conf.WebRTCICEServers != nil { l.Log(logger.Warn, "parameter 'webrtcICEServers' is deprecated "+ "and has been replaced with 'webrtcICEServers2'") @@ -701,6 +759,7 @@ func (conf *Conf) Validate(l logger.Writer) error { } } } + for _, server := range conf.WebRTCICEServers2 { if !strings.HasPrefix(server.URL, "stun:") && !strings.HasPrefix(server.URL, "turn:") && @@ -708,18 +767,25 @@ func (conf *Conf) Validate(l logger.Writer) error { return fmt.Errorf("invalid ICE server: '%s'", server.URL) } } + if conf.WebRTCLocalUDPAddress == "" && conf.WebRTCLocalTCPAddress == "" && len(conf.WebRTCICEServers2) == 0 { return fmt.Errorf("at least one between 'webrtcLocalUDPAddress'," + " 'webrtcLocalTCPAddress' or 'webrtcICEServers2' must be filled") } + if conf.WebRTCLocalUDPAddress != "" || conf.WebRTCLocalTCPAddress != "" { if !conf.WebRTCIPsFromInterfaces && len(conf.WebRTCAdditionalHosts) == 0 { return fmt.Errorf("at least one between 'webrtcIPsFromInterfaces' or 'webrtcAdditionalHosts' must be filled") } } + if conf.WebRTCAllowOrigin != nil { + l.Log(logger.Warn, "parameter 'webrtcAllowOrigin' is deprecated and has been replaced with 'webrtcAllowOrigins'") + conf.WebRTCAllowOrigins = []string{*conf.WebRTCAllowOrigin} + } + // Record (deprecated) if conf.Record != nil { @@ -727,26 +793,31 @@ func (conf *Conf) Validate(l logger.Writer) error { "and has been replaced with 'pathDefaults.record'") conf.PathDefaults.Record = *conf.Record } + if conf.RecordPath != nil { l.Log(logger.Warn, "parameter 'recordPath' is deprecated "+ "and has been replaced with 'pathDefaults.recordPath'") conf.PathDefaults.RecordPath = *conf.RecordPath } + if conf.RecordFormat != nil { l.Log(logger.Warn, "parameter 'recordFormat' is deprecated "+ "and has been replaced with 'pathDefaults.recordFormat'") conf.PathDefaults.RecordFormat = *conf.RecordFormat } + if conf.RecordPartDuration != nil { l.Log(logger.Warn, "parameter 'recordPartDuration' is deprecated "+ "and has been replaced with 'pathDefaults.recordPartDuration'") conf.PathDefaults.RecordPartDuration = *conf.RecordPartDuration } + if conf.RecordSegmentDuration != nil { l.Log(logger.Warn, "parameter 'recordSegmentDuration' is deprecated "+ "and has been replaced with 'pathDefaults.recordSegmentDuration'") conf.PathDefaults.RecordSegmentDuration = *conf.RecordSegmentDuration } + if conf.RecordDeleteAfter != nil { l.Log(logger.Warn, "parameter 'recordDeleteAfter' is deprecated "+ "and has been replaced with 'pathDefaults.recordDeleteAfter'") diff --git a/internal/core/core.go b/internal/core/core.go index b6907dc0..5c60fb1c 100644 --- a/internal/core/core.go +++ b/internal/core/core.go @@ -10,6 +10,7 @@ import ( "path/filepath" "reflect" "runtime" + "slices" "strings" "syscall" "time" @@ -323,7 +324,7 @@ func (p *Core) createResources(initial bool) error { Encryption: p.conf.MetricsEncryption, ServerKey: p.conf.MetricsServerKey, ServerCert: p.conf.MetricsServerCert, - AllowOrigin: p.conf.MetricsAllowOrigin, + AllowOrigins: p.conf.MetricsAllowOrigins, TrustedProxies: p.conf.MetricsTrustedProxies, ReadTimeout: p.conf.ReadTimeout, WriteTimeout: p.conf.WriteTimeout, @@ -344,7 +345,7 @@ func (p *Core) createResources(initial bool) error { Encryption: p.conf.PPROFEncryption, ServerKey: p.conf.PPROFServerKey, ServerCert: p.conf.PPROFServerCert, - AllowOrigin: p.conf.PPROFAllowOrigin, + AllowOrigins: p.conf.PPROFAllowOrigins, TrustedProxies: p.conf.PPROFTrustedProxies, ReadTimeout: p.conf.ReadTimeout, WriteTimeout: p.conf.WriteTimeout, @@ -374,7 +375,7 @@ func (p *Core) createResources(initial bool) error { Encryption: p.conf.PlaybackEncryption, ServerKey: p.conf.PlaybackServerKey, ServerCert: p.conf.PlaybackServerCert, - AllowOrigin: p.conf.PlaybackAllowOrigin, + AllowOrigins: p.conf.PlaybackAllowOrigins, TrustedProxies: p.conf.PlaybackTrustedProxies, ReadTimeout: p.conf.ReadTimeout, WriteTimeout: p.conf.WriteTimeout, @@ -562,7 +563,7 @@ func (p *Core) createResources(initial bool) error { Encryption: p.conf.HLSEncryption, ServerKey: p.conf.HLSServerKey, ServerCert: p.conf.HLSServerCert, - AllowOrigin: p.conf.HLSAllowOrigin, + AllowOrigins: p.conf.HLSAllowOrigins, TrustedProxies: p.conf.HLSTrustedProxies, AlwaysRemux: p.conf.HLSAlwaysRemux, Variant: p.conf.HLSVariant, @@ -592,7 +593,7 @@ func (p *Core) createResources(initial bool) error { Encryption: p.conf.WebRTCEncryption, ServerKey: p.conf.WebRTCServerKey, ServerCert: p.conf.WebRTCServerCert, - AllowOrigin: p.conf.WebRTCAllowOrigin, + AllowOrigins: p.conf.WebRTCAllowOrigins, TrustedProxies: p.conf.WebRTCTrustedProxies, ReadTimeout: p.conf.ReadTimeout, WriteTimeout: p.conf.WriteTimeout, @@ -650,7 +651,7 @@ func (p *Core) createResources(initial bool) error { Encryption: p.conf.APIEncryption, ServerKey: p.conf.APIServerKey, ServerCert: p.conf.APIServerCert, - AllowOrigin: p.conf.APIAllowOrigin, + AllowOrigins: p.conf.APIAllowOrigins, TrustedProxies: p.conf.APITrustedProxies, ReadTimeout: p.conf.ReadTimeout, WriteTimeout: p.conf.WriteTimeout, @@ -712,7 +713,7 @@ func (p *Core) closeResources(newConf *conf.Conf, calledByAPI bool) { newConf.MetricsEncryption != p.conf.MetricsEncryption || newConf.MetricsServerKey != p.conf.MetricsServerKey || newConf.MetricsServerCert != p.conf.MetricsServerCert || - newConf.MetricsAllowOrigin != p.conf.MetricsAllowOrigin || + !slices.Equal(newConf.MetricsAllowOrigins, p.conf.MetricsAllowOrigins) || !reflect.DeepEqual(newConf.MetricsTrustedProxies, p.conf.MetricsTrustedProxies) || newConf.ReadTimeout != p.conf.ReadTimeout || newConf.WriteTimeout != p.conf.WriteTimeout || @@ -725,7 +726,7 @@ func (p *Core) closeResources(newConf *conf.Conf, calledByAPI bool) { newConf.PPROFEncryption != p.conf.PPROFEncryption || newConf.PPROFServerKey != p.conf.PPROFServerKey || newConf.PPROFServerCert != p.conf.PPROFServerCert || - newConf.PPROFAllowOrigin != p.conf.PPROFAllowOrigin || + !slices.Equal(newConf.PPROFAllowOrigins, p.conf.PPROFAllowOrigins) || !reflect.DeepEqual(newConf.PPROFTrustedProxies, p.conf.PPROFTrustedProxies) || newConf.ReadTimeout != p.conf.ReadTimeout || newConf.WriteTimeout != p.conf.WriteTimeout || @@ -745,7 +746,7 @@ func (p *Core) closeResources(newConf *conf.Conf, calledByAPI bool) { newConf.PlaybackEncryption != p.conf.PlaybackEncryption || newConf.PlaybackServerKey != p.conf.PlaybackServerKey || newConf.PlaybackServerCert != p.conf.PlaybackServerCert || - newConf.PlaybackAllowOrigin != p.conf.PlaybackAllowOrigin || + !slices.Equal(newConf.PlaybackAllowOrigins, p.conf.PlaybackAllowOrigins) || !reflect.DeepEqual(newConf.PlaybackTrustedProxies, p.conf.PlaybackTrustedProxies) || newConf.ReadTimeout != p.conf.ReadTimeout || newConf.WriteTimeout != p.conf.WriteTimeout || @@ -852,7 +853,7 @@ func (p *Core) closeResources(newConf *conf.Conf, calledByAPI bool) { newConf.HLSEncryption != p.conf.HLSEncryption || newConf.HLSServerKey != p.conf.HLSServerKey || newConf.HLSServerCert != p.conf.HLSServerCert || - newConf.HLSAllowOrigin != p.conf.HLSAllowOrigin || + !slices.Equal(newConf.HLSAllowOrigins, p.conf.HLSAllowOrigins) || !reflect.DeepEqual(newConf.HLSTrustedProxies, p.conf.HLSTrustedProxies) || newConf.HLSAlwaysRemux != p.conf.HLSAlwaysRemux || newConf.HLSVariant != p.conf.HLSVariant || @@ -874,7 +875,7 @@ func (p *Core) closeResources(newConf *conf.Conf, calledByAPI bool) { newConf.WebRTCEncryption != p.conf.WebRTCEncryption || newConf.WebRTCServerKey != p.conf.WebRTCServerKey || newConf.WebRTCServerCert != p.conf.WebRTCServerCert || - newConf.WebRTCAllowOrigin != p.conf.WebRTCAllowOrigin || + !slices.Equal(newConf.WebRTCAllowOrigins, p.conf.WebRTCAllowOrigins) || !reflect.DeepEqual(newConf.WebRTCTrustedProxies, p.conf.WebRTCTrustedProxies) || newConf.ReadTimeout != p.conf.ReadTimeout || newConf.WriteTimeout != p.conf.WriteTimeout || @@ -911,7 +912,7 @@ func (p *Core) closeResources(newConf *conf.Conf, calledByAPI bool) { newConf.APIEncryption != p.conf.APIEncryption || newConf.APIServerKey != p.conf.APIServerKey || newConf.APIServerCert != p.conf.APIServerCert || - newConf.APIAllowOrigin != p.conf.APIAllowOrigin || + !slices.Equal(newConf.APIAllowOrigins, p.conf.APIAllowOrigins) || !reflect.DeepEqual(newConf.APITrustedProxies, p.conf.APITrustedProxies) || newConf.ReadTimeout != p.conf.ReadTimeout || newConf.WriteTimeout != p.conf.WriteTimeout || diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index 0ad67f8f..0934870f 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -74,7 +74,7 @@ type Metrics struct { Encryption bool ServerKey string ServerCert string - AllowOrigin string + AllowOrigins []string TrustedProxies conf.IPNetworks ReadTimeout conf.Duration WriteTimeout conf.Duration @@ -98,13 +98,14 @@ func (m *Metrics) Initialize() error { router := gin.New() router.SetTrustedProxies(m.TrustedProxies.ToTrustedProxies()) //nolint:errcheck - router.Use(m.middlewareOrigin) + router.Use(m.middlewarePreflightRequests) router.Use(m.middlewareAuth) router.GET("/metrics", m.onMetrics) m.httpServer = &httpp.Server{ Address: m.Address, + AllowOrigins: m.AllowOrigins, ReadTimeout: time.Duration(m.ReadTimeout), WriteTimeout: time.Duration(m.WriteTimeout), Encryption: m.Encryption, @@ -134,11 +135,7 @@ func (m *Metrics) Log(level logger.Level, format string, args ...any) { m.Parent.Log(level, "[metrics] "+format, args...) } -func (m *Metrics) middlewareOrigin(ctx *gin.Context) { - ctx.Header("Access-Control-Allow-Origin", m.AllowOrigin) - ctx.Header("Access-Control-Allow-Credentials", "true") - - // preflight requests +func (m *Metrics) middlewarePreflightRequests(ctx *gin.Context) { if ctx.Request.Method == http.MethodOptions && ctx.Request.Header.Get("Access-Control-Request-Method") != "" { ctx.Header("Access-Control-Allow-Methods", "OPTIONS, GET") diff --git a/internal/metrics/metrics_test.go b/internal/metrics/metrics_test.go index 45d9d4e4..9be2c982 100644 --- a/internal/metrics/metrics_test.go +++ b/internal/metrics/metrics_test.go @@ -192,7 +192,7 @@ func (dummyWebRTCServer) APISessionsKick(uuid.UUID) error { func TestPreflightRequest(t *testing.T) { m := Metrics{ Address: "localhost:9998", - AllowOrigin: "*", + AllowOrigins: []string{"*"}, ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), AuthManager: test.NilAuthManager, @@ -232,7 +232,7 @@ func TestMetrics(t *testing.T) { m := Metrics{ Address: "localhost:9998", - AllowOrigin: "*", + AllowOrigins: []string{"*"}, ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), AuthManager: &test.AuthManager{ @@ -368,7 +368,7 @@ func TestAuthError(t *testing.T) { m := Metrics{ Address: "localhost:9998", - AllowOrigin: "*", + AllowOrigins: []string{"*"}, ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), AuthManager: &test.AuthManager{ @@ -428,7 +428,7 @@ func TestFilter(t *testing.T) { t.Run(ca, func(t *testing.T) { m := Metrics{ Address: "localhost:9998", - AllowOrigin: "*", + AllowOrigins: []string{"*"}, ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), AuthManager: test.NilAuthManager, diff --git a/internal/playback/server.go b/internal/playback/server.go index ca930a81..c5b7f391 100644 --- a/internal/playback/server.go +++ b/internal/playback/server.go @@ -24,7 +24,7 @@ type Server struct { Encryption bool ServerKey string ServerCert string - AllowOrigin string + AllowOrigins []string TrustedProxies conf.IPNetworks ReadTimeout conf.Duration WriteTimeout conf.Duration @@ -41,13 +41,14 @@ func (s *Server) Initialize() error { router := gin.New() router.SetTrustedProxies(s.TrustedProxies.ToTrustedProxies()) //nolint:errcheck - router.Use(s.middlewareOrigin) + router.Use(s.middlewarePreflightRequests) router.GET("/list", s.onList) router.GET("/get", s.onGet) s.httpServer = &httpp.Server{ Address: s.Address, + AllowOrigins: s.AllowOrigins, ReadTimeout: time.Duration(s.ReadTimeout), WriteTimeout: time.Duration(s.WriteTimeout), Encryption: s.Encryption, @@ -100,11 +101,7 @@ func (s *Server) safeFindPathConf(name string) (*conf.Path, error) { return pathConf, err } -func (s *Server) middlewareOrigin(ctx *gin.Context) { - ctx.Header("Access-Control-Allow-Origin", s.AllowOrigin) - ctx.Header("Access-Control-Allow-Credentials", "true") - - // preflight requests +func (s *Server) middlewarePreflightRequests(ctx *gin.Context) { if ctx.Request.Method == http.MethodOptions && ctx.Request.Header.Get("Access-Control-Request-Method") != "" { ctx.Header("Access-Control-Allow-Methods", "OPTIONS, GET") diff --git a/internal/playback/server_test.go b/internal/playback/server_test.go index aebcf119..7db52032 100644 --- a/internal/playback/server_test.go +++ b/internal/playback/server_test.go @@ -18,7 +18,7 @@ import ( func TestPreflightRequest(t *testing.T) { s := &Server{ Address: "127.0.0.1:9996", - AllowOrigin: "*", + AllowOrigins: []string{"*"}, ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), Parent: test.NilLogger, diff --git a/internal/pprof/pprof.go b/internal/pprof/pprof.go index 945a5cc4..5b171935 100644 --- a/internal/pprof/pprof.go +++ b/internal/pprof/pprof.go @@ -29,7 +29,7 @@ type PPROF struct { Encryption bool ServerKey string ServerCert string - AllowOrigin string + AllowOrigins []string TrustedProxies conf.IPNetworks ReadTimeout conf.Duration WriteTimeout conf.Duration @@ -44,13 +44,14 @@ func (pp *PPROF) Initialize() error { router := gin.New() router.SetTrustedProxies(pp.TrustedProxies.ToTrustedProxies()) //nolint:errcheck - router.Use(pp.middlewareOrigin) + router.Use(pp.middlewarePreflightRequests) router.Use(pp.middlewareAuth) pprof.Register(router) pp.httpServer = &httpp.Server{ Address: pp.Address, + AllowOrigins: pp.AllowOrigins, ReadTimeout: time.Duration(pp.ReadTimeout), WriteTimeout: time.Duration(pp.WriteTimeout), Encryption: pp.Encryption, @@ -80,11 +81,7 @@ func (pp *PPROF) Log(level logger.Level, format string, args ...any) { pp.Parent.Log(level, "[pprof] "+format, args...) } -func (pp *PPROF) middlewareOrigin(ctx *gin.Context) { - ctx.Header("Access-Control-Allow-Origin", pp.AllowOrigin) - ctx.Header("Access-Control-Allow-Credentials", "true") - - // preflight requests +func (pp *PPROF) middlewarePreflightRequests(ctx *gin.Context) { if ctx.Request.Method == http.MethodOptions && ctx.Request.Header.Get("Access-Control-Request-Method") != "" { ctx.Header("Access-Control-Allow-Methods", "OPTIONS, GET") diff --git a/internal/pprof/pprof_test.go b/internal/pprof/pprof_test.go index 1d61c9e6..0a57bfdb 100644 --- a/internal/pprof/pprof_test.go +++ b/internal/pprof/pprof_test.go @@ -17,7 +17,7 @@ import ( func TestPreflightRequest(t *testing.T) { s := &PPROF{ Address: "127.0.0.1:9999", - AllowOrigin: "*", + AllowOrigins: []string{"*"}, ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), Parent: test.NilLogger, @@ -56,7 +56,7 @@ func TestPprof(t *testing.T) { s := &PPROF{ Address: "127.0.0.1:9999", - AllowOrigin: "*", + AllowOrigins: []string{"*"}, ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), AuthManager: &test.AuthManager{ @@ -99,7 +99,7 @@ func TestAuthError(t *testing.T) { s := &PPROF{ Address: "127.0.0.1:9999", - AllowOrigin: "*", + AllowOrigins: []string{"*"}, ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), AuthManager: &test.AuthManager{ diff --git a/internal/protocols/httpp/handler_filter_requests_test.go b/internal/protocols/httpp/handler_filter_requests_test.go new file mode 100644 index 00000000..02d5e24b --- /dev/null +++ b/internal/protocols/httpp/handler_filter_requests_test.go @@ -0,0 +1,42 @@ +package httpp + +import ( + "net" + "net/http" + "strings" + "testing" + "time" + + "github.com/bluenviron/mediamtx/internal/test" + "github.com/stretchr/testify/require" +) + +func TestHandlerFilterRequests(t *testing.T) { + s := &Server{ + Address: "localhost:4555", + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + Parent: test.NilLogger, + Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }), + } + err := s.Initialize() + require.NoError(t, err) + defer s.Close() + + conn, err := net.Dial("tcp", "localhost:4555") + require.NoError(t, err) + defer conn.Close() + + _, err = conn.Write([]byte("OPTIONS / HTTP/1.1\n" + + "Host: localhost:8889\n\n")) + require.NoError(t, err) + + buf := make([]byte, 200) + n, err := conn.Read(buf) + require.NoError(t, err) + + res := strings.Split(string(buf[:n]), "\r\n") + require.Equal(t, "HTTP/1.1 200 OK", res[0]) +} diff --git a/internal/protocols/httpp/handler_origin.go b/internal/protocols/httpp/handler_origin.go new file mode 100644 index 00000000..f7f6a8e1 --- /dev/null +++ b/internal/protocols/httpp/handler_origin.go @@ -0,0 +1,88 @@ +package httpp + +import ( + "net" + "net/http" + "net/url" + "regexp" + "strings" +) + +func isOriginAllowed(origin string, allowOrigins []string) (string, bool) { + if len(allowOrigins) == 0 { + return "", false + } + + for _, o := range allowOrigins { + if o == "*" { + return o, true + } + } + + if origin == "" { + return "", false + } + + originURL, err := url.Parse(origin) + if err != nil || originURL.Scheme == "" { + return "", false + } + + if originURL.Port() == "" && originURL.Scheme != "" { + switch originURL.Scheme { + case "http": + originURL.Host = net.JoinHostPort(originURL.Host, "80") + case "https": + originURL.Host = net.JoinHostPort(originURL.Host, "443") + } + } + + for _, o := range allowOrigins { + allowedURL, errAllowed := url.Parse(o) + if errAllowed != nil { + continue + } + + if allowedURL.Port() == "" { + switch allowedURL.Scheme { + case "http": + allowedURL.Host = net.JoinHostPort(allowedURL.Host, "80") + case "https": + allowedURL.Host = net.JoinHostPort(allowedURL.Host, "443") + } + } + + if allowedURL.Scheme == originURL.Scheme && + allowedURL.Host == originURL.Host && + allowedURL.Port() == originURL.Port() { + return origin, true + } + + if strings.Contains(allowedURL.Host, "*") { + pattern := strings.ReplaceAll(allowedURL.Host, "*.", "(.*\\.)?") + pattern = strings.ReplaceAll(pattern, "*", ".*") + matched, errMatched := regexp.MatchString("^"+pattern+"$", originURL.Host) + if errMatched == nil && matched { + return origin, true + } + } + } + + return "", false +} + +// add Access-Control-Allow-Origin and Access-Control-Allow-Credentials headers. +type handlerOrigin struct { + h http.Handler + allowOrigins []string +} + +func (h *handlerOrigin) ServeHTTP(w http.ResponseWriter, r *http.Request) { + origin, ok := isOriginAllowed(r.Header.Get("Origin"), h.allowOrigins) + if ok { + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Access-Control-Allow-Credentials", "true") + } + + h.h.ServeHTTP(w, r) +} diff --git a/internal/protocols/httpp/handler_origin_test.go b/internal/protocols/httpp/handler_origin_test.go new file mode 100644 index 00000000..1f07878d --- /dev/null +++ b/internal/protocols/httpp/handler_origin_test.go @@ -0,0 +1,87 @@ +package httpp + +import ( + "net/http" + "testing" + "time" + + "github.com/bluenviron/mediamtx/internal/test" + "github.com/stretchr/testify/require" +) + +func TestHandlerOrigin(t *testing.T) { + for _, ca := range []struct { + name string + origin string + allowedOrigins []string + expected string + }{ + { + "empty", + "", + []string{}, + "", + }, + { + "not allowed", + "http://another.com", + []string{"http://example.com"}, + "", + }, + { + "everything allowed, no origin", + "", + []string{"*"}, + "*", + }, + { + "everything allowed, with origin", + "https://example.com", + []string{"*"}, + "*", + }, + { + "allowed", + "https://example.org", + []string{"http://example.com", "https://example.org"}, + "https://example.org", + }, + { + "wildcard", + "https://test.example.org", + []string{"https://*.example.org"}, + "https://test.example.org", + }, + } { + t.Run(ca.name, func(t *testing.T) { + s := &Server{ + Address: "localhost:4555", + AllowOrigins: ca.allowedOrigins, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + Parent: test.NilLogger, + Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }), + } + err := s.Initialize() + require.NoError(t, err) + defer s.Close() + + tr := &http.Transport{} + defer tr.CloseIdleConnections() + hc := &http.Client{Transport: tr} + + req, err := http.NewRequest(http.MethodGet, "http://localhost:4555", nil) + require.NoError(t, err) + + req.Header.Set("Origin", ca.origin) + + res, err := hc.Do(req) + require.NoError(t, err) + defer res.Body.Close() + + require.Equal(t, ca.expected, res.Header.Get("Access-Control-Allow-Origin")) + }) + } +} diff --git a/internal/protocols/httpp/server.go b/internal/protocols/httpp/server.go index 4efffe3b..c4f02176 100644 --- a/internal/protocols/httpp/server.go +++ b/internal/protocols/httpp/server.go @@ -32,6 +32,7 @@ func (nilWriter) Write(p []byte) (int, error) { // - filtering of invalid requests type Server struct { Address string + AllowOrigins []string ReadTimeout time.Duration WriteTimeout time.Duration Encryption bool @@ -100,8 +101,9 @@ func (s *Server) Initialize() error { } h := s.Handler - h = &handlerFilterRequests{h} + h = &handlerOrigin{h, s.AllowOrigins} h = &handlerServerHeader{h} + h = &handlerFilterRequests{h} h = &handlerLogger{h, s.Parent} h = &handlerExitOnPanic{h} h = &handlerWriteTimeout{h, s.WriteTimeout} diff --git a/internal/protocols/httpp/server_test.go b/internal/protocols/httpp/server_test.go index 8753c39d..6359e1f3 100644 --- a/internal/protocols/httpp/server_test.go +++ b/internal/protocols/httpp/server_test.go @@ -13,36 +13,6 @@ import ( "github.com/bluenviron/mediamtx/internal/test" ) -func TestFilterEmptyPath(t *testing.T) { - s := &Server{ - Address: "localhost:4555", - ReadTimeout: 10 * time.Second, - WriteTimeout: 10 * time.Second, - Parent: test.NilLogger, - Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusOK) - }), - } - err := s.Initialize() - require.NoError(t, err) - defer s.Close() - - conn, err := net.Dial("tcp", "localhost:4555") - require.NoError(t, err) - defer conn.Close() - - _, err = conn.Write([]byte("OPTIONS / HTTP/1.1\n" + - "Host: localhost:8889\n\n")) - require.NoError(t, err) - - buf := make([]byte, 200) - n, err := conn.Read(buf) - require.NoError(t, err) - - res := strings.Split(string(buf[:n]), "\r\n") - require.Equal(t, "HTTP/1.1 200 OK", res[0]) -} - func TestUnixSocket(t *testing.T) { s := &Server{ Address: "unix://http.sock", diff --git a/internal/servers/hls/http_server.go b/internal/servers/hls/http_server.go index ac5436ee..fcf429ec 100644 --- a/internal/servers/hls/http_server.go +++ b/internal/servers/hls/http_server.go @@ -39,7 +39,7 @@ type httpServer struct { encryption bool serverKey string serverCert string - allowOrigin string + allowOrigins []string trustedProxies conf.IPNetworks readTimeout conf.Duration writeTimeout conf.Duration @@ -53,12 +53,13 @@ func (s *httpServer) initialize() error { router := gin.New() router.SetTrustedProxies(s.trustedProxies.ToTrustedProxies()) //nolint:errcheck - router.Use(s.middlewareOrigin) + router.Use(s.middlewarePreflightRequests) router.Use(s.onRequest) s.inner = &httpp.Server{ Address: s.address, + AllowOrigins: s.allowOrigins, ReadTimeout: time.Duration(s.readTimeout), WriteTimeout: time.Duration(s.writeTimeout), Encryption: s.encryption, @@ -84,11 +85,7 @@ func (s *httpServer) close() { s.inner.Close() } -func (s *httpServer) middlewareOrigin(ctx *gin.Context) { - ctx.Header("Access-Control-Allow-Origin", s.allowOrigin) - ctx.Header("Access-Control-Allow-Credentials", "true") - - // preflight requests +func (s *httpServer) middlewarePreflightRequests(ctx *gin.Context) { if ctx.Request.Method == http.MethodOptions && ctx.Request.Header.Get("Access-Control-Request-Method") != "" { ctx.Header("Access-Control-Allow-Methods", "OPTIONS, GET") diff --git a/internal/servers/hls/server.go b/internal/servers/hls/server.go index 74175af9..d4edffbe 100644 --- a/internal/servers/hls/server.go +++ b/internal/servers/hls/server.go @@ -74,7 +74,7 @@ type Server struct { Encryption bool ServerKey string ServerCert string - AllowOrigin string + AllowOrigins []string TrustedProxies conf.IPNetworks AlwaysRemux bool Variant conf.HLSVariant @@ -124,7 +124,7 @@ func (s *Server) Initialize() error { encryption: s.Encryption, serverKey: s.ServerKey, serverCert: s.ServerCert, - allowOrigin: s.AllowOrigin, + allowOrigins: s.AllowOrigins, trustedProxies: s.TrustedProxies, readTimeout: s.ReadTimeout, writeTimeout: s.WriteTimeout, diff --git a/internal/servers/hls/server_test.go b/internal/servers/hls/server_test.go index 356f0ace..d257df73 100644 --- a/internal/servers/hls/server_test.go +++ b/internal/servers/hls/server_test.go @@ -68,7 +68,7 @@ func (pa *dummyPath) RemoveReader(_ defs.PathRemoveReaderReq) { func TestServerPreflightRequest(t *testing.T) { s := &Server{ Address: "127.0.0.1:8888", - AllowOrigin: "*", + AllowOrigins: []string{"*"}, ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), PathManager: &dummyPathManager{}, @@ -131,7 +131,6 @@ func TestServerNotFound(t *testing.T) { SegmentDuration: conf.Duration(1 * time.Second), PartDuration: conf.Duration(200 * time.Millisecond), SegmentMaxSize: 50 * 1024 * 1024, - AllowOrigin: "", TrustedProxies: conf.IPNetworks{}, Directory: "", ReadTimeout: conf.Duration(10 * time.Second), @@ -433,7 +432,6 @@ func TestServerDirectory(t *testing.T) { SegmentDuration: conf.Duration(1 * time.Second), PartDuration: conf.Duration(200 * time.Millisecond), SegmentMaxSize: 50 * 1024 * 1024, - AllowOrigin: "", TrustedProxies: conf.IPNetworks{}, Directory: filepath.Join(dir, "mydir"), ReadTimeout: conf.Duration(10 * time.Second), diff --git a/internal/servers/webrtc/http_server.go b/internal/servers/webrtc/http_server.go index fb7332c2..f2b89c94 100644 --- a/internal/servers/webrtc/http_server.go +++ b/internal/servers/webrtc/http_server.go @@ -76,7 +76,7 @@ type httpServer struct { encryption bool serverKey string serverCert string - allowOrigin string + allowOrigins []string trustedProxies conf.IPNetworks readTimeout conf.Duration writeTimeout conf.Duration @@ -90,12 +90,13 @@ func (s *httpServer) initialize() error { router := gin.New() router.SetTrustedProxies(s.trustedProxies.ToTrustedProxies()) //nolint:errcheck - router.Use(s.middlewareOrigin) + router.Use(s.middlewarePreflightRequests) router.Use(s.onRequest) s.inner = &httpp.Server{ Address: s.address, + AllowOrigins: s.allowOrigins, ReadTimeout: time.Duration(s.readTimeout), WriteTimeout: time.Duration(s.writeTimeout), Encryption: s.encryption, @@ -319,11 +320,7 @@ func (s *httpServer) onPage(ctx *gin.Context, pathName string, publish bool) { } } -func (s *httpServer) middlewareOrigin(ctx *gin.Context) { - ctx.Header("Access-Control-Allow-Origin", s.allowOrigin) - ctx.Header("Access-Control-Allow-Credentials", "true") - - // preflight requests +func (s *httpServer) middlewarePreflightRequests(ctx *gin.Context) { if ctx.Request.Method == http.MethodOptions && ctx.Request.Header.Get("Access-Control-Request-Method") != "" { ctx.Header("Access-Control-Allow-Methods", "OPTIONS, GET, POST, PATCH, DELETE") diff --git a/internal/servers/webrtc/server.go b/internal/servers/webrtc/server.go index 1851d12e..3711ac1e 100644 --- a/internal/servers/webrtc/server.go +++ b/internal/servers/webrtc/server.go @@ -190,7 +190,7 @@ type Server struct { Encryption bool ServerKey string ServerCert string - AllowOrigin string + AllowOrigins []string TrustedProxies conf.IPNetworks ReadTimeout conf.Duration WriteTimeout conf.Duration @@ -254,7 +254,7 @@ func (s *Server) Initialize() error { encryption: s.Encryption, serverKey: s.ServerKey, serverCert: s.ServerCert, - allowOrigin: s.AllowOrigin, + allowOrigins: s.AllowOrigins, trustedProxies: s.TrustedProxies, readTimeout: s.ReadTimeout, writeTimeout: s.WriteTimeout, diff --git a/internal/servers/webrtc/server_test.go b/internal/servers/webrtc/server_test.go index 63b02f43..0f12949c 100644 --- a/internal/servers/webrtc/server_test.go +++ b/internal/servers/webrtc/server_test.go @@ -66,7 +66,7 @@ func initializeTestServer(t *testing.T) *Server { s := &Server{ Address: "127.0.0.1:8886", - AllowOrigin: "*", + AllowOrigins: []string{"*"}, TrustedProxies: conf.IPNetworks{}, ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), diff --git a/mediamtx.yml b/mediamtx.yml index 6bc9d9f0..8b20bfa6 100644 --- a/mediamtx.yml +++ b/mediamtx.yml @@ -158,8 +158,9 @@ apiEncryption: no apiServerKey: server.key # Path to the server certificate. apiServerCert: server.crt -# Value of the Access-Control-Allow-Origin header provided in every HTTP response. -apiAllowOrigin: '*' +# List of allowed CORS origins. +# Supports wildcards: ['http://*.example.com'] +apiAllowOrigins: ['*'] # List of IPs or CIDRs of proxies placed before the HTTP server. # If the server receives a request from one of these entries, IP in logs # will be taken from the X-Forwarded-For header. @@ -181,8 +182,9 @@ metricsEncryption: no metricsServerKey: server.key # Path to the server certificate. metricsServerCert: server.crt -# Value of the Access-Control-Allow-Origin header provided in every HTTP response. -metricsAllowOrigin: '*' +# List of allowed CORS origins. +# Supports wildcards: ['http://*.example.com'] +metricsAllowOrigins: ['*'] # List of IPs or CIDRs of proxies placed before the HTTP server. # If the server receives a request from one of these entries, IP in logs # will be taken from the X-Forwarded-For header. @@ -204,8 +206,9 @@ pprofEncryption: no pprofServerKey: server.key # Path to the server certificate. pprofServerCert: server.crt -# Value of the Access-Control-Allow-Origin header provided in every HTTP response. -pprofAllowOrigin: '*' +# List of allowed CORS origins. +# Supports wildcards: ['http://*.example.com'] +pprofAllowOrigins: ['*'] # List of IPs or CIDRs of proxies placed before the HTTP server. # If the server receives a request from one of these entries, IP in logs # will be taken from the X-Forwarded-For header. @@ -227,8 +230,9 @@ playbackEncryption: no playbackServerKey: server.key # Path to the server certificate. playbackServerCert: server.crt -# Value of the Access-Control-Allow-Origin header provided in every HTTP response. -playbackAllowOrigin: '*' +# List of allowed CORS origins. +# Supports wildcards: ['http://*.example.com'] +playbackAllowOrigins: ['*'] # List of IPs or CIDRs of proxies placed before the HTTP server. # If the server receives a request from one of these entries, IP in logs # will be taken from the X-Forwarded-For header. @@ -319,9 +323,9 @@ hlsEncryption: no hlsServerKey: server.key # Path to the server certificate. hlsServerCert: server.crt -# Value of the Access-Control-Allow-Origin header provided in every HTTP response. -# This allows to play the HLS stream from an external website. -hlsAllowOrigin: '*' +# List of allowed CORS origins. +# Supports wildcards: ['http://*.example.com'] +hlsAllowOrigins: ['*'] # List of IPs or CIDRs of proxies placed before the HLS server. # If the server receives a request from one of these entries, IP in logs # will be taken from the X-Forwarded-For header. @@ -377,9 +381,9 @@ webrtcEncryption: no webrtcServerKey: server.key # Path to the server certificate. webrtcServerCert: server.crt -# Value of the Access-Control-Allow-Origin header provided in every HTTP response. -# This allows to play the WebRTC stream from an external website. -webrtcAllowOrigin: '*' +# List of allowed CORS origins. +# Supports wildcards: ['http://*.example.com'] +webrtcAllowOrigins: ['*'] # List of IPs or CIDRs of proxies placed before the WebRTC server. # If the server receives a request from one of these entries, IP in logs # will be taken from the X-Forwarded-For header.