feat(api): add support for multiple allowed origins with wildcards

Introduce a new `allowOrigins` configuration field (slice of strings) to
replace the single-string `allowOrigin`.
Add logic to validate origins against allowed patterns, supporting:
- Exact matches
- Wildcard domain matching (e.g., http://*.example.com)

Update configuration handling in core and test components to support the
new format.
Maintain backwards compatibility with the old `allowOrigin` field.
Update `openapi.yaml` to reflect the new configuration field.
This commit is contained in:
Kim Adrian Huynh 2025-11-02 11:13:00 +01:00 committed by aler9
parent 14ab95f39c
commit 61c67cc585
6 changed files with 190 additions and 11 deletions

View file

@ -124,8 +124,10 @@ components:
type: string
apiServerCert:
type: string
apiAllowOrigin:
type: string
apiAllowOrigins:
type: array
items:
type: string
apiTrustedProxies:
type: array
items:

View file

@ -6,8 +6,10 @@ import (
"fmt"
"net"
"net/http"
"net/url"
"os"
"reflect"
"regexp"
"sort"
"strings"
"sync"
@ -76,6 +78,71 @@ func recordingsOfPath(
return ret
}
var errOriginNotAllowed = errors.New("origin not allowed")
func isOriginAllowed(origin string, allowOrigins []string) (string, error) {
if len(allowOrigins) == 0 {
return "", errOriginNotAllowed
}
for _, o := range allowOrigins {
if o == "*" {
return o, nil
}
}
if origin == "" {
return "", errOriginNotAllowed
}
originURL, err := url.Parse(origin)
if err != nil || originURL.Scheme == "" {
return "", errOriginNotAllowed
}
if originURL.Port() == "" && originURL.Scheme != "" {
switch originURL.Scheme {
case "http":
originURL.Host = net.JoinHostPort(originURL.Host, "80")
case "https":
originURL.Host = net.JoinHostPort(originURL.Host, "443")
}
}
for _, o := range allowOrigins {
allowedURL, errAllowed := url.Parse(o)
if errAllowed != nil {
continue
}
if allowedURL.Port() == "" {
switch allowedURL.Scheme {
case "http":
allowedURL.Host = net.JoinHostPort(allowedURL.Host, "80")
case "https":
allowedURL.Host = net.JoinHostPort(allowedURL.Host, "443")
}
}
if allowedURL.Scheme == originURL.Scheme &&
allowedURL.Host == originURL.Host &&
allowedURL.Port() == originURL.Port() {
return origin, nil
}
if strings.Contains(allowedURL.Host, "*") {
pattern := strings.ReplaceAll(allowedURL.Host, "*.", "(.*\\.)?")
pattern = strings.ReplaceAll(pattern, "*", ".*")
matched, errMatched := regexp.MatchString("^"+pattern+"$", originURL.Host)
if errMatched == nil && matched {
return origin, nil
}
}
}
return "", errOriginNotAllowed
}
type apiAuthManager interface {
Authenticate(req *auth.Request) *auth.Error
RefreshJWTJWKS()
@ -94,7 +161,7 @@ type API struct {
Encryption bool
ServerKey string
ServerCert string
AllowOrigin string
AllowOrigins []string
TrustedProxies conf.IPNetworks
ReadTimeout conf.Duration
WriteTimeout conf.Duration
@ -235,7 +302,12 @@ func (a *API) writeError(ctx *gin.Context, status int, err error) {
}
func (a *API) middlewareOrigin(ctx *gin.Context) {
ctx.Header("Access-Control-Allow-Origin", a.AllowOrigin)
origin, err := isOriginAllowed(ctx.Request.Header.Get("Origin"), a.AllowOrigins)
if err != nil {
return
}
ctx.Header("Access-Control-Allow-Origin", origin)
ctx.Header("Access-Control-Allow-Credentials", "true")
// preflight requests

View file

@ -3,6 +3,7 @@ package api
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
@ -83,7 +84,7 @@ func checkError(t *testing.T, msg string, body io.Reader) {
func TestPreflightRequest(t *testing.T) {
api := API{
Address: "localhost:9997",
AllowOrigin: "*",
AllowOrigins: []string{"*"},
ReadTimeout: conf.Duration(10 * time.Second),
WriteTimeout: conf.Duration(10 * time.Second),
AuthManager: test.NilAuthManager,
@ -118,6 +119,98 @@ func TestPreflightRequest(t *testing.T) {
require.Equal(t, byts, []byte{})
}
func TestMiddlewareOrigin(t *testing.T) {
allowOrigins := []string{}
origin := ""
allowedOrigin, err := isOriginAllowed(origin, allowOrigins)
if err == nil {
t.Fatalf("expected error for empty origin, got nil")
}
if allowedOrigin != "" {
t.Fatalf("expected empty allowed origin, got %s", allowedOrigin)
}
allowOrigins = []string{"http://example.com"}
allowedOrigin, err = isOriginAllowed(origin, allowOrigins)
if err == nil {
t.Fatalf("expected error for empty origin with allowed origins, got nil")
}
if allowedOrigin != "" {
t.Fatalf("unexpected allowed origin: %s", allowedOrigin)
}
allowOrigins = []string{"*"}
allowedOrigin, err = isOriginAllowed(origin, allowOrigins)
if err != nil {
t.Fatalf("unexpected error for wildcard origin: %v", err)
}
if allowedOrigin != "*" {
t.Fatalf("unexpected allowed origin: %s", allowedOrigin)
}
origin = "http://example.com"
allowedOrigin, err = isOriginAllowed(origin, allowOrigins)
if err != nil {
t.Fatalf("unexpected error for matching wildcard: %v", err)
}
if allowedOrigin != "*" {
t.Fatalf("unexpected allowed origin: %s", allowedOrigin)
}
allowOrigins = []string{"http://example.com", "https://example.org"}
allowedOrigin, err = isOriginAllowed(origin, allowOrigins)
if err != nil {
t.Fatalf("unexpected error for matching origin: %v", err)
}
if allowedOrigin != origin {
t.Fatalf("expected empty allowed origin, got %s", allowedOrigin)
}
allowedOrigin, err = isOriginAllowed(origin, allowOrigins)
if err != nil {
t.Fatalf("unexpected error for matching origin: %v", err)
}
if allowedOrigin != origin {
t.Fatalf("unexpected allowed origin: %s", allowedOrigin)
}
origin = "https://example.org"
allowedOrigin, err = isOriginAllowed(origin, allowOrigins)
if err != nil {
t.Fatalf("unexpected error for matching origin: %v", err)
}
if allowedOrigin != origin {
t.Fatalf("unexpected allowed origin: %s", allowedOrigin)
}
allowedOrigin, err = isOriginAllowed("http://notallowed.com", allowOrigins)
if !errors.Is(err, errOriginNotAllowed) {
t.Fatalf("expected errOriginNotAllowed for disallowed origin, got %v", err)
}
if allowedOrigin != "" {
t.Fatalf("expected empty allowed origin, got %s", allowedOrigin)
}
allowOrigins = []string{"http://*.example.com"}
origin = "http://test.example.com"
allowedOrigin, err = isOriginAllowed(origin, allowOrigins)
if err != nil {
t.Fatalf("unexpected error for wildcard subdomain: %v", err)
}
if allowedOrigin != origin {
t.Fatalf("unexpected allowed origin: %s", allowedOrigin)
}
origin = "http://example.com"
allowedOrigin, err = isOriginAllowed(origin, allowOrigins)
if err != nil {
t.Fatalf("unexpected error for exact subdomain match: %v", err)
}
if allowedOrigin != origin {
t.Fatalf("unexpected allowed origin: %s", allowedOrigin)
}
}
func TestInfo(t *testing.T) {
cnf := tempConf(t, "api: yes\n")

View file

@ -182,7 +182,8 @@ type Conf struct {
APIEncryption bool `json:"apiEncryption"`
APIServerKey string `json:"apiServerKey"`
APIServerCert string `json:"apiServerCert"`
APIAllowOrigin string `json:"apiAllowOrigin"`
APIAllowOrigin *string `json:"apiAllowOrigin,omitempty"` // deprecated
APIAllowOrigins []string `json:"apiAllowOrigins"`
APITrustedProxies IPNetworks `json:"apiTrustedProxies"`
// Metrics
@ -340,7 +341,7 @@ func (conf *Conf) setDefaults() {
conf.APIAddress = ":9997"
conf.APIServerKey = "server.key"
conf.APIServerCert = "server.crt"
conf.APIAllowOrigin = "*"
conf.APIAllowOrigins = []string{"*"}
// Metrics
conf.MetricsAddress = ":9998"
@ -607,6 +608,13 @@ func (conf *Conf) Validate(l logger.Writer) error {
}
}
// Control API
if conf.APIAllowOrigin != nil {
l.Log(logger.Warn, "parameter 'apiAllowOrigin' is deprecated and has been replaced with 'apiAllowOrigins'")
conf.APIAllowOrigins = []string{*conf.APIAllowOrigin}
}
// RTSP
if conf.RTSPDisable != nil {

View file

@ -10,6 +10,7 @@ import (
"path/filepath"
"reflect"
"runtime"
"slices"
"strings"
"syscall"
"time"
@ -650,7 +651,7 @@ func (p *Core) createResources(initial bool) error {
Encryption: p.conf.APIEncryption,
ServerKey: p.conf.APIServerKey,
ServerCert: p.conf.APIServerCert,
AllowOrigin: p.conf.APIAllowOrigin,
AllowOrigins: p.conf.APIAllowOrigins,
TrustedProxies: p.conf.APITrustedProxies,
ReadTimeout: p.conf.ReadTimeout,
WriteTimeout: p.conf.WriteTimeout,
@ -911,7 +912,7 @@ func (p *Core) closeResources(newConf *conf.Conf, calledByAPI bool) {
newConf.APIEncryption != p.conf.APIEncryption ||
newConf.APIServerKey != p.conf.APIServerKey ||
newConf.APIServerCert != p.conf.APIServerCert ||
newConf.APIAllowOrigin != p.conf.APIAllowOrigin ||
slices.Equal(newConf.APIAllowOrigins, p.conf.APIAllowOrigins) ||
!reflect.DeepEqual(newConf.APITrustedProxies, p.conf.APITrustedProxies) ||
newConf.ReadTimeout != p.conf.ReadTimeout ||
newConf.WriteTimeout != p.conf.WriteTimeout ||

View file

@ -158,8 +158,11 @@ apiEncryption: no
apiServerKey: server.key
# Path to the server certificate.
apiServerCert: server.crt
# Value of the Access-Control-Allow-Origin header provided in every HTTP response.
apiAllowOrigin: '*'
# List of allowed origins.
# Supports wildcards: ['http://*.example.com']
# If apiAllowOrigins is set to '*', the Access-Control-Allow-Origin response will be '*',
# even if no Origin was sent from the client.
apiAllowOrigins: ['*']
# List of IPs or CIDRs of proxies placed before the HTTP server.
# If the server receives a request from one of these entries, IP in logs
# will be taken from the X-Forwarded-For header.