diff --git a/internal/rtmp/base/messagewriter.go b/internal/rtmp/base/messagewriter.go index ad4736ec..5eb2464a 100644 --- a/internal/rtmp/base/messagewriter.go +++ b/internal/rtmp/base/messagewriter.go @@ -4,62 +4,46 @@ import ( "io" ) -// MessageWriter is a message writer. -type MessageWriter struct { - w io.Writer - chunkMaxBodyLen int - lastMessageStreamIDPerChunkStreamID map[byte]uint32 +type messageWriterChunkStream struct { + mw *MessageWriter + lastMessageStreamID *uint32 } -// NewMessageWriter instantiates a MessageWriter. -func NewMessageWriter(w io.Writer) *MessageWriter { - return &MessageWriter{ - w: w, - chunkMaxBodyLen: 128, - lastMessageStreamIDPerChunkStreamID: make(map[byte]uint32), - } -} - -// SetChunkSize sets the chunk size. -func (mw *MessageWriter) SetChunkSize(v int) { - mw.chunkMaxBodyLen = v -} - -// Write writes a Message. -func (mw *MessageWriter) Write(msg *Message) error { +func (wc *messageWriterChunkStream) write(msg *Message) error { bodyLen := len(msg.Body) pos := 0 - first := true + firstChunk := true for { chunkBodyLen := bodyLen - pos - if chunkBodyLen > mw.chunkMaxBodyLen { - chunkBodyLen = mw.chunkMaxBodyLen + if chunkBodyLen > wc.mw.chunkSize { + chunkBodyLen = wc.mw.chunkSize } - if first { - first = false + if firstChunk { + firstChunk = false - if v, ok := mw.lastMessageStreamIDPerChunkStreamID[msg.ChunkStreamID]; !ok || v != msg.MessageStreamID { + if wc.lastMessageStreamID == nil || *wc.lastMessageStreamID != msg.MessageStreamID { err := Chunk0{ ChunkStreamID: msg.ChunkStreamID, Type: msg.Type, MessageStreamID: msg.MessageStreamID, BodyLen: uint32(bodyLen), Body: msg.Body[pos : pos+chunkBodyLen], - }.Write(mw.w) + }.Write(wc.mw.w) if err != nil { return err } - mw.lastMessageStreamIDPerChunkStreamID[msg.ChunkStreamID] = msg.MessageStreamID + v := msg.MessageStreamID + wc.lastMessageStreamID = &v } else { err := Chunk1{ ChunkStreamID: msg.ChunkStreamID, Type: msg.Type, BodyLen: uint32(bodyLen), Body: msg.Body[pos : pos+chunkBodyLen], - }.Write(mw.w) + }.Write(wc.mw.w) if err != nil { return err } @@ -68,7 +52,7 @@ func (mw *MessageWriter) Write(msg *Message) error { err := Chunk3{ ChunkStreamID: msg.ChunkStreamID, Body: msg.Body[pos : pos+chunkBodyLen], - }.Write(mw.w) + }.Write(wc.mw.w) if err != nil { return err } @@ -81,3 +65,35 @@ func (mw *MessageWriter) Write(msg *Message) error { } } } + +// MessageWriter is a message writer. +type MessageWriter struct { + w io.Writer + chunkSize int + chunkStreams map[byte]*messageWriterChunkStream +} + +// NewMessageWriter instantiates a MessageWriter. +func NewMessageWriter(w io.Writer) *MessageWriter { + return &MessageWriter{ + w: w, + chunkSize: 128, + chunkStreams: make(map[byte]*messageWriterChunkStream), + } +} + +// SetChunkSize sets the maximum chunk size. +func (mw *MessageWriter) SetChunkSize(v int) { + mw.chunkSize = v +} + +// Write writes a Message. +func (mw *MessageWriter) Write(msg *Message) error { + cs, ok := mw.chunkStreams[msg.ChunkStreamID] + if !ok { + cs = &messageWriterChunkStream{mw: mw} + mw.chunkStreams[msg.ChunkStreamID] = cs + } + + return cs.write(msg) +}