package hls import ( "fmt" "io" "net/http" "os" "path/filepath" "testing" "time" "github.com/bluenviron/gohlslib/v2" "github.com/bluenviron/gohlslib/v2/pkg/codecs" "github.com/bluenviron/gortsplib/v5/pkg/description" "github.com/bluenviron/mediacommon/v2/pkg/codecs/mpeg4audio" "github.com/bluenviron/mediamtx/internal/auth" "github.com/bluenviron/mediamtx/internal/conf" "github.com/bluenviron/mediamtx/internal/defs" "github.com/bluenviron/mediamtx/internal/externalcmd" "github.com/bluenviron/mediamtx/internal/logger" "github.com/bluenviron/mediamtx/internal/stream" "github.com/bluenviron/mediamtx/internal/test" "github.com/bluenviron/mediamtx/internal/unit" "github.com/stretchr/testify/require" ) type dummyPathManager struct { setHLSServerImpl func() []defs.Path findPathConfImpl func(req defs.PathFindPathConfReq) (*conf.Path, error) addReaderImpl func(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) } func (pm *dummyPathManager) SetHLSServer(*Server) []defs.Path { if pm.setHLSServerImpl != nil { return pm.setHLSServerImpl() } return nil } func (pm *dummyPathManager) FindPathConf(req defs.PathFindPathConfReq) (*conf.Path, error) { return pm.findPathConfImpl(req) } func (pm *dummyPathManager) AddReader(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { return pm.addReaderImpl(req) } type dummyPath struct{} func (pa *dummyPath) Name() string { return "teststream" } func (pa *dummyPath) SafeConf() *conf.Path { return &conf.Path{} } func (pa *dummyPath) ExternalCmdEnv() externalcmd.Environment { return nil } func (pa *dummyPath) RemovePublisher(_ defs.PathRemovePublisherReq) { } func (pa *dummyPath) RemoveReader(_ defs.PathRemoveReaderReq) { } func TestServerPreflightRequest(t *testing.T) { s := &Server{ Address: "127.0.0.1:8888", AllowOrigins: []string{"*"}, ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), PathManager: &dummyPathManager{}, Parent: test.NilLogger, } err := s.Initialize() require.NoError(t, err) defer s.Close() tr := &http.Transport{} defer tr.CloseIdleConnections() hc := &http.Client{Transport: tr} req, err := http.NewRequest(http.MethodOptions, "http://localhost:8888", nil) require.NoError(t, err) req.Header.Add("Access-Control-Request-Method", "GET") res, err := hc.Do(req) require.NoError(t, err) defer res.Body.Close() require.Equal(t, http.StatusNoContent, res.StatusCode) byts, err := io.ReadAll(res.Body) require.NoError(t, err) require.Equal(t, "*", res.Header.Get("Access-Control-Allow-Origin")) require.Equal(t, "true", res.Header.Get("Access-Control-Allow-Credentials")) require.Equal(t, "OPTIONS, GET", res.Header.Get("Access-Control-Allow-Methods")) require.Equal(t, "Authorization, Range", res.Header.Get("Access-Control-Allow-Headers")) require.Equal(t, byts, []byte{}) } func TestServerNotFound(t *testing.T) { for _, ca := range []string{ "always remux off", "always remux on", } { t.Run(ca, func(t *testing.T) { pm := &dummyPathManager{ findPathConfImpl: func(req defs.PathFindPathConfReq) (*conf.Path, error) { require.Equal(t, "nonexisting", req.AccessRequest.Name) return &conf.Path{}, nil }, addReaderImpl: func(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { require.Equal(t, "nonexisting", req.AccessRequest.Name) return nil, nil, fmt.Errorf("not found") }, } s := &Server{ Address: "127.0.0.1:8888", Encryption: false, ServerKey: "", ServerCert: "", AlwaysRemux: ca == "always remux on", Variant: conf.HLSVariant(gohlslib.MuxerVariantMPEGTS), SegmentCount: 7, SegmentDuration: conf.Duration(1 * time.Second), PartDuration: conf.Duration(200 * time.Millisecond), SegmentMaxSize: 50 * 1024 * 1024, TrustedProxies: conf.IPNetworks{}, Directory: "", ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), PathManager: pm, Parent: test.NilLogger, } err := s.Initialize() require.NoError(t, err) defer s.Close() tr := &http.Transport{} defer tr.CloseIdleConnections() hc := &http.Client{Transport: tr} func() { var req *http.Request req, err = http.NewRequest(http.MethodGet, "http://myuser:mypass@127.0.0.1:8888/nonexisting/", nil) require.NoError(t, err) var res *http.Response res, err = hc.Do(req) require.NoError(t, err) defer res.Body.Close() require.Equal(t, http.StatusOK, res.StatusCode) }() func() { var req *http.Request req, err = http.NewRequest(http.MethodGet, "http://myuser:mypass@127.0.0.1:8888/nonexisting/index.m3u8", nil) require.NoError(t, err) var res *http.Response res, err = hc.Do(req) require.NoError(t, err) defer res.Body.Close() require.Equal(t, http.StatusNotFound, res.StatusCode) }() }) } } func TestServerRead(t *testing.T) { for _, ca := range []string{ "always remux off", "always remux on", } { t.Run(ca, func(t *testing.T) { desc := &description.Session{Medias: []*description.Media{ test.MediaH264, test.MediaMPEG4Audio, }} strm := &stream.Stream{ WriteQueueSize: 512, RTPMaxPayloadSize: 1450, Desc: desc, GenerateRTPPackets: true, Parent: test.NilLogger, } err := strm.Initialize() require.NoError(t, err) pm := &dummyPathManager{ findPathConfImpl: func(req defs.PathFindPathConfReq) (*conf.Path, error) { require.Equal(t, "teststream", req.AccessRequest.Name) require.Equal(t, "param=value", req.AccessRequest.Query) require.Equal(t, "myuser", req.AccessRequest.Credentials.User) require.Equal(t, "mypass", req.AccessRequest.Credentials.Pass) return &conf.Path{}, nil }, addReaderImpl: func(req defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { require.Equal(t, "teststream", req.AccessRequest.Name) if ca == "always remux off" { require.Equal(t, "param=value", req.AccessRequest.Query) } else { require.Equal(t, "", req.AccessRequest.Query) } return &dummyPath{}, strm, nil }, } switch ca { case "always remux off": s := &Server{ Address: "127.0.0.1:8888", AlwaysRemux: false, Variant: conf.HLSVariant(gohlslib.MuxerVariantMPEGTS), SegmentCount: 7, SegmentDuration: conf.Duration(1 * time.Second), PartDuration: conf.Duration(200 * time.Millisecond), SegmentMaxSize: 50 * 1024 * 1024, TrustedProxies: conf.IPNetworks{}, ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), PathManager: pm, Parent: test.NilLogger, } err = s.Initialize() require.NoError(t, err) defer s.Close() c := &gohlslib.Client{ URI: "http://myuser:mypass@127.0.0.1:8888/teststream/index.m3u8?param=value", } recv1 := make(chan struct{}) recv2 := make(chan struct{}) c.OnTracks = func(tracks []*gohlslib.Track) error { //nolint:dupl require.Equal(t, []*gohlslib.Track{ { Codec: &codecs.H264{}, ClockRate: 90000, }, { Codec: &codecs.MPEG4Audio{ Config: mpeg4audio.AudioSpecificConfig{ Type: 2, ChannelCount: 2, SampleRate: 44100, }, }, ClockRate: 90000, }, }, tracks) c.OnDataH26x(tracks[0], func(pts, dts int64, au [][]byte) { require.Equal(t, int64(0), pts) require.Equal(t, int64(0), dts) require.Equal(t, [][]byte{ test.FormatH264.SPS, test.FormatH264.PPS, {5, 1}, }, au) close(recv1) }) c.OnDataMPEG4Audio(tracks[1], func(pts int64, aus [][]byte) { require.Equal(t, int64(0), pts) require.Equal(t, [][]byte{{1, 2}}, aus) close(recv2) }) return nil } err = c.Start() require.NoError(t, err) defer c.Close() time.Sleep(100 * time.Millisecond) for i := range 4 { strm.WriteUnit(test.MediaH264, test.FormatH264, &unit.Unit{ NTP: time.Time{}, PTS: int64(i) * 90000, Payload: unit.PayloadH264{ {5, 1}, // IDR }, }) strm.WriteUnit(test.MediaMPEG4Audio, test.FormatMPEG4Audio, &unit.Unit{ NTP: time.Time{}, PTS: int64(i) * 44100, Payload: unit.PayloadMPEG4Audio{{1, 2}}, }) } <-recv1 <-recv2 case "always remux on": s := &Server{ Address: "127.0.0.1:8888", AlwaysRemux: true, Variant: conf.HLSVariant(gohlslib.MuxerVariantMPEGTS), SegmentCount: 7, SegmentDuration: conf.Duration(1 * time.Second), PartDuration: conf.Duration(200 * time.Millisecond), SegmentMaxSize: 50 * 1024 * 1024, TrustedProxies: conf.IPNetworks{}, ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), PathManager: pm, Parent: test.NilLogger, } err = s.Initialize() require.NoError(t, err) defer s.Close() s.PathReady(&dummyPath{}) time.Sleep(500 * time.Millisecond) for i := range 4 { strm.WriteUnit(test.MediaH264, test.FormatH264, &unit.Unit{ NTP: time.Time{}, PTS: int64(i) * 90000, Payload: unit.PayloadH264{ {5, 1}, // IDR }, }) strm.WriteUnit(test.MediaMPEG4Audio, test.FormatMPEG4Audio, &unit.Unit{ NTP: time.Time{}, PTS: int64(i) * 44100, Payload: unit.PayloadMPEG4Audio{{1, 2}}, }) } time.Sleep(100 * time.Millisecond) c := &gohlslib.Client{ URI: "http://myuser:mypass@127.0.0.1:8888/teststream/index.m3u8?param=value", } recv1 := make(chan struct{}) recv2 := make(chan struct{}) c.OnTracks = func(tracks []*gohlslib.Track) error { //nolint:dupl require.Equal(t, []*gohlslib.Track{ { Codec: &codecs.H264{}, ClockRate: 90000, }, { Codec: &codecs.MPEG4Audio{ Config: mpeg4audio.AudioSpecificConfig{ Type: 2, ChannelCount: 2, SampleRate: 44100, }, }, ClockRate: 90000, }, }, tracks) c.OnDataH26x(tracks[0], func(pts, dts int64, au [][]byte) { require.Equal(t, int64(0), pts) require.Equal(t, int64(0), dts) require.Equal(t, [][]byte{ test.FormatH264.SPS, test.FormatH264.PPS, {5, 1}, }, au) close(recv1) }) c.OnDataMPEG4Audio(tracks[1], func(pts int64, aus [][]byte) { require.Equal(t, int64(0), pts) require.Equal(t, [][]byte{{1, 2}}, aus) close(recv2) }) return nil } err = c.Start() require.NoError(t, err) defer c.Close() <-recv1 <-recv2 } }) } } func TestServerDirectory(t *testing.T) { dir, err := os.MkdirTemp("", "mediamtx-playback") require.NoError(t, err) defer os.RemoveAll(dir) desc := &description.Session{Medias: []*description.Media{test.MediaH264}} strm := &stream.Stream{ WriteQueueSize: 512, RTPMaxPayloadSize: 1450, Desc: desc, GenerateRTPPackets: true, Parent: test.NilLogger, } err = strm.Initialize() require.NoError(t, err) pm := &dummyPathManager{ addReaderImpl: func(_ defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { return &dummyPath{}, strm, nil }, } s := &Server{ Address: "127.0.0.1:8888", Encryption: false, ServerKey: "", ServerCert: "", AlwaysRemux: true, Variant: conf.HLSVariant(gohlslib.MuxerVariantMPEGTS), SegmentCount: 7, SegmentDuration: conf.Duration(1 * time.Second), PartDuration: conf.Duration(200 * time.Millisecond), SegmentMaxSize: 50 * 1024 * 1024, TrustedProxies: conf.IPNetworks{}, Directory: filepath.Join(dir, "mydir"), ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), PathManager: pm, Parent: test.NilLogger, } err = s.Initialize() require.NoError(t, err) defer s.Close() s.PathReady(&dummyPath{}) time.Sleep(100 * time.Millisecond) _, err = os.Stat(filepath.Join(dir, "mydir", "teststream")) require.NoError(t, err) } func TestServerDynamicAlwaysRemux(t *testing.T) { desc := &description.Session{Medias: []*description.Media{test.MediaH264}} strm := &stream.Stream{ WriteQueueSize: 512, RTPMaxPayloadSize: 1450, Desc: desc, GenerateRTPPackets: true, Parent: test.NilLogger, } err := strm.Initialize() require.NoError(t, err) done := make(chan struct{}) pm := &dummyPathManager{ setHLSServerImpl: func() []defs.Path { return []defs.Path{&dummyPath{}} }, addReaderImpl: func(_ defs.PathAddReaderReq) (defs.Path, *stream.Stream, error) { close(done) return &dummyPath{}, strm, nil }, } s := &Server{ Address: "127.0.0.1:8888", Encryption: false, ServerKey: "", ServerCert: "", AlwaysRemux: true, Variant: conf.HLSVariant(gohlslib.MuxerVariantMPEGTS), SegmentCount: 7, SegmentDuration: conf.Duration(1 * time.Second), PartDuration: conf.Duration(200 * time.Millisecond), SegmentMaxSize: 50 * 1024 * 1024, ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), PathManager: pm, Parent: test.NilLogger, } err = s.Initialize() require.NoError(t, err) defer s.Close() <-done } func TestAuthError(t *testing.T) { n := 0 s := &Server{ Address: "127.0.0.1:8888", Encryption: false, ServerKey: "", ServerCert: "", AlwaysRemux: true, Variant: conf.HLSVariant(gohlslib.MuxerVariantMPEGTS), SegmentCount: 7, SegmentDuration: conf.Duration(1 * time.Second), PartDuration: conf.Duration(200 * time.Millisecond), SegmentMaxSize: 50 * 1024 * 1024, ReadTimeout: conf.Duration(10 * time.Second), WriteTimeout: conf.Duration(10 * time.Second), PathManager: &dummyPathManager{ findPathConfImpl: func(req defs.PathFindPathConfReq) (*conf.Path, error) { if req.AccessRequest.Credentials.User == "" && req.AccessRequest.Credentials.Pass == "" { return nil, &auth.Error{AskCredentials: true} } return nil, &auth.Error{Wrapped: fmt.Errorf("auth error")} }, }, Parent: test.Logger(func(l logger.Level, s string, i ...any) { if l == logger.Info { if n == 1 { require.Regexp(t, "failed to authenticate: auth error$", fmt.Sprintf(s, i...)) } n++ } }), } err := s.Initialize() require.NoError(t, err) defer s.Close() req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8888/stream/index.m3u8", nil) require.NoError(t, err) res, err := http.DefaultClient.Do(req) require.NoError(t, err) defer res.Body.Close() require.Equal(t, http.StatusUnauthorized, res.StatusCode) require.Equal(t, `Basic realm="mediamtx"`, res.Header.Get("WWW-Authenticate")) req, err = http.NewRequest(http.MethodGet, "http://myuser:mypass@127.0.0.1:8888/stream/index.m3u8", nil) require.NoError(t, err) start := time.Now() res, err = http.DefaultClient.Do(req) require.NoError(t, err) defer res.Body.Close() require.Greater(t, time.Since(start), 2*time.Second) require.Equal(t, http.StatusUnauthorized, res.StatusCode) require.Equal(t, 2, n) }