diff --git a/api/openapi.yaml b/api/openapi.yaml index 0428214a..bb70706d 100644 --- a/api/openapi.yaml +++ b/api/openapi.yaml @@ -124,8 +124,10 @@ components: type: string apiServerCert: type: string - apiAllowOrigin: - type: string + apiAllowOrigins: + type: array + items: + type: string apiTrustedProxies: type: array items: diff --git a/internal/api/api.go b/internal/api/api.go index bd6d33aa..065f277a 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -6,8 +6,10 @@ import ( "fmt" "net" "net/http" + "net/url" "os" "reflect" + "regexp" "sort" "strings" "sync" @@ -76,6 +78,71 @@ func recordingsOfPath( return ret } +var errOriginNotAllowed = errors.New("origin not allowed") + +func isOriginAllowed(origin string, allowOrigins []string) (string, error) { + if len(allowOrigins) == 0 { + return "", errOriginNotAllowed + } + + for _, o := range allowOrigins { + if o == "*" { + return o, nil + } + } + + if origin == "" { + return "", errOriginNotAllowed + } + + originURL, err := url.Parse(origin) + if err != nil || originURL.Scheme == "" { + return "", errOriginNotAllowed + } + + if originURL.Port() == "" && originURL.Scheme != "" { + switch originURL.Scheme { + case "http": + originURL.Host = net.JoinHostPort(originURL.Host, "80") + case "https": + originURL.Host = net.JoinHostPort(originURL.Host, "443") + } + } + + for _, o := range allowOrigins { + allowedURL, errAllowed := url.Parse(o) + if errAllowed != nil { + continue + } + + if allowedURL.Port() == "" { + switch allowedURL.Scheme { + case "http": + allowedURL.Host = net.JoinHostPort(allowedURL.Host, "80") + case "https": + allowedURL.Host = net.JoinHostPort(allowedURL.Host, "443") + } + } + + if allowedURL.Scheme == originURL.Scheme && + allowedURL.Host == originURL.Host && + allowedURL.Port() == originURL.Port() { + return origin, nil + } + + if strings.Contains(allowedURL.Host, "*") { + pattern := strings.ReplaceAll(allowedURL.Host, "*.", "(.*\\.)?") + pattern = strings.ReplaceAll(pattern, "*", ".*") + matched, errMatched := regexp.MatchString("^"+pattern+"$", originURL.Host) + if errMatched == nil && matched { + return origin, nil + } + } + } + + return "", errOriginNotAllowed +} + type apiAuthManager interface { Authenticate(req *auth.Request) *auth.Error RefreshJWTJWKS() @@ -94,7 +161,7 @@ type API struct { Encryption bool ServerKey string ServerCert string - AllowOrigin string + AllowOrigins []string TrustedProxies conf.IPNetworks ReadTimeout conf.Duration WriteTimeout conf.Duration @@ -235,7 +302,12 @@ func (a *API) writeError(ctx *gin.Context, status int, err error) { } func (a *API) middlewareOrigin(ctx *gin.Context) { - ctx.Header("Access-Control-Allow-Origin", a.AllowOrigin) + origin, err := isOriginAllowed(ctx.Request.Header.Get("Origin"), a.AllowOrigins) + if err != nil { + return + } + + ctx.Header("Access-Control-Allow-Origin", origin) ctx.Header("Access-Control-Allow-Credentials", "true") // preflight requests diff --git a/internal/api/api_test.go b/internal/api/api_test.go index d216e52b..46c4e3ca 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -3,6 +3,7 @@ package api import ( "bytes" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -83,7 +84,7 @@ func checkError(t *testing.T, msg string, body io.Reader) { func TestPreflightRequest(t *testing.T) { api := API{ Address: "localhost:9997", - AllowOrigin: "*", + AllowOrigins: []string{"*"}, ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), AuthManager: test.NilAuthManager, @@ -118,6 +119,98 @@ func TestPreflightRequest(t *testing.T) { require.Equal(t, byts, []byte{}) } +func TestMiddlewareOrigin(t *testing.T) { + allowOrigins := []string{} + origin := "" + allowedOrigin, err := isOriginAllowed(origin, allowOrigins) + if err == nil { + t.Fatalf("expected error for empty origin, got nil") + } + if allowedOrigin != "" { + t.Fatalf("expected empty allowed origin, got %s", allowedOrigin) + } + + allowOrigins = []string{"http://example.com"} + allowedOrigin, err = isOriginAllowed(origin, allowOrigins) + if err == nil { + t.Fatalf("expected error for empty origin with allowed origins, got nil") + } + if allowedOrigin != "" { + t.Fatalf("unexpected allowed origin: %s", allowedOrigin) + } + + allowOrigins = []string{"*"} + allowedOrigin, err = isOriginAllowed(origin, allowOrigins) + if err != nil { + t.Fatalf("unexpected error for wildcard origin: %v", err) + } + if allowedOrigin != "*" { + t.Fatalf("unexpected allowed origin: %s", allowedOrigin) + } + + origin = "http://example.com" + allowedOrigin, err = isOriginAllowed(origin, allowOrigins) + if err != nil { + t.Fatalf("unexpected error for matching wildcard: %v", err) + } + if allowedOrigin != "*" { + t.Fatalf("unexpected allowed origin: %s", allowedOrigin) + } + + allowOrigins = []string{"http://example.com", "https://example.org"} + allowedOrigin, err = isOriginAllowed(origin, allowOrigins) + if err != nil { + t.Fatalf("unexpected error for matching origin: %v", err) + } + if allowedOrigin != origin { + t.Fatalf("expected empty allowed origin, got %s", allowedOrigin) + } + + allowedOrigin, err = isOriginAllowed(origin, allowOrigins) + if err != nil { + t.Fatalf("unexpected error for matching origin: %v", err) + } + if allowedOrigin != origin { + t.Fatalf("unexpected allowed origin: %s", allowedOrigin) + } + + origin = "https://example.org" + allowedOrigin, err = isOriginAllowed(origin, allowOrigins) + if err != nil { + t.Fatalf("unexpected error for matching origin: %v", err) + } + if allowedOrigin != origin { + t.Fatalf("unexpected allowed origin: %s", allowedOrigin) + } + + allowedOrigin, err = isOriginAllowed("http://notallowed.com", allowOrigins) + if !errors.Is(err, errOriginNotAllowed) { + t.Fatalf("expected errOriginNotAllowed for disallowed origin, got %v", err) + } + if allowedOrigin != "" { + t.Fatalf("expected empty allowed origin, got %s", allowedOrigin) + } + + allowOrigins = []string{"http://*.example.com"} + origin = "http://test.example.com" + allowedOrigin, err = isOriginAllowed(origin, allowOrigins) + if err != nil { + t.Fatalf("unexpected error for wildcard subdomain: %v", err) + } + if allowedOrigin != origin { + t.Fatalf("unexpected allowed origin: %s", allowedOrigin) + } + + origin = "http://example.com" + allowedOrigin, err = isOriginAllowed(origin, allowOrigins) + if err != nil { + t.Fatalf("unexpected error for exact subdomain match: %v", err) + } + if allowedOrigin != origin { + t.Fatalf("unexpected allowed origin: %s", allowedOrigin) + } +} + func TestInfo(t *testing.T) { cnf := tempConf(t, "api: yes\n") diff --git a/internal/conf/conf.go b/internal/conf/conf.go index 8595adda..1a64b4cb 100644 --- a/internal/conf/conf.go +++ b/internal/conf/conf.go @@ -182,7 +182,8 @@ type Conf struct { APIEncryption bool `json:"apiEncryption"` APIServerKey string `json:"apiServerKey"` APIServerCert string `json:"apiServerCert"` - APIAllowOrigin string `json:"apiAllowOrigin"` + APIAllowOrigin *string `json:"apiAllowOrigin,omitempty"` // deprecated + APIAllowOrigins []string `json:"apiAllowOrigins"` APITrustedProxies IPNetworks `json:"apiTrustedProxies"` // Metrics @@ -340,7 +341,7 @@ func (conf *Conf) setDefaults() { conf.APIAddress = ":9997" conf.APIServerKey = "server.key" conf.APIServerCert = "server.crt" - conf.APIAllowOrigin = "*" + conf.APIAllowOrigins = []string{"*"} // Metrics conf.MetricsAddress = ":9998" @@ -607,6 +608,13 @@ func (conf *Conf) Validate(l logger.Writer) error { } } + // Control API + + if conf.APIAllowOrigin != nil { + l.Log(logger.Warn, "parameter 'apiAllowOrigin' is deprecated and has been replaced with 'apiAllowOrigins'") + conf.APIAllowOrigins = []string{*conf.APIAllowOrigin} + } + // RTSP if conf.RTSPDisable != nil { diff --git a/internal/core/core.go b/internal/core/core.go index b6907dc0..62c5c023 100644 --- a/internal/core/core.go +++ b/internal/core/core.go @@ -10,6 +10,7 @@ import ( "path/filepath" "reflect" "runtime" + "slices" "strings" "syscall" "time" @@ -650,7 +651,7 @@ func (p *Core) createResources(initial bool) error { Encryption: p.conf.APIEncryption, ServerKey: p.conf.APIServerKey, ServerCert: p.conf.APIServerCert, - AllowOrigin: p.conf.APIAllowOrigin, + AllowOrigins: p.conf.APIAllowOrigins, TrustedProxies: p.conf.APITrustedProxies, ReadTimeout: p.conf.ReadTimeout, WriteTimeout: p.conf.WriteTimeout, @@ -911,7 +912,7 @@ func (p *Core) closeResources(newConf *conf.Conf, calledByAPI bool) { newConf.APIEncryption != p.conf.APIEncryption || newConf.APIServerKey != p.conf.APIServerKey || newConf.APIServerCert != p.conf.APIServerCert || - newConf.APIAllowOrigin != p.conf.APIAllowOrigin || + slices.Equal(newConf.APIAllowOrigins, p.conf.APIAllowOrigins) || !reflect.DeepEqual(newConf.APITrustedProxies, p.conf.APITrustedProxies) || newConf.ReadTimeout != p.conf.ReadTimeout || newConf.WriteTimeout != p.conf.WriteTimeout || diff --git a/mediamtx.yml b/mediamtx.yml index 6bc9d9f0..ae020731 100644 --- a/mediamtx.yml +++ b/mediamtx.yml @@ -158,8 +158,11 @@ apiEncryption: no apiServerKey: server.key # Path to the server certificate. apiServerCert: server.crt -# Value of the Access-Control-Allow-Origin header provided in every HTTP response. -apiAllowOrigin: '*' +# List of allowed origins. +# Supports wildcards: ['http://*.example.com'] +# If apiAllowOrigins is set to '*', the Access-Control-Allow-Origin response will be '*', +# even if no Origin was sent from the client. +apiAllowOrigins: ['*'] # List of IPs or CIDRs of proxies placed before the HTTP server. # If the server receives a request from one of these entries, IP in logs # will be taken from the X-Forwarded-For header.