mirror of
https://github.com/bluenviron/mediamtx.git
synced 2025-12-29 06:22:00 -08:00
feat(api): add support for multiple allowed origins with wildcards
Introduce a new `allowOrigins` configuration field (slice of strings) to replace the single-string `allowOrigin`. Add logic to validate origins against allowed patterns, supporting: - Exact matches - Wildcard domain matching (e.g., http://*.example.com) Update configuration handling in core and test components to support the new format. Maintain backwards compatibility with the old `allowOrigin` field. Update `openapi.yaml` to reflect the new configuration field.
This commit is contained in:
parent
14ab95f39c
commit
61c67cc585
6 changed files with 190 additions and 11 deletions
|
|
@ -124,8 +124,10 @@ components:
|
|||
type: string
|
||||
apiServerCert:
|
||||
type: string
|
||||
apiAllowOrigin:
|
||||
type: string
|
||||
apiAllowOrigins:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
apiTrustedProxies:
|
||||
type: array
|
||||
items:
|
||||
|
|
|
|||
|
|
@ -6,8 +6,10 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
|
@ -76,6 +78,71 @@ func recordingsOfPath(
|
|||
return ret
|
||||
}
|
||||
|
||||
var errOriginNotAllowed = errors.New("origin not allowed")
|
||||
|
||||
func isOriginAllowed(origin string, allowOrigins []string) (string, error) {
|
||||
if len(allowOrigins) == 0 {
|
||||
return "", errOriginNotAllowed
|
||||
}
|
||||
|
||||
for _, o := range allowOrigins {
|
||||
if o == "*" {
|
||||
return o, nil
|
||||
}
|
||||
}
|
||||
|
||||
if origin == "" {
|
||||
return "", errOriginNotAllowed
|
||||
}
|
||||
|
||||
originURL, err := url.Parse(origin)
|
||||
if err != nil || originURL.Scheme == "" {
|
||||
return "", errOriginNotAllowed
|
||||
}
|
||||
|
||||
if originURL.Port() == "" && originURL.Scheme != "" {
|
||||
switch originURL.Scheme {
|
||||
case "http":
|
||||
originURL.Host = net.JoinHostPort(originURL.Host, "80")
|
||||
case "https":
|
||||
originURL.Host = net.JoinHostPort(originURL.Host, "443")
|
||||
}
|
||||
}
|
||||
|
||||
for _, o := range allowOrigins {
|
||||
allowedURL, errAllowed := url.Parse(o)
|
||||
if errAllowed != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if allowedURL.Port() == "" {
|
||||
switch allowedURL.Scheme {
|
||||
case "http":
|
||||
allowedURL.Host = net.JoinHostPort(allowedURL.Host, "80")
|
||||
case "https":
|
||||
allowedURL.Host = net.JoinHostPort(allowedURL.Host, "443")
|
||||
}
|
||||
}
|
||||
|
||||
if allowedURL.Scheme == originURL.Scheme &&
|
||||
allowedURL.Host == originURL.Host &&
|
||||
allowedURL.Port() == originURL.Port() {
|
||||
return origin, nil
|
||||
}
|
||||
|
||||
if strings.Contains(allowedURL.Host, "*") {
|
||||
pattern := strings.ReplaceAll(allowedURL.Host, "*.", "(.*\\.)?")
|
||||
pattern = strings.ReplaceAll(pattern, "*", ".*")
|
||||
matched, errMatched := regexp.MatchString("^"+pattern+"$", originURL.Host)
|
||||
if errMatched == nil && matched {
|
||||
return origin, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", errOriginNotAllowed
|
||||
}
|
||||
|
||||
type apiAuthManager interface {
|
||||
Authenticate(req *auth.Request) *auth.Error
|
||||
RefreshJWTJWKS()
|
||||
|
|
@ -94,7 +161,7 @@ type API struct {
|
|||
Encryption bool
|
||||
ServerKey string
|
||||
ServerCert string
|
||||
AllowOrigin string
|
||||
AllowOrigins []string
|
||||
TrustedProxies conf.IPNetworks
|
||||
ReadTimeout conf.Duration
|
||||
WriteTimeout conf.Duration
|
||||
|
|
@ -235,7 +302,12 @@ func (a *API) writeError(ctx *gin.Context, status int, err error) {
|
|||
}
|
||||
|
||||
func (a *API) middlewareOrigin(ctx *gin.Context) {
|
||||
ctx.Header("Access-Control-Allow-Origin", a.AllowOrigin)
|
||||
origin, err := isOriginAllowed(ctx.Request.Header.Get("Origin"), a.AllowOrigins)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Header("Access-Control-Allow-Origin", origin)
|
||||
ctx.Header("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
// preflight requests
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package api
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
|
@ -83,7 +84,7 @@ func checkError(t *testing.T, msg string, body io.Reader) {
|
|||
func TestPreflightRequest(t *testing.T) {
|
||||
api := API{
|
||||
Address: "localhost:9997",
|
||||
AllowOrigin: "*",
|
||||
AllowOrigins: []string{"*"},
|
||||
ReadTimeout: conf.Duration(10 * time.Second),
|
||||
WriteTimeout: conf.Duration(10 * time.Second),
|
||||
AuthManager: test.NilAuthManager,
|
||||
|
|
@ -118,6 +119,98 @@ func TestPreflightRequest(t *testing.T) {
|
|||
require.Equal(t, byts, []byte{})
|
||||
}
|
||||
|
||||
func TestMiddlewareOrigin(t *testing.T) {
|
||||
allowOrigins := []string{}
|
||||
origin := ""
|
||||
allowedOrigin, err := isOriginAllowed(origin, allowOrigins)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for empty origin, got nil")
|
||||
}
|
||||
if allowedOrigin != "" {
|
||||
t.Fatalf("expected empty allowed origin, got %s", allowedOrigin)
|
||||
}
|
||||
|
||||
allowOrigins = []string{"http://example.com"}
|
||||
allowedOrigin, err = isOriginAllowed(origin, allowOrigins)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for empty origin with allowed origins, got nil")
|
||||
}
|
||||
if allowedOrigin != "" {
|
||||
t.Fatalf("unexpected allowed origin: %s", allowedOrigin)
|
||||
}
|
||||
|
||||
allowOrigins = []string{"*"}
|
||||
allowedOrigin, err = isOriginAllowed(origin, allowOrigins)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for wildcard origin: %v", err)
|
||||
}
|
||||
if allowedOrigin != "*" {
|
||||
t.Fatalf("unexpected allowed origin: %s", allowedOrigin)
|
||||
}
|
||||
|
||||
origin = "http://example.com"
|
||||
allowedOrigin, err = isOriginAllowed(origin, allowOrigins)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for matching wildcard: %v", err)
|
||||
}
|
||||
if allowedOrigin != "*" {
|
||||
t.Fatalf("unexpected allowed origin: %s", allowedOrigin)
|
||||
}
|
||||
|
||||
allowOrigins = []string{"http://example.com", "https://example.org"}
|
||||
allowedOrigin, err = isOriginAllowed(origin, allowOrigins)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for matching origin: %v", err)
|
||||
}
|
||||
if allowedOrigin != origin {
|
||||
t.Fatalf("expected empty allowed origin, got %s", allowedOrigin)
|
||||
}
|
||||
|
||||
allowedOrigin, err = isOriginAllowed(origin, allowOrigins)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for matching origin: %v", err)
|
||||
}
|
||||
if allowedOrigin != origin {
|
||||
t.Fatalf("unexpected allowed origin: %s", allowedOrigin)
|
||||
}
|
||||
|
||||
origin = "https://example.org"
|
||||
allowedOrigin, err = isOriginAllowed(origin, allowOrigins)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for matching origin: %v", err)
|
||||
}
|
||||
if allowedOrigin != origin {
|
||||
t.Fatalf("unexpected allowed origin: %s", allowedOrigin)
|
||||
}
|
||||
|
||||
allowedOrigin, err = isOriginAllowed("http://notallowed.com", allowOrigins)
|
||||
if !errors.Is(err, errOriginNotAllowed) {
|
||||
t.Fatalf("expected errOriginNotAllowed for disallowed origin, got %v", err)
|
||||
}
|
||||
if allowedOrigin != "" {
|
||||
t.Fatalf("expected empty allowed origin, got %s", allowedOrigin)
|
||||
}
|
||||
|
||||
allowOrigins = []string{"http://*.example.com"}
|
||||
origin = "http://test.example.com"
|
||||
allowedOrigin, err = isOriginAllowed(origin, allowOrigins)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for wildcard subdomain: %v", err)
|
||||
}
|
||||
if allowedOrigin != origin {
|
||||
t.Fatalf("unexpected allowed origin: %s", allowedOrigin)
|
||||
}
|
||||
|
||||
origin = "http://example.com"
|
||||
allowedOrigin, err = isOriginAllowed(origin, allowOrigins)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for exact subdomain match: %v", err)
|
||||
}
|
||||
if allowedOrigin != origin {
|
||||
t.Fatalf("unexpected allowed origin: %s", allowedOrigin)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInfo(t *testing.T) {
|
||||
cnf := tempConf(t, "api: yes\n")
|
||||
|
||||
|
|
|
|||
|
|
@ -182,7 +182,8 @@ type Conf struct {
|
|||
APIEncryption bool `json:"apiEncryption"`
|
||||
APIServerKey string `json:"apiServerKey"`
|
||||
APIServerCert string `json:"apiServerCert"`
|
||||
APIAllowOrigin string `json:"apiAllowOrigin"`
|
||||
APIAllowOrigin *string `json:"apiAllowOrigin,omitempty"` // deprecated
|
||||
APIAllowOrigins []string `json:"apiAllowOrigins"`
|
||||
APITrustedProxies IPNetworks `json:"apiTrustedProxies"`
|
||||
|
||||
// Metrics
|
||||
|
|
@ -340,7 +341,7 @@ func (conf *Conf) setDefaults() {
|
|||
conf.APIAddress = ":9997"
|
||||
conf.APIServerKey = "server.key"
|
||||
conf.APIServerCert = "server.crt"
|
||||
conf.APIAllowOrigin = "*"
|
||||
conf.APIAllowOrigins = []string{"*"}
|
||||
|
||||
// Metrics
|
||||
conf.MetricsAddress = ":9998"
|
||||
|
|
@ -607,6 +608,13 @@ func (conf *Conf) Validate(l logger.Writer) error {
|
|||
}
|
||||
}
|
||||
|
||||
// Control API
|
||||
|
||||
if conf.APIAllowOrigin != nil {
|
||||
l.Log(logger.Warn, "parameter 'apiAllowOrigin' is deprecated and has been replaced with 'apiAllowOrigins'")
|
||||
conf.APIAllowOrigins = []string{*conf.APIAllowOrigin}
|
||||
}
|
||||
|
||||
// RTSP
|
||||
|
||||
if conf.RTSPDisable != nil {
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import (
|
|||
"path/filepath"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
|
@ -650,7 +651,7 @@ func (p *Core) createResources(initial bool) error {
|
|||
Encryption: p.conf.APIEncryption,
|
||||
ServerKey: p.conf.APIServerKey,
|
||||
ServerCert: p.conf.APIServerCert,
|
||||
AllowOrigin: p.conf.APIAllowOrigin,
|
||||
AllowOrigins: p.conf.APIAllowOrigins,
|
||||
TrustedProxies: p.conf.APITrustedProxies,
|
||||
ReadTimeout: p.conf.ReadTimeout,
|
||||
WriteTimeout: p.conf.WriteTimeout,
|
||||
|
|
@ -911,7 +912,7 @@ func (p *Core) closeResources(newConf *conf.Conf, calledByAPI bool) {
|
|||
newConf.APIEncryption != p.conf.APIEncryption ||
|
||||
newConf.APIServerKey != p.conf.APIServerKey ||
|
||||
newConf.APIServerCert != p.conf.APIServerCert ||
|
||||
newConf.APIAllowOrigin != p.conf.APIAllowOrigin ||
|
||||
slices.Equal(newConf.APIAllowOrigins, p.conf.APIAllowOrigins) ||
|
||||
!reflect.DeepEqual(newConf.APITrustedProxies, p.conf.APITrustedProxies) ||
|
||||
newConf.ReadTimeout != p.conf.ReadTimeout ||
|
||||
newConf.WriteTimeout != p.conf.WriteTimeout ||
|
||||
|
|
|
|||
|
|
@ -158,8 +158,11 @@ apiEncryption: no
|
|||
apiServerKey: server.key
|
||||
# Path to the server certificate.
|
||||
apiServerCert: server.crt
|
||||
# Value of the Access-Control-Allow-Origin header provided in every HTTP response.
|
||||
apiAllowOrigin: '*'
|
||||
# List of allowed origins.
|
||||
# Supports wildcards: ['http://*.example.com']
|
||||
# If apiAllowOrigins is set to '*', the Access-Control-Allow-Origin response will be '*',
|
||||
# even if no Origin was sent from the client.
|
||||
apiAllowOrigins: ['*']
|
||||
# List of IPs or CIDRs of proxies placed before the HTTP server.
|
||||
# If the server receives a request from one of these entries, IP in logs
|
||||
# will be taken from the X-Forwarded-For header.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue