mediamtx/internal/protocols/httpp/handler_origin.go
KHuynh ade0cddeb3
support multiple CORS origins (#5150)
Co-authored-by: aler9 <46489434+aler9@users.noreply.github.com>
2025-11-21 02:00:46 +01:00

88 lines
1.9 KiB
Go

package httpp
import (
"net"
"net/http"
"net/url"
"regexp"
"strings"
)
func isOriginAllowed(origin string, allowOrigins []string) (string, bool) {
if len(allowOrigins) == 0 {
return "", false
}
for _, o := range allowOrigins {
if o == "*" {
return o, true
}
}
if origin == "" {
return "", false
}
originURL, err := url.Parse(origin)
if err != nil || originURL.Scheme == "" {
return "", false
}
if originURL.Port() == "" && originURL.Scheme != "" {
switch originURL.Scheme {
case "http":
originURL.Host = net.JoinHostPort(originURL.Host, "80")
case "https":
originURL.Host = net.JoinHostPort(originURL.Host, "443")
}
}
for _, o := range allowOrigins {
allowedURL, errAllowed := url.Parse(o)
if errAllowed != nil {
continue
}
if allowedURL.Port() == "" {
switch allowedURL.Scheme {
case "http":
allowedURL.Host = net.JoinHostPort(allowedURL.Host, "80")
case "https":
allowedURL.Host = net.JoinHostPort(allowedURL.Host, "443")
}
}
if allowedURL.Scheme == originURL.Scheme &&
allowedURL.Host == originURL.Host &&
allowedURL.Port() == originURL.Port() {
return origin, true
}
if strings.Contains(allowedURL.Host, "*") {
pattern := strings.ReplaceAll(allowedURL.Host, "*.", "(.*\\.)?")
pattern = strings.ReplaceAll(pattern, "*", ".*")
matched, errMatched := regexp.MatchString("^"+pattern+"$", originURL.Host)
if errMatched == nil && matched {
return origin, true
}
}
}
return "", false
}
// add Access-Control-Allow-Origin and Access-Control-Allow-Credentials headers.
type handlerOrigin struct {
h http.Handler
allowOrigins []string
}
func (h *handlerOrigin) ServeHTTP(w http.ResponseWriter, r *http.Request) {
origin, ok := isOriginAllowed(r.Header.Get("Origin"), h.allowOrigins)
if ok {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
h.h.ServeHTTP(w, r)
}