support publishing, reading, proxying with SRT (#2068)

This commit is contained in:
Alessandro Ros 2023-07-31 21:20:09 +02:00 committed by GitHub
parent d696a782f7
commit b4e3033ea3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 2184 additions and 213 deletions

View file

@ -159,6 +159,10 @@ type Conf struct {
WebRTCICEUDPMuxAddress string `json:"webrtcICEUDPMuxAddress"`
WebRTCICETCPMuxAddress string `json:"webrtcICETCPMuxAddress"`
// SRT
SRT bool `json:"srt"`
SRTAddress string `json:"srtAddress"`
// paths
Paths map[string]*PathConf `json:"paths"`
}
@ -336,6 +340,10 @@ func (conf *Conf) UnmarshalJSON(b []byte) error {
conf.WebRTCAllowOrigin = "*"
conf.WebRTCICEServers2 = []WebRTCICEServer{{URL: "stun:stun.l.google.com:19302"}}
// SRT
conf.SRT = true
conf.SRTAddress = ":8890"
type alias Conf
d := json.NewDecoder(bytes.NewReader(b))
d.DisallowUnknownFields()

View file

@ -210,6 +210,16 @@ func (pconf *PathConf) check(conf *Conf, name string) error {
return fmt.Errorf("'%s' is not a valid IP", host)
}
case strings.HasPrefix(pconf.Source, "srt://"):
if pconf.Regexp != nil {
return fmt.Errorf("a path with a regular expression (or path 'all') cannot have a SRT source. use another path")
}
_, err := gourl.Parse(pconf.Source)
if err != nil {
return fmt.Errorf("'%s' is not a valid HLS URL", pconf.Source)
}
case pconf.Source == "redirect":
if pconf.SourceRedirect == "" {
return fmt.Errorf("source redirect must be filled")
@ -337,6 +347,7 @@ func (pconf PathConf) HasStaticSource() bool {
strings.HasPrefix(pconf.Source, "http://") ||
strings.HasPrefix(pconf.Source, "https://") ||
strings.HasPrefix(pconf.Source, "udp://") ||
strings.HasPrefix(pconf.Source, "srt://") ||
pconf.Source == "rpiCamera"
}

View file

@ -180,6 +180,12 @@ type apiWebRTCManager interface {
apiSessionsKick(uuid.UUID) error
}
type apiSRTServer interface {
apiConnsList() (*apiSRTConnsList, error)
apiConnsGet(uuid.UUID) (*apiSRTConn, error)
apiConnsKick(uuid.UUID) error
}
type apiParent interface {
logger.Writer
apiConfigSet(conf *conf.Conf)
@ -194,6 +200,7 @@ type api struct {
rtmpsServer apiRTMPServer
hlsManager apiHLSManager
webRTCManager apiWebRTCManager
srtServer apiSRTServer
parent apiParent
httpServer *httpserv.WrappedServer
@ -211,6 +218,7 @@ func newAPI(
rtmpsServer apiRTMPServer,
hlsManager apiHLSManager,
webRTCManager apiWebRTCManager,
srtServer apiSRTServer,
parent apiParent,
) (*api, error) {
a := &api{
@ -222,6 +230,7 @@ func newAPI(
rtmpsServer: rtmpsServer,
hlsManager: hlsManager,
webRTCManager: webRTCManager,
srtServer: srtServer,
parent: parent,
}
@ -280,6 +289,12 @@ func newAPI(
group.POST("/v2/webrtcsessions/kick/:id", a.onWebRTCSessionsKick)
}
if !interfaceIsEmpty(a.srtServer) {
group.GET("/v2/srtconns/list", a.onSRTConnsList)
group.GET("/v2/srtconns/get/:id", a.onSRTConnsGet)
group.POST("/v2/srtconns/kick/:id", a.onSRTConnsKick)
}
network, address := restrictNetwork("tcp", address)
var err error
@ -853,6 +868,56 @@ func (a *api) onWebRTCSessionsKick(ctx *gin.Context) {
ctx.Status(http.StatusOK)
}
func (a *api) onSRTConnsList(ctx *gin.Context) {
data, err := a.srtServer.apiConnsList()
if err != nil {
ctx.AbortWithStatus(http.StatusInternalServerError)
return
}
data.ItemCount = len(data.Items)
pageCount, err := paginate(&data.Items, ctx.Query("itemsPerPage"), ctx.Query("page"))
if err != nil {
ctx.AbortWithStatus(http.StatusBadRequest)
return
}
data.PageCount = pageCount
ctx.JSON(http.StatusOK, data)
}
func (a *api) onSRTConnsGet(ctx *gin.Context) {
uuid, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.AbortWithStatus(http.StatusBadRequest)
return
}
data, err := a.srtServer.apiConnsGet(uuid)
if err != nil {
abortWithError(ctx, err)
return
}
ctx.JSON(http.StatusOK, data)
}
func (a *api) onSRTConnsKick(ctx *gin.Context) {
uuid, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.AbortWithStatus(http.StatusBadRequest)
return
}
err = a.srtServer.apiConnsKick(uuid)
if err != nil {
abortWithError(ctx, err)
return
}
ctx.Status(http.StatusOK)
}
// confReload is called by core.
func (a *api) confReload(conf *conf.Conf) {
a.mutex.Lock()

View file

@ -104,3 +104,19 @@ type apiWebRTCSessionsList struct {
PageCount int `json:"pageCount"`
Items []*apiWebRTCSession `json:"items"`
}
type apiSRTConn struct {
ID uuid.UUID `json:"id"`
Created time.Time `json:"created"`
RemoteAddr string `json:"remoteAddr"`
State string `json:"state"`
Path string `json:"path"`
BytesReceived uint64 `json:"bytesReceived"`
BytesSent uint64 `json:"bytesSent"`
}
type apiSRTConnsList struct {
ItemCount int `json:"itemCount"`
PageCount int `json:"pageCount"`
Items []*apiSRTConn `json:"items"`
}

View file

@ -1,6 +1,7 @@
package core
import (
"bufio"
"bytes"
"crypto/tls"
"encoding/json"
@ -17,6 +18,8 @@ import (
"github.com/bluenviron/gortsplib/v3/pkg/formats"
"github.com/bluenviron/gortsplib/v3/pkg/media"
"github.com/bluenviron/mediacommon/pkg/codecs/mpeg4audio"
"github.com/bluenviron/mediacommon/pkg/formats/mpegts"
"github.com/datarhei/gosrt"
"github.com/google/uuid"
"github.com/pion/rtp"
"github.com/stretchr/testify/require"
@ -509,6 +512,7 @@ func TestAPIProtocolList(t *testing.T) {
"rtmps",
"hls",
"webrtc",
"srt",
} {
t.Run(ca, func(t *testing.T) {
conf := "api: yes\n"
@ -663,10 +667,33 @@ func TestAPIProtocolList(t *testing.T) {
})
<-c.incomingTrack
case "srt":
conf := srt.DefaultConfig()
conf.StreamId = "publish:mypath"
conn, err := srt.Dial("srt", "localhost:8890", conf)
require.NoError(t, err)
defer conn.Close()
track := &mpegts.Track{
PID: 256,
Codec: &mpegts.CodecH264{},
}
bw := bufio.NewWriter(conn)
w := mpegts.NewWriter(bw, []*mpegts.Track{track})
require.NoError(t, err)
err = w.WriteH26x(track, 0, 0, true, [][]byte{{1}})
require.NoError(t, err)
bw.Flush()
time.Sleep(500 * time.Millisecond)
}
switch ca {
case "rtsp conns", "rtsp sessions", "rtsps conns", "rtsps sessions", "rtmp", "rtmps":
case "rtsp conns", "rtsp sessions", "rtsps conns", "rtsps sessions", "rtmp", "rtmps", "srt":
var pa string
switch ca {
case "rtsp conns":
@ -686,6 +713,9 @@ func TestAPIProtocolList(t *testing.T) {
case "rtmps":
pa = "rtmpsconns"
case "srt":
pa = "srtconns"
}
type item struct {
@ -763,6 +793,7 @@ func TestAPIProtocolGet(t *testing.T) {
"rtmps",
"hls",
"webrtc",
"srt",
} {
t.Run(ca, func(t *testing.T) {
conf := "api: yes\n"
@ -917,10 +948,33 @@ func TestAPIProtocolGet(t *testing.T) {
})
<-c.incomingTrack
case "srt":
conf := srt.DefaultConfig()
conf.StreamId = "publish:mypath"
conn, err := srt.Dial("srt", "localhost:8890", conf)
require.NoError(t, err)
defer conn.Close()
track := &mpegts.Track{
PID: 256,
Codec: &mpegts.CodecH264{},
}
bw := bufio.NewWriter(conn)
w := mpegts.NewWriter(bw, []*mpegts.Track{track})
require.NoError(t, err)
err = w.WriteH26x(track, 0, 0, true, [][]byte{{1}})
require.NoError(t, err)
bw.Flush()
time.Sleep(500 * time.Millisecond)
}
switch ca {
case "rtsp conns", "rtsp sessions", "rtsps conns", "rtsps sessions", "rtmp", "rtmps":
case "rtsp conns", "rtsp sessions", "rtsps conns", "rtsps sessions", "rtmp", "rtmps", "srt":
var pa string
switch ca {
case "rtsp conns":
@ -940,6 +994,9 @@ func TestAPIProtocolGet(t *testing.T) {
case "rtmps":
pa = "rtmpsconns"
case "srt":
pa = "srtconns"
}
type item struct {
@ -1020,6 +1077,7 @@ func TestAPIProtocolGetNotFound(t *testing.T) {
"rtmps",
"hls",
"webrtc",
"srt",
} {
t.Run(ca, func(t *testing.T) {
conf := "api: yes\n"
@ -1071,6 +1129,9 @@ func TestAPIProtocolGetNotFound(t *testing.T) {
case "webrtc":
pa = "webrtcsessions"
case "srt":
pa = "srtconns"
}
func() {
@ -1100,6 +1161,7 @@ func TestAPIProtocolKick(t *testing.T) {
"rtsps",
"rtmp",
"webrtc",
"srt",
} {
t.Run(ca, func(t *testing.T) {
conf := "api: yes\n"
@ -1158,6 +1220,29 @@ func TestAPIProtocolKick(t *testing.T) {
case "webrtc":
c := newWebRTCTestClient(t, hc, "http://localhost:8889/mypath/whip", true)
defer c.close()
case "srt":
conf := srt.DefaultConfig()
conf.StreamId = "publish:mypath"
conn, err := srt.Dial("srt", "localhost:8890", conf)
require.NoError(t, err)
defer conn.Close()
track := &mpegts.Track{
PID: 256,
Codec: &mpegts.CodecH264{},
}
bw := bufio.NewWriter(conn)
w := mpegts.NewWriter(bw, []*mpegts.Track{track})
require.NoError(t, err)
err = w.WriteH26x(track, 0, 0, true, [][]byte{{1}})
require.NoError(t, err)
bw.Flush()
// time.Sleep(500 * time.Millisecond)
}
var pa string
@ -1173,6 +1258,9 @@ func TestAPIProtocolKick(t *testing.T) {
case "webrtc":
pa = "webrtcsessions"
case "srt":
pa = "srtconns"
}
var out1 struct {
@ -1209,6 +1297,7 @@ func TestAPIProtocolKickNotFound(t *testing.T) {
"rtsps",
"rtmp",
"webrtc",
"srt",
} {
t.Run(ca, func(t *testing.T) {
conf := "api: yes\n"
@ -1242,6 +1331,9 @@ func TestAPIProtocolKickNotFound(t *testing.T) {
case "webrtc":
pa = "webrtcsessions"
case "srt":
pa = "srtconns"
}
func() {

View file

@ -50,6 +50,7 @@ const (
authProtocolRTMP authProtocol = "rtmp"
authProtocolHLS authProtocol = "hls"
authProtocolWebRTC authProtocol = "webrtc"
authProtocolSRT authProtocol = "srt"
)
type authCredentials struct {

View file

@ -44,6 +44,7 @@ type Core struct {
rtmpsServer *rtmpServer
hlsManager *hlsManager
webRTCManager *webRTCManager
srtServer *srtServer
api *api
confWatcher *confwatcher.ConfWatcher
@ -432,6 +433,23 @@ func (p *Core) createResources(initial bool) error {
}
}
if p.conf.SRT {
if p.srtServer == nil {
p.srtServer, err = newSRTServer(
p.conf.SRTAddress,
p.conf.ReadTimeout,
p.conf.WriteTimeout,
p.conf.ReadBufferCount,
p.conf.UDPMaxPayloadSize,
p.pathManager,
p,
)
if err != nil {
return err
}
}
}
if p.conf.API {
if p.api == nil {
p.api, err = newAPI(
@ -445,6 +463,7 @@ func (p *Core) createResources(initial bool) error {
p.rtmpsServer,
p.hlsManager,
p.webRTCManager,
p.srtServer,
p,
)
if err != nil {
@ -595,6 +614,15 @@ func (p *Core) closeResources(newConf *conf.Conf, calledByAPI bool) {
newConf.WebRTCICEUDPMuxAddress != p.conf.WebRTCICEUDPMuxAddress ||
newConf.WebRTCICETCPMuxAddress != p.conf.WebRTCICETCPMuxAddress
closeSRTServer := newConf == nil ||
newConf.SRT != p.conf.SRT ||
newConf.SRTAddress != p.conf.SRTAddress ||
newConf.ReadTimeout != p.conf.ReadTimeout ||
newConf.WriteTimeout != p.conf.WriteTimeout ||
newConf.ReadBufferCount != p.conf.ReadBufferCount ||
newConf.UDPMaxPayloadSize != p.conf.UDPMaxPayloadSize ||
closePathManager
closeAPI := newConf == nil ||
newConf.API != p.conf.API ||
newConf.APIAddress != p.conf.APIAddress ||
@ -604,7 +632,8 @@ func (p *Core) closeResources(newConf *conf.Conf, calledByAPI bool) {
closeRTSPSServer ||
closeRTMPServer ||
closeHLSManager ||
closeWebRTCManager
closeWebRTCManager ||
closeSRTServer
if newConf == nil && p.confWatcher != nil {
p.confWatcher.Close()
@ -620,6 +649,11 @@ func (p *Core) closeResources(newConf *conf.Conf, calledByAPI bool) {
}
}
if closeSRTServer && p.srtServer != nil {
p.srtServer.close()
p.srtServer = nil
}
if closeWebRTCManager && p.webRTCManager != nil {
p.webRTCManager.close()
p.webRTCManager = nil

View file

@ -8,18 +8,33 @@ import (
"net/http"
"testing"
"github.com/asticode/go-astits"
"github.com/bluenviron/gortsplib/v3"
"github.com/bluenviron/gortsplib/v3/pkg/formats"
"github.com/bluenviron/gortsplib/v3/pkg/media"
"github.com/bluenviron/gortsplib/v3/pkg/url"
"github.com/bluenviron/mediacommon/pkg/codecs/h264"
"github.com/bluenviron/mediacommon/pkg/codecs/mpeg4audio"
"github.com/bluenviron/mediacommon/pkg/formats/mpegts"
"github.com/gin-gonic/gin"
"github.com/pion/rtp"
"github.com/stretchr/testify/require"
)
var track1 = &mpegts.Track{
PID: 256,
Codec: &mpegts.CodecH264{},
}
var track2 = &mpegts.Track{
PID: 257,
Codec: &mpegts.CodecMPEG4Audio{
Config: mpeg4audio.Config{
Type: 2,
SampleRate: 44100,
ChannelCount: 2,
},
},
}
type testHLSManager struct {
s *http.Server
@ -71,131 +86,29 @@ segment2.ts
func (ts *testHLSManager) onSegment1(ctx *gin.Context) {
ctx.Writer.Header().Set("Content-Type", `video/MP2T`)
mux := astits.NewMuxer(context.Background(), ctx.Writer)
mux.AddElementaryStream(astits.PMTElementaryStream{
ElementaryPID: 256,
StreamType: astits.StreamTypeH264Video,
})
w := mpegts.NewWriter(ctx.Writer, []*mpegts.Track{track1, track2})
mux.AddElementaryStream(astits.PMTElementaryStream{
ElementaryPID: 257,
StreamType: astits.StreamTypeAACAudio,
})
mux.SetPCRPID(256)
mux.WriteTables()
pkts := mpeg4audio.ADTSPackets{
{
Type: 2,
SampleRate: 44100,
ChannelCount: 2,
AU: []byte{0x01, 0x02, 0x03, 0x04},
},
}
enc, _ := pkts.Marshal()
mux.WriteData(&astits.MuxerData{
PID: 257,
PES: &astits.PESData{
Header: &astits.PESHeader{
OptionalHeader: &astits.PESOptionalHeader{
MarkerBits: 2,
PTSDTSIndicator: astits.PTSDTSIndicatorOnlyPTS,
PTS: &astits.ClockReference{Base: int64(1 * 90000)},
},
StreamID: 192,
},
Data: enc,
},
})
w.WriteMPEG4Audio(track2, 1*90000, [][]byte{{1, 2, 3, 4}})
}
func (ts *testHLSManager) onSegment2(ctx *gin.Context) {
<-ts.clientConnected
ctx.Writer.Header().Set("Content-Type", `video/MP2T`)
mux := astits.NewMuxer(context.Background(), ctx.Writer)
mux.AddElementaryStream(astits.PMTElementaryStream{
ElementaryPID: 256,
StreamType: astits.StreamTypeH264Video,
})
w := mpegts.NewWriter(ctx.Writer, []*mpegts.Track{track1, track2})
mux.AddElementaryStream(astits.PMTElementaryStream{
ElementaryPID: 257,
StreamType: astits.StreamTypeAACAudio,
})
mux.SetPCRPID(256)
mux.WriteTables()
enc, _ := h264.AnnexBMarshal([][]byte{
w.WriteH26x(track1, 2*90000, 2*90000, true, [][]byte{
{7, 1, 2, 3}, // SPS
{8}, // PPS
})
mux.WriteData(&astits.MuxerData{
PID: 256,
PES: &astits.PESData{
Header: &astits.PESHeader{
OptionalHeader: &astits.PESOptionalHeader{
MarkerBits: 2,
PTSDTSIndicator: astits.PTSDTSIndicatorOnlyPTS,
PTS: &astits.ClockReference{Base: int64(2 * 90000)},
},
StreamID: 224, // = video
},
Data: enc,
},
})
w.WriteMPEG4Audio(track2, 2*90000, [][]byte{{1, 2, 3, 4}})
pkts := mpeg4audio.ADTSPackets{
{
Type: 2,
SampleRate: 44100,
ChannelCount: 2,
AU: []byte{0x01, 0x02, 0x03, 0x04},
},
}
enc, _ = pkts.Marshal()
mux.WriteData(&astits.MuxerData{
PID: 257,
PES: &astits.PESData{
Header: &astits.PESHeader{
OptionalHeader: &astits.PESOptionalHeader{
MarkerBits: 2,
PTSDTSIndicator: astits.PTSDTSIndicatorOnlyPTS,
PTS: &astits.ClockReference{Base: int64(1 * 90000)},
},
StreamID: 192,
},
Data: enc,
},
})
enc, _ = h264.AnnexBMarshal([][]byte{
w.WriteH26x(track1, 2*90000, 2*90000, true, [][]byte{
{5}, // IDR
})
mux.WriteData(&astits.MuxerData{
PID: 256,
PES: &astits.PESData{
Header: &astits.PESHeader{
OptionalHeader: &astits.PESOptionalHeader{
MarkerBits: 2,
PTSDTSIndicator: astits.PTSDTSIndicatorOnlyPTS,
PTS: &astits.ClockReference{Base: int64(2 * 90000)},
},
StreamID: 224, // = video
},
Data: enc,
},
})
}
func TestHLSSource(t *testing.T) {

View file

@ -41,8 +41,7 @@ func pathNameAndQuery(inURL *url.URL) (string, url.Values, string) {
type rtmpConnState int
const (
rtmpConnStateIdle rtmpConnState = iota //nolint:deadcode,varcheck
rtmpConnStateRead
rtmpConnStateRead rtmpConnState = iota + 1
rtmpConnStatePublish
)
@ -756,8 +755,10 @@ func (c *rtmpConn) apiItem() *apiRTMPConn {
case rtmpConnStatePublish:
return "publish"
default:
return "idle"
}
return "idle"
}(),
Path: c.pathName,
BytesReceived: bytesReceived,

View file

@ -86,6 +86,11 @@ func newSourceStatic(
readTimeout,
s)
case strings.HasPrefix(cnf.Source, "srt://"):
s.impl = newSRTSource(
readTimeout,
s)
case cnf.Source == "rpiCamera":
s.impl = newRPICameraSource(
s)

810
internal/core/srt_conn.go Normal file
View file

@ -0,0 +1,810 @@
package core
import (
"bufio"
"context"
"errors"
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/bluenviron/gortsplib/v3/pkg/formats"
"github.com/bluenviron/gortsplib/v3/pkg/media"
"github.com/bluenviron/gortsplib/v3/pkg/ringbuffer"
"github.com/bluenviron/mediacommon/pkg/codecs/h264"
"github.com/bluenviron/mediacommon/pkg/codecs/h265"
"github.com/bluenviron/mediacommon/pkg/formats/mpegts"
"github.com/datarhei/gosrt"
"github.com/google/uuid"
"github.com/bluenviron/mediamtx/internal/conf"
"github.com/bluenviron/mediamtx/internal/formatprocessor"
"github.com/bluenviron/mediamtx/internal/logger"
"github.com/bluenviron/mediamtx/internal/stream"
)
func durationGoToMPEGTS(v time.Duration) int64 {
return int64(v.Seconds() * 90000)
}
func h265RandomAccessPresent(au [][]byte) bool {
for _, nalu := range au {
typ := h265.NALUType((nalu[0] >> 1) & 0b111111)
switch typ {
case h265.NALUType_IDR_W_RADL, h265.NALUType_IDR_N_LP, h265.NALUType_CRA_NUT:
return true
}
}
return false
}
type srtConnState int
const (
srtConnStateRead srtConnState = iota + 1
srtConnStatePublish
)
type srtConnPathManager interface {
addReader(req pathAddReaderReq) pathAddReaderRes
addPublisher(req pathAddPublisherReq) pathAddPublisherRes
}
type srtConnParent interface {
logger.Writer
closeConn(*srtConn)
}
type srtConn struct {
readTimeout conf.StringDuration
writeTimeout conf.StringDuration
readBufferCount int
udpMaxPayloadSize int
connReq srt.ConnRequest
wg *sync.WaitGroup
pathManager srtConnPathManager
parent srtConnParent
ctx context.Context
ctxCancel func()
created time.Time
uuid uuid.UUID
mutex sync.RWMutex
state srtConnState
pathName string
conn srt.Conn
chNew chan srtNewConnReq
chSetConn chan srt.Conn
}
func newSRTConn(
parentCtx context.Context,
readTimeout conf.StringDuration,
writeTimeout conf.StringDuration,
readBufferCount int,
udpMaxPayloadSize int,
connReq srt.ConnRequest,
wg *sync.WaitGroup,
pathManager srtConnPathManager,
parent srtConnParent,
) *srtConn {
ctx, ctxCancel := context.WithCancel(parentCtx)
c := &srtConn{
readTimeout: readTimeout,
writeTimeout: writeTimeout,
readBufferCount: readBufferCount,
udpMaxPayloadSize: udpMaxPayloadSize,
connReq: connReq,
wg: wg,
pathManager: pathManager,
parent: parent,
ctx: ctx,
ctxCancel: ctxCancel,
created: time.Now(),
uuid: uuid.New(),
chNew: make(chan srtNewConnReq),
chSetConn: make(chan srt.Conn),
}
c.Log(logger.Info, "opened")
c.wg.Add(1)
go c.run()
return c
}
func (c *srtConn) close() {
c.ctxCancel()
}
func (c *srtConn) Log(level logger.Level, format string, args ...interface{}) {
c.parent.Log(level, "[conn %v] "+format, append([]interface{}{c.connReq.RemoteAddr()}, args...)...)
}
func (c *srtConn) ip() net.IP {
return c.connReq.RemoteAddr().(*net.UDPAddr).IP
}
func (c *srtConn) run() {
defer c.wg.Done()
err := c.runInner()
c.ctxCancel()
c.parent.closeConn(c)
c.Log(logger.Info, "closed (%v)", err)
}
func (c *srtConn) runInner() error {
var req srtNewConnReq
select {
case req = <-c.chNew:
case <-c.ctx.Done():
return errors.New("terminated")
}
answerSent, err := c.runInner2(req)
if !answerSent {
req.res <- nil
}
return err
}
func (c *srtConn) runInner2(req srtNewConnReq) (bool, error) {
parts := strings.Split(req.connReq.StreamId(), ":")
if (len(parts) != 2 && len(parts) != 4) || (parts[0] != "read" && parts[0] != "publish") {
return false, fmt.Errorf("invalid streamid '%s':"+
" it must be 'action:pathname' or 'action:pathname:user:pass', "+
"where action is either read or publish, pathname is the path name, user and pass are the credentials",
req.connReq.StreamId())
}
pathName := parts[1]
user := ""
pass := ""
if len(parts) == 4 {
user, pass = parts[2], parts[3]
}
if parts[0] == "publish" {
return c.runPublish(req, pathName, user, pass)
}
return c.runRead(req, pathName, user, pass)
}
func (c *srtConn) runPublish(req srtNewConnReq, pathName string, user string, pass string) (bool, error) {
res := c.pathManager.addPublisher(pathAddPublisherReq{
author: c,
pathName: pathName,
credentials: authCredentials{
ip: c.ip(),
user: user,
pass: pass,
proto: authProtocolSRT,
id: &c.uuid,
},
})
if res.err != nil {
if terr, ok := res.err.(*errAuthentication); ok {
// TODO: re-enable. Currently this freezes the listener.
// wait some seconds to stop brute force attacks
// <-time.After(srtPauseAfterAuthError)
return false, terr
}
return false, res.err
}
defer res.path.removePublisher(pathRemovePublisherReq{author: c})
sconn, err := c.exchangeRequestWithConn(req)
if err != nil {
return true, err
}
c.mutex.Lock()
c.state = srtConnStatePublish
c.pathName = pathName
c.conn = sconn
c.mutex.Unlock()
readerErr := make(chan error)
go func() {
readerErr <- c.runPublishReader(sconn, res.path)
}()
select {
case err := <-readerErr:
sconn.Close()
return true, err
case <-c.ctx.Done():
sconn.Close()
<-readerErr
return true, errors.New("terminated")
}
}
func (c *srtConn) runPublishReader(sconn srt.Conn, path *path) error {
sconn.SetReadDeadline(time.Now().Add(time.Duration(c.readTimeout)))
r, err := mpegts.NewReader(mpegts.NewBufferedReader(sconn))
if err != nil {
return err
}
var medias media.Medias
var stream *stream.Stream
var td *mpegts.TimeDecoder
decodeTime := func(t int64) time.Duration {
if td == nil {
td = mpegts.NewTimeDecoder(t)
}
return td.Decode(t)
}
for _, track := range r.Tracks() { //nolint:dupl
var medi *media.Media
switch tcodec := track.Codec.(type) {
case *mpegts.CodecH264:
medi = &media.Media{
Type: media.TypeVideo,
Formats: []formats.Format{&formats.H264{
PayloadTyp: 96,
PacketizationMode: 1,
}},
}
r.OnDataH26x(track, func(pts int64, _ int64, au [][]byte) error {
stream.WriteUnit(medi, medi.Formats[0], &formatprocessor.UnitH264{
BaseUnit: formatprocessor.BaseUnit{
NTP: time.Now(),
},
PTS: decodeTime(pts),
AU: au,
})
return nil
})
case *mpegts.CodecH265:
medi = &media.Media{
Type: media.TypeVideo,
Formats: []formats.Format{&formats.H265{
PayloadTyp: 96,
}},
}
r.OnDataH26x(track, func(pts int64, _ int64, au [][]byte) error {
stream.WriteUnit(medi, medi.Formats[0], &formatprocessor.UnitH265{
BaseUnit: formatprocessor.BaseUnit{
NTP: time.Now(),
},
PTS: decodeTime(pts),
AU: au,
})
return nil
})
case *mpegts.CodecMPEG4Audio:
medi = &media.Media{
Type: media.TypeAudio,
Formats: []formats.Format{&formats.MPEG4Audio{
PayloadTyp: 96,
SizeLength: 13,
IndexLength: 3,
IndexDeltaLength: 3,
Config: &tcodec.Config,
}},
}
r.OnDataMPEG4Audio(track, func(pts int64, _ int64, aus [][]byte) error {
stream.WriteUnit(medi, medi.Formats[0], &formatprocessor.UnitMPEG4AudioGeneric{
BaseUnit: formatprocessor.BaseUnit{
NTP: time.Now(),
},
PTS: decodeTime(pts),
AUs: aus,
})
return nil
})
case *mpegts.CodecOpus:
medi = &media.Media{
Type: media.TypeAudio,
Formats: []formats.Format{&formats.Opus{
PayloadTyp: 96,
IsStereo: (tcodec.ChannelCount == 2),
}},
}
r.OnDataOpus(track, func(pts int64, _ int64, packets [][]byte) error {
stream.WriteUnit(medi, medi.Formats[0], &formatprocessor.UnitOpus{
BaseUnit: formatprocessor.BaseUnit{
NTP: time.Now(),
},
PTS: decodeTime(pts),
Packets: packets,
})
return nil
})
}
medias = append(medias, medi)
}
rres := path.startPublisher(pathStartPublisherReq{
author: c,
medias: medias,
generateRTPPackets: true,
})
if rres.err != nil {
return rres.err
}
c.Log(logger.Info, "is publishing to path '%s', %s",
path.name,
sourceMediaInfo(medias))
stream = rres.stream
for {
err := r.Read()
if err != nil {
return err
}
}
}
func (c *srtConn) runRead(req srtNewConnReq, pathName string, user string, pass string) (bool, error) {
res := c.pathManager.addReader(pathAddReaderReq{
author: c,
pathName: pathName,
credentials: authCredentials{
ip: c.ip(),
user: user,
pass: pass,
proto: authProtocolSRT,
id: &c.uuid,
},
})
if res.err != nil {
if terr, ok := res.err.(*errAuthentication); ok {
// TODO: re-enable. Currently this freezes the listener.
// wait some seconds to stop brute force attacks
// <-time.After(srtPauseAfterAuthError)
return false, terr
}
return false, res.err
}
defer res.path.removeReader(pathRemoveReaderReq{author: c})
sconn, err := c.exchangeRequestWithConn(req)
if err != nil {
return true, err
}
defer sconn.Close()
c.mutex.Lock()
c.state = srtConnStateRead
c.pathName = pathName
c.conn = sconn
c.mutex.Unlock()
ringBuffer, _ := ringbuffer.New(uint64(c.readBufferCount))
go func() {
<-c.ctx.Done()
ringBuffer.Close()
}()
var w *mpegts.Writer
nextPID := uint16(256)
var tracks []*mpegts.Track
var medias media.Medias
bw := bufio.NewWriterSize(sconn, srtMaxPayloadSize(c.udpMaxPayloadSize))
leadingTrackChosen := false
leadingTrackInitialized := false
var leadingTrackStartDTS time.Duration
for _, medi := range res.stream.Medias() {
for _, format := range medi.Formats {
switch format := format.(type) {
case *formats.H265:
track := &mpegts.Track{
PID: nextPID,
Codec: &mpegts.CodecH265{},
}
tracks = append(tracks, track)
medias = append(medias, medi)
nextPID++
var startPTS time.Duration
startPTSFilled := false
var isLeadingTrack bool
if !leadingTrackChosen {
isLeadingTrack = true
} else {
isLeadingTrack = false
}
randomAccessReceived := false
dtsExtractor := h265.NewDTSExtractor()
res.stream.AddReader(c, medi, format, func(unit formatprocessor.Unit) {
ringBuffer.Push(func() error {
tunit := unit.(*formatprocessor.UnitH265)
if tunit.AU == nil {
return nil
}
if !startPTSFilled {
startPTS = tunit.PTS
startPTSFilled = true
}
randomAccessPresent := h265RandomAccessPresent(tunit.AU)
if !randomAccessReceived {
if !randomAccessPresent {
return nil
}
randomAccessReceived = true
}
pts := tunit.PTS - startPTS
dts, err := dtsExtractor.Extract(tunit.AU, pts)
if err != nil {
return err
}
if !leadingTrackInitialized {
if isLeadingTrack {
leadingTrackStartDTS = dts
leadingTrackInitialized = true
} else {
return nil
}
}
dts -= leadingTrackStartDTS
pts -= leadingTrackStartDTS
sconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout)))
err = w.WriteH26x(track, durationGoToMPEGTS(pts), durationGoToMPEGTS(dts), randomAccessPresent, tunit.AU)
if err != nil {
return err
}
return bw.Flush()
})
})
case *formats.H264:
track := &mpegts.Track{
PID: nextPID,
Codec: &mpegts.CodecH264{},
}
tracks = append(tracks, track)
medias = append(medias, medi)
nextPID++
var startPTS time.Duration
startPTSFilled := false
var isLeadingTrack bool
if !leadingTrackChosen {
isLeadingTrack = true
} else {
isLeadingTrack = false
}
firstIDRReceived := false
dtsExtractor := h264.NewDTSExtractor()
res.stream.AddReader(c, medi, format, func(unit formatprocessor.Unit) {
ringBuffer.Push(func() error {
tunit := unit.(*formatprocessor.UnitH264)
if tunit.AU == nil {
return nil
}
if !startPTSFilled {
startPTS = tunit.PTS
startPTSFilled = true
}
idrPresent := h264.IDRPresent(tunit.AU)
if !firstIDRReceived {
if !idrPresent {
return nil
}
firstIDRReceived = true
}
pts := tunit.PTS - startPTS
dts, err := dtsExtractor.Extract(tunit.AU, pts)
if err != nil {
return err
}
if !leadingTrackInitialized {
if isLeadingTrack {
leadingTrackStartDTS = dts
leadingTrackInitialized = true
} else {
return nil
}
}
dts -= leadingTrackStartDTS
pts -= leadingTrackStartDTS
sconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout)))
err = w.WriteH26x(track, durationGoToMPEGTS(pts), durationGoToMPEGTS(dts), idrPresent, tunit.AU)
if err != nil {
return err
}
return bw.Flush()
})
})
case *formats.MPEG4AudioGeneric:
track := &mpegts.Track{
PID: nextPID,
Codec: &mpegts.CodecMPEG4Audio{
Config: *format.Config,
},
}
tracks = append(tracks, track)
medias = append(medias, medi)
nextPID++
var startPTS time.Duration
startPTSFilled := false
res.stream.AddReader(c, medi, format, func(unit formatprocessor.Unit) {
ringBuffer.Push(func() error {
tunit := unit.(*formatprocessor.UnitMPEG4AudioGeneric)
if tunit.AUs == nil {
return nil
}
if !startPTSFilled {
startPTS = tunit.PTS
startPTSFilled = true
}
if leadingTrackChosen && !leadingTrackInitialized {
return nil
}
pts := tunit.PTS
pts -= startPTS
pts -= leadingTrackStartDTS
sconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout)))
err = w.WriteMPEG4Audio(track, durationGoToMPEGTS(pts), tunit.AUs)
if err != nil {
return err
}
return bw.Flush()
})
})
case *formats.MPEG4AudioLATM:
if format.Config != nil &&
len(format.Config.Programs) == 1 &&
len(format.Config.Programs[0].Layers) == 1 {
track := &mpegts.Track{
PID: nextPID,
Codec: &mpegts.CodecMPEG4Audio{
Config: *format.Config.Programs[0].Layers[0].AudioSpecificConfig,
},
}
tracks = append(tracks, track)
medias = append(medias, medi)
nextPID++
var startPTS time.Duration
startPTSFilled := false
res.stream.AddReader(c, medi, format, func(unit formatprocessor.Unit) {
ringBuffer.Push(func() error {
tunit := unit.(*formatprocessor.UnitMPEG4AudioLATM)
if tunit.AU == nil {
return nil
}
if !startPTSFilled {
startPTS = tunit.PTS
startPTSFilled = true
}
if leadingTrackChosen && !leadingTrackInitialized {
return nil
}
pts := tunit.PTS
pts -= startPTS
pts -= leadingTrackStartDTS
sconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout)))
err = w.WriteMPEG4Audio(track, durationGoToMPEGTS(pts), [][]byte{tunit.AU})
if err != nil {
return err
}
return bw.Flush()
})
})
}
case *formats.Opus:
track := &mpegts.Track{
PID: nextPID,
Codec: &mpegts.CodecOpus{
ChannelCount: func() int {
if format.IsStereo {
return 2
}
return 1
}(),
},
}
tracks = append(tracks, track)
medias = append(medias, medi)
nextPID++
var startPTS time.Duration
startPTSFilled := false
res.stream.AddReader(c, medi, format, func(unit formatprocessor.Unit) {
ringBuffer.Push(func() error {
tunit := unit.(*formatprocessor.UnitOpus)
if tunit.Packets == nil {
return nil
}
if !startPTSFilled {
startPTS = tunit.PTS
startPTSFilled = true
}
if leadingTrackChosen && !leadingTrackInitialized {
return nil
}
pts := tunit.PTS
pts -= startPTS
pts -= leadingTrackStartDTS
sconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout)))
err = w.WriteOpus(track, durationGoToMPEGTS(pts), tunit.Packets)
if err != nil {
return err
}
return bw.Flush()
})
})
}
}
}
if len(tracks) == 0 {
return true, fmt.Errorf(
"the stream doesn't contain any supported codec, which are currently H265, H264, Opus, MPEG4-Audio")
}
c.Log(logger.Info, "is reading from path '%s', %s",
res.path.name, sourceMediaInfo(medias))
w = mpegts.NewWriter(bw, tracks)
// disable read deadline
sconn.SetReadDeadline(time.Time{})
for {
item, ok := ringBuffer.Pull()
if !ok {
return true, fmt.Errorf("terminated")
}
err := item.(func() error)()
if err != nil {
return true, err
}
}
}
func (c *srtConn) exchangeRequestWithConn(req srtNewConnReq) (srt.Conn, error) {
req.res <- c
select {
case sconn := <-c.chSetConn:
return sconn, nil
case <-c.ctx.Done():
return nil, errors.New("terminated")
}
}
// new is called by srtListener through srtServer.
func (c *srtConn) new(req srtNewConnReq) *srtConn {
select {
case c.chNew <- req:
return <-req.res
case <-c.ctx.Done():
return nil
}
}
// setConn is called by srtListener .
func (c *srtConn) setConn(sconn srt.Conn) {
select {
case c.chSetConn <- sconn:
case <-c.ctx.Done():
}
}
// apiReaderDescribe implements reader.
func (c *srtConn) apiReaderDescribe() pathAPISourceOrReader {
return pathAPISourceOrReader{
Type: "srtConn",
ID: c.uuid.String(),
}
}
// apiSourceDescribe implements source.
func (c *srtConn) apiSourceDescribe() pathAPISourceOrReader {
return c.apiReaderDescribe()
}
func (c *srtConn) apiItem() *apiSRTConn {
c.mutex.RLock()
defer c.mutex.RUnlock()
bytesReceived := uint64(0)
bytesSent := uint64(0)
if c.conn != nil {
var s srt.Statistics
c.conn.Stats(&s)
bytesReceived = s.Accumulated.ByteRecv
bytesSent = s.Accumulated.ByteSent
}
return &apiSRTConn{
ID: c.uuid,
Created: c.created,
RemoteAddr: c.connReq.RemoteAddr().String(),
State: func() string {
switch c.state {
case srtConnStateRead:
return "read"
case srtConnStatePublish:
return "publish"
default:
return "idle"
}
}(),
Path: c.pathName,
BytesReceived: bytesReceived,
BytesSent: bytesSent,
}
}

View file

@ -0,0 +1,60 @@
package core
import (
"sync"
"github.com/datarhei/gosrt"
)
type srtListener struct {
ln srt.Listener
wg *sync.WaitGroup
parent *srtServer
}
func newSRTListener(
ln srt.Listener,
wg *sync.WaitGroup,
parent *srtServer,
) *srtListener {
l := &srtListener{
ln: ln,
wg: wg,
parent: parent,
}
l.wg.Add(1)
go l.run()
return l
}
func (l *srtListener) run() {
defer l.wg.Done()
err := func() error {
for {
var sconn *srtConn
conn, _, err := l.ln.Accept(func(req srt.ConnRequest) srt.ConnType {
sconn = l.parent.newConnRequest(req)
if sconn == nil {
return srt.REJECT
}
// currently it's the same to return SUBSCRIBE or PUBLISH
return srt.SUBSCRIBE
})
if err != nil {
return err
}
if conn == nil {
continue
}
sconn.setConn(conn)
}
}()
l.parent.acceptError(err)
}

308
internal/core/srt_server.go Normal file
View file

@ -0,0 +1,308 @@
package core
import (
"context"
"fmt"
"sort"
"sync"
"time"
"github.com/datarhei/gosrt"
"github.com/google/uuid"
"github.com/bluenviron/mediamtx/internal/conf"
"github.com/bluenviron/mediamtx/internal/logger"
)
func srtMaxPayloadSize(u int) int {
return ((u - 16) / 188) * 188 // 16 = SRT header, 188 = MPEG-TS packet
}
type srtNewConnReq struct {
connReq srt.ConnRequest
res chan *srtConn
}
type srtServerAPIConnsListRes struct {
data *apiSRTConnsList
err error
}
type srtServerAPIConnsListReq struct {
res chan srtServerAPIConnsListRes
}
type srtServerAPIConnsGetRes struct {
data *apiSRTConn
err error
}
type srtServerAPIConnsGetReq struct {
uuid uuid.UUID
res chan srtServerAPIConnsGetRes
}
type srtServerAPIConnsKickRes struct {
err error
}
type srtServerAPIConnsKickReq struct {
uuid uuid.UUID
res chan srtServerAPIConnsKickRes
}
type srtServerParent interface {
logger.Writer
}
type srtServer struct {
readTimeout conf.StringDuration
writeTimeout conf.StringDuration
readBufferCount int
udpMaxPayloadSize int
pathManager *pathManager
parent srtServerParent
ctx context.Context
ctxCancel func()
wg sync.WaitGroup
ln srt.Listener
conns map[*srtConn]struct{}
// in
chNewConnRequest chan srtNewConnReq
chAcceptErr chan error
chCloseConn chan *srtConn
chAPIConnsList chan srtServerAPIConnsListReq
chAPIConnsGet chan srtServerAPIConnsGetReq
chAPIConnsKick chan srtServerAPIConnsKickReq
}
func newSRTServer(
address string,
readTimeout conf.StringDuration,
writeTimeout conf.StringDuration,
readBufferCount int,
udpMaxPayloadSize int,
pathManager *pathManager,
parent srtServerParent,
) (*srtServer, error) {
conf := srt.DefaultConfig()
conf.ConnectionTimeout = time.Duration(readTimeout)
conf.PayloadSize = uint32(srtMaxPayloadSize(udpMaxPayloadSize))
ln, err := srt.Listen("srt", address, conf)
if err != nil {
return nil, err
}
ctx, ctxCancel := context.WithCancel(context.Background())
s := &srtServer{
readTimeout: readTimeout,
writeTimeout: writeTimeout,
readBufferCount: readBufferCount,
udpMaxPayloadSize: udpMaxPayloadSize,
pathManager: pathManager,
parent: parent,
ctx: ctx,
ctxCancel: ctxCancel,
ln: ln,
conns: make(map[*srtConn]struct{}),
chNewConnRequest: make(chan srtNewConnReq),
chAcceptErr: make(chan error),
chCloseConn: make(chan *srtConn),
chAPIConnsList: make(chan srtServerAPIConnsListReq),
chAPIConnsGet: make(chan srtServerAPIConnsGetReq),
chAPIConnsKick: make(chan srtServerAPIConnsKickReq),
}
s.Log(logger.Info, "listener opened on "+address+" (UDP)")
newSRTListener(
s.ln,
&s.wg,
s,
)
s.wg.Add(1)
go s.run()
return s, nil
}
// Log is the main logging function.
func (s *srtServer) Log(level logger.Level, format string, args ...interface{}) {
s.parent.Log(level, "[SRT] "+format, append([]interface{}{}, args...)...)
}
func (s *srtServer) close() {
s.Log(logger.Info, "listener is closing")
s.ctxCancel()
s.wg.Wait()
}
func (s *srtServer) run() {
defer s.wg.Done()
outer:
for {
select {
case err := <-s.chAcceptErr:
s.Log(logger.Error, "%s", err)
break outer
case req := <-s.chNewConnRequest:
c := newSRTConn(
s.ctx,
s.readTimeout,
s.writeTimeout,
s.readBufferCount,
s.udpMaxPayloadSize,
req.connReq,
&s.wg,
s.pathManager,
s)
s.conns[c] = struct{}{}
req.res <- c
case c := <-s.chCloseConn:
delete(s.conns, c)
case req := <-s.chAPIConnsList:
data := &apiSRTConnsList{
Items: []*apiSRTConn{},
}
for c := range s.conns {
data.Items = append(data.Items, c.apiItem())
}
sort.Slice(data.Items, func(i, j int) bool {
return data.Items[i].Created.Before(data.Items[j].Created)
})
req.res <- srtServerAPIConnsListRes{data: data}
case req := <-s.chAPIConnsGet:
c := s.findConnByUUID(req.uuid)
if c == nil {
req.res <- srtServerAPIConnsGetRes{err: errAPINotFound}
continue
}
req.res <- srtServerAPIConnsGetRes{data: c.apiItem()}
case req := <-s.chAPIConnsKick:
c := s.findConnByUUID(req.uuid)
if c == nil {
req.res <- srtServerAPIConnsKickRes{err: errAPINotFound}
continue
}
delete(s.conns, c)
c.close()
req.res <- srtServerAPIConnsKickRes{}
case <-s.ctx.Done():
break outer
}
}
s.ctxCancel()
s.ln.Close()
}
func (s *srtServer) findConnByUUID(uuid uuid.UUID) *srtConn {
for sx := range s.conns {
if sx.uuid == uuid {
return sx
}
}
return nil
}
// newConnRequest is called by srtListener.
func (s *srtServer) newConnRequest(connReq srt.ConnRequest) *srtConn {
req := srtNewConnReq{
connReq: connReq,
res: make(chan *srtConn),
}
select {
case s.chNewConnRequest <- req:
c := <-req.res
return c.new(req)
case <-s.ctx.Done():
return nil
}
}
// acceptError is called by srtListener.
func (s *srtServer) acceptError(err error) {
select {
case s.chAcceptErr <- err:
case <-s.ctx.Done():
}
}
// closeConn is called by srtConn.
func (s *srtServer) closeConn(c *srtConn) {
select {
case s.chCloseConn <- c:
case <-s.ctx.Done():
}
}
// apiConnsList is called by api.
func (s *srtServer) apiConnsList() (*apiSRTConnsList, error) {
req := srtServerAPIConnsListReq{
res: make(chan srtServerAPIConnsListRes),
}
select {
case s.chAPIConnsList <- req:
res := <-req.res
return res.data, res.err
case <-s.ctx.Done():
return nil, fmt.Errorf("terminated")
}
}
// apiConnsGet is called by api.
func (s *srtServer) apiConnsGet(uuid uuid.UUID) (*apiSRTConn, error) {
req := srtServerAPIConnsGetReq{
uuid: uuid,
res: make(chan srtServerAPIConnsGetRes),
}
select {
case s.chAPIConnsGet <- req:
res := <-req.res
return res.data, res.err
case <-s.ctx.Done():
return nil, fmt.Errorf("terminated")
}
}
// apiConnsKick is called by api.
func (s *srtServer) apiConnsKick(uuid uuid.UUID) error {
req := srtServerAPIConnsKickReq{
uuid: uuid,
res: make(chan srtServerAPIConnsKickRes),
}
select {
case s.chAPIConnsKick <- req:
res := <-req.res
return res.err
case <-s.ctx.Done():
return fmt.Errorf("terminated")
}
}

View file

@ -0,0 +1,115 @@
package core
import (
"bufio"
"testing"
"time"
"github.com/bluenviron/mediacommon/pkg/formats/mpegts"
"github.com/datarhei/gosrt"
"github.com/stretchr/testify/require"
)
func TestSRTServer(t *testing.T) {
p, ok := newInstance("paths:\n" +
" all:\n")
require.Equal(t, true, ok)
defer p.Close()
conf := srt.DefaultConfig()
address, err := conf.UnmarshalURL("srt://localhost:8890?streamid=publish:mypath")
require.NoError(t, err)
err = conf.Validate()
require.NoError(t, err)
publisher, err := srt.Dial("srt", address, conf)
require.NoError(t, err)
defer publisher.Close()
track := &mpegts.Track{
PID: 256,
Codec: &mpegts.CodecH264{},
}
bw := bufio.NewWriter(publisher)
w := mpegts.NewWriter(bw, []*mpegts.Track{track})
require.NoError(t, err)
err = w.WriteH26x(track, 0, 0, true, [][]byte{
{ // SPS
0x67, 0x42, 0xc0, 0x28, 0xd9, 0x00, 0x78, 0x02,
0x27, 0xe5, 0x84, 0x00, 0x00, 0x03, 0x00, 0x04,
0x00, 0x00, 0x03, 0x00, 0xf0, 0x3c, 0x60, 0xc9,
0x20,
},
{ // PPS
0x08, 0x06, 0x07, 0x08,
},
{ // IDR
0x05, 1,
},
})
require.NoError(t, err)
bw.Flush()
time.Sleep(500 * time.Millisecond)
conf = srt.DefaultConfig()
address, err = conf.UnmarshalURL("srt://localhost:8890?streamid=read:mypath")
require.NoError(t, err)
err = conf.Validate()
require.NoError(t, err)
reader, err := srt.Dial("srt", address, conf)
require.NoError(t, err)
defer reader.Close()
err = w.WriteH26x(track, 2*90000, 1*90000, true, [][]byte{
{ // IDR
0x05, 2,
},
})
require.NoError(t, err)
bw.Flush()
r, err := mpegts.NewReader(reader)
require.NoError(t, err)
require.Equal(t, []*mpegts.Track{{
PID: 256,
Codec: &mpegts.CodecH264{},
}}, r.Tracks())
received := false
r.OnDataH26x(r.Tracks()[0], func(pts int64, dts int64, au [][]byte) error {
require.Equal(t, int64(0), pts)
require.Equal(t, int64(0), dts)
require.Equal(t, [][]byte{
{ // SPS
0x67, 0x42, 0xc0, 0x28, 0xd9, 0x00, 0x78, 0x02,
0x27, 0xe5, 0x84, 0x00, 0x00, 0x03, 0x00, 0x04,
0x00, 0x00, 0x03, 0x00, 0xf0, 0x3c, 0x60, 0xc9,
0x20,
},
{ // PPS
0x08, 0x06, 0x07, 0x08,
},
{ // IDR
0x05, 1,
},
}, au)
received = true
return nil
})
for {
err = r.Read()
require.NoError(t, err)
if received {
break
}
}
}

221
internal/core/srt_source.go Normal file
View file

@ -0,0 +1,221 @@
package core
import (
"context"
"time"
"github.com/bluenviron/gortsplib/v3/pkg/formats"
"github.com/bluenviron/gortsplib/v3/pkg/media"
"github.com/bluenviron/mediacommon/pkg/formats/mpegts"
"github.com/datarhei/gosrt"
"github.com/bluenviron/mediamtx/internal/conf"
"github.com/bluenviron/mediamtx/internal/formatprocessor"
"github.com/bluenviron/mediamtx/internal/logger"
"github.com/bluenviron/mediamtx/internal/stream"
)
type srtSourceParent interface {
logger.Writer
sourceStaticImplSetReady(req pathSourceStaticSetReadyReq) pathSourceStaticSetReadyRes
sourceStaticImplSetNotReady(req pathSourceStaticSetNotReadyReq)
}
type srtSource struct {
readTimeout conf.StringDuration
parent srtSourceParent
}
func newSRTSource(
readTimeout conf.StringDuration,
parent srtSourceParent,
) *srtSource {
s := &srtSource{
readTimeout: readTimeout,
parent: parent,
}
return s
}
func (s *srtSource) Log(level logger.Level, format string, args ...interface{}) {
s.parent.Log(level, "[srt source] "+format, args...)
}
// run implements sourceStaticImpl.
func (s *srtSource) run(ctx context.Context, cnf *conf.PathConf, reloadConf chan *conf.PathConf) error {
s.Log(logger.Debug, "connecting")
conf := srt.DefaultConfig()
address, err := conf.UnmarshalURL(cnf.Source)
if err != nil {
return err
}
err = conf.Validate()
if err != nil {
return err
}
sconn, err := srt.Dial("srt", address, conf)
if err != nil {
return err
}
readDone := make(chan error)
go func() {
readDone <- s.runReader(sconn)
}()
for {
select {
case err := <-readDone:
sconn.Close()
return err
case <-reloadConf:
case <-ctx.Done():
sconn.Close()
<-readDone
return nil
}
}
}
func (s *srtSource) runReader(sconn srt.Conn) error {
sconn.SetReadDeadline(time.Now().Add(time.Duration(s.readTimeout)))
r, err := mpegts.NewReader(mpegts.NewBufferedReader(sconn))
if err != nil {
return err
}
var medias media.Medias
var stream *stream.Stream
var td *mpegts.TimeDecoder
decodeTime := func(t int64) time.Duration {
if td == nil {
td = mpegts.NewTimeDecoder(t)
}
return td.Decode(t)
}
for _, track := range r.Tracks() { //nolint:dupl
var medi *media.Media
switch tcodec := track.Codec.(type) {
case *mpegts.CodecH264:
medi = &media.Media{
Type: media.TypeVideo,
Formats: []formats.Format{&formats.H264{
PayloadTyp: 96,
PacketizationMode: 1,
}},
}
r.OnDataH26x(track, func(pts int64, _ int64, au [][]byte) error {
stream.WriteUnit(medi, medi.Formats[0], &formatprocessor.UnitH264{
BaseUnit: formatprocessor.BaseUnit{
NTP: time.Now(),
},
PTS: decodeTime(pts),
AU: au,
})
return nil
})
case *mpegts.CodecH265:
medi = &media.Media{
Type: media.TypeVideo,
Formats: []formats.Format{&formats.H265{
PayloadTyp: 96,
}},
}
r.OnDataH26x(track, func(pts int64, _ int64, au [][]byte) error {
stream.WriteUnit(medi, medi.Formats[0], &formatprocessor.UnitH265{
BaseUnit: formatprocessor.BaseUnit{
NTP: time.Now(),
},
PTS: decodeTime(pts),
AU: au,
})
return nil
})
case *mpegts.CodecMPEG4Audio:
medi = &media.Media{
Type: media.TypeAudio,
Formats: []formats.Format{&formats.MPEG4Audio{
PayloadTyp: 96,
SizeLength: 13,
IndexLength: 3,
IndexDeltaLength: 3,
Config: &tcodec.Config,
}},
}
r.OnDataMPEG4Audio(track, func(pts int64, _ int64, aus [][]byte) error {
stream.WriteUnit(medi, medi.Formats[0], &formatprocessor.UnitMPEG4AudioGeneric{
BaseUnit: formatprocessor.BaseUnit{
NTP: time.Now(),
},
PTS: decodeTime(pts),
AUs: aus,
})
return nil
})
case *mpegts.CodecOpus:
medi = &media.Media{
Type: media.TypeAudio,
Formats: []formats.Format{&formats.Opus{
PayloadTyp: 96,
IsStereo: (tcodec.ChannelCount == 2),
}},
}
r.OnDataOpus(track, func(pts int64, _ int64, packets [][]byte) error {
stream.WriteUnit(medi, medi.Formats[0], &formatprocessor.UnitOpus{
BaseUnit: formatprocessor.BaseUnit{
NTP: time.Now(),
},
PTS: decodeTime(pts),
Packets: packets,
})
return nil
})
}
medias = append(medias, medi)
}
res := s.parent.sourceStaticImplSetReady(pathSourceStaticSetReadyReq{
medias: medias,
generateRTPPackets: true,
})
if res.err != nil {
return res.err
}
s.Log(logger.Info, "ready: %s", sourceMediaInfo(medias))
stream = res.stream
for {
sconn.SetReadDeadline(time.Now().Add(time.Duration(s.readTimeout)))
err := r.Read()
if err != nil {
return err
}
}
}
// apiSourceDescribe implements sourceStaticImpl.
func (*srtSource) apiSourceDescribe() pathAPISourceOrReader {
return pathAPISourceOrReader{
Type: "srtSource",
ID: "",
}
}

View file

@ -0,0 +1,98 @@
package core
import (
"bufio"
"testing"
"github.com/bluenviron/gortsplib/v3"
"github.com/bluenviron/gortsplib/v3/pkg/url"
"github.com/bluenviron/mediacommon/pkg/formats/mpegts"
"github.com/datarhei/gosrt"
"github.com/pion/rtp"
"github.com/stretchr/testify/require"
)
func TestSRTSource(t *testing.T) {
ln, err := srt.Listen("srt", "localhost:9999", srt.DefaultConfig())
require.NoError(t, err)
defer ln.Close()
connected := make(chan struct{})
received := make(chan struct{})
done := make(chan struct{})
go func() {
conn, _, err := ln.Accept(func(req srt.ConnRequest) srt.ConnType {
require.Equal(t, "sidname", req.StreamId())
err := req.SetPassphrase("ttest1234567")
if err != nil {
return srt.REJECT
}
return srt.SUBSCRIBE
})
require.NoError(t, err)
require.NotNil(t, conn)
defer conn.Close()
track := &mpegts.Track{
PID: 256,
Codec: &mpegts.CodecH264{},
}
bw := bufio.NewWriter(conn)
w := mpegts.NewWriter(bw, []*mpegts.Track{track})
require.NoError(t, err)
err = w.WriteH26x(track, 0, 0, true, [][]byte{
{ // IDR
0x05, 1,
},
})
require.NoError(t, err)
bw.Flush()
<-connected
err = w.WriteH26x(track, 0, 0, true, [][]byte{{5, 2}})
require.NoError(t, err)
bw.Flush()
<-done
}()
p, ok := newInstance("paths:\n" +
" proxied:\n" +
" source: srt://localhost:9999?streamid=sidname&passphrase=ttest1234567\n" +
" sourceOnDemand: yes\n")
require.Equal(t, true, ok)
defer p.Close()
c := gortsplib.Client{}
u, err := url.Parse("rtsp://127.0.0.1:8554/proxied")
require.NoError(t, err)
err = c.Start(u.Scheme, u.Host)
require.NoError(t, err)
defer c.Close()
medias, baseURL, _, err := c.Describe(u)
require.NoError(t, err)
err = c.SetupAll(medias, baseURL)
require.NoError(t, err)
c.OnPacketRTP(medias[0], medias[0].Formats[0], func(pkt *rtp.Packet) {
require.Equal(t, []byte{5, 1}, pkt.Payload)
close(received)
})
_, err = c.Play(nil)
require.NoError(t, err)
close(connected)
<-received
close(done)
}

View file

@ -93,15 +93,6 @@ type webRTCManagerAPISessionsListReq struct {
res chan webRTCManagerAPISessionsListRes
}
type webRTCManagerAPISessionsKickRes struct {
err error
}
type webRTCManagerAPISessionsKickReq struct {
uuid uuid.UUID
res chan webRTCManagerAPISessionsKickRes
}
type webRTCManagerAPISessionsGetRes struct {
data *apiWebRTCSession
err error
@ -112,6 +103,15 @@ type webRTCManagerAPISessionsGetReq struct {
res chan webRTCManagerAPISessionsGetRes
}
type webRTCManagerAPISessionsKickRes struct {
err error
}
type webRTCManagerAPISessionsKickReq struct {
uuid uuid.UUID
res chan webRTCManagerAPISessionsKickRes
}
type webRTCNewSessionRes struct {
sx *webRTCSession
answer []byte