mediamtx/internal/core/rtsp_source.go
aler9 3606472e82 generate RTP packets after H264 remuxing
Previously, RTP packets coming from sources other than RTSP (that
actually are RTMP and HLS) were generated before the H264 remuxing, and
that leaded to invalid streams, expecially when sourceOnDemand is true
and the stream has invalid or dynamic SPS/PPS.
2022-08-14 13:03:04 +02:00

185 lines
4.1 KiB
Go

package core
import (
"context"
"crypto/sha256"
"crypto/tls"
"encoding/hex"
"fmt"
"strings"
"time"
"github.com/aler9/gortsplib"
"github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/url"
"github.com/aler9/rtsp-simple-server/internal/conf"
"github.com/aler9/rtsp-simple-server/internal/logger"
)
type rtspSourceParent interface {
log(logger.Level, string, ...interface{})
sourceStaticImplSetReady(req pathSourceStaticSetReadyReq) pathSourceStaticSetReadyRes
sourceStaticImplSetNotReady(req pathSourceStaticSetNotReadyReq)
}
type rtspSource struct {
ur string
proto conf.SourceProtocol
anyPortEnable bool
fingerprint string
readTimeout conf.StringDuration
writeTimeout conf.StringDuration
readBufferCount int
parent rtspSourceParent
}
func newRTSPSource(
ur string,
proto conf.SourceProtocol,
anyPortEnable bool,
fingerprint string,
readTimeout conf.StringDuration,
writeTimeout conf.StringDuration,
readBufferCount int,
parent rtspSourceParent,
) *rtspSource {
return &rtspSource{
ur: ur,
proto: proto,
anyPortEnable: anyPortEnable,
fingerprint: fingerprint,
readTimeout: readTimeout,
writeTimeout: writeTimeout,
readBufferCount: readBufferCount,
parent: parent,
}
}
func (s *rtspSource) Log(level logger.Level, format string, args ...interface{}) {
s.parent.log(level, "[rtsp source] "+format, args...)
}
// run implements sourceStaticImpl.
func (s *rtspSource) run(ctx context.Context) error {
s.Log(logger.Debug, "connecting")
var tlsConfig *tls.Config
if s.fingerprint != "" {
tlsConfig = &tls.Config{
InsecureSkipVerify: true,
VerifyConnection: func(cs tls.ConnectionState) error {
h := sha256.New()
h.Write(cs.PeerCertificates[0].Raw)
hstr := hex.EncodeToString(h.Sum(nil))
fingerprintLower := strings.ToLower(s.fingerprint)
if hstr != fingerprintLower {
return fmt.Errorf("server fingerprint do not match: expected %s, got %s",
fingerprintLower, hstr)
}
return nil
},
}
}
c := &gortsplib.Client{
Transport: s.proto.Transport,
TLSConfig: tlsConfig,
ReadTimeout: time.Duration(s.readTimeout),
WriteTimeout: time.Duration(s.writeTimeout),
ReadBufferCount: s.readBufferCount,
AnyPortEnable: s.anyPortEnable,
OnRequest: func(req *base.Request) {
s.Log(logger.Debug, "c->s %v", req)
},
OnResponse: func(res *base.Response) {
s.Log(logger.Debug, "s->c %v", res)
},
}
u, err := url.Parse(s.ur)
if err != nil {
return err
}
err = c.Start(u.Scheme, u.Host)
if err != nil {
return err
}
defer c.Close()
readErr := make(chan error)
go func() {
readErr <- func() error {
tracks, baseURL, _, err := c.Describe(u)
if err != nil {
return err
}
for _, t := range tracks {
_, err := c.Setup(true, t, baseURL, 0, 0)
if err != nil {
return err
}
}
res := s.parent.sourceStaticImplSetReady(pathSourceStaticSetReadyReq{
tracks: tracks,
generateRTPPackets: false,
})
if res.err != nil {
return res.err
}
s.Log(logger.Info, "ready")
defer func() {
s.parent.sourceStaticImplSetNotReady(pathSourceStaticSetNotReadyReq{})
}()
c.OnPacketRTP = func(ctx *gortsplib.ClientOnPacketRTPCtx) {
if ctx.H264NALUs != nil {
res.stream.writeData(&data{
trackID: ctx.TrackID,
rtpPacket: ctx.Packet,
ptsEqualsDTS: ctx.PTSEqualsDTS,
pts: ctx.H264PTS,
h264NALUs: append([][]byte(nil), ctx.H264NALUs...),
})
} else {
res.stream.writeData(&data{
trackID: ctx.TrackID,
rtpPacket: ctx.Packet,
ptsEqualsDTS: ctx.PTSEqualsDTS,
})
}
}
_, err = c.Play(nil)
if err != nil {
return err
}
return c.Wait()
}()
}()
select {
case err := <-readErr:
return err
case <-ctx.Done():
c.Close()
<-readErr
return nil
}
}
// apiSourceDescribe implements sourceStaticImpl.
func (*rtspSource) apiSourceDescribe() interface{} {
return struct {
Type string `json:"type"`
}{"rtspSource"}
}