167 lines
3.3 KiB
Go
167 lines
3.3 KiB
Go
package smux
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
// Stream implements io.ReadWriteCloser
|
|
type Stream struct {
|
|
id uint32
|
|
rstflag int32
|
|
sess *Session
|
|
buffer bytes.Buffer
|
|
bufferLock sync.Mutex
|
|
frameSize int
|
|
chReadEvent chan struct{} // notify a read event
|
|
die chan struct{} // flag the stream has closed
|
|
dieLock sync.Mutex
|
|
}
|
|
|
|
// newStream initiates a Stream struct
|
|
func newStream(id uint32, frameSize int, sess *Session) *Stream {
|
|
s := new(Stream)
|
|
s.id = id
|
|
s.chReadEvent = make(chan struct{}, 1)
|
|
s.frameSize = frameSize
|
|
s.sess = sess
|
|
s.die = make(chan struct{})
|
|
return s
|
|
}
|
|
|
|
// Read implements io.ReadWriteCloser
|
|
func (s *Stream) Read(b []byte) (n int, err error) {
|
|
READ:
|
|
select {
|
|
case <-s.die:
|
|
return 0, errors.New(errBrokenPipe)
|
|
default:
|
|
}
|
|
|
|
s.bufferLock.Lock()
|
|
n, err = s.buffer.Read(b)
|
|
s.bufferLock.Unlock()
|
|
|
|
if n > 0 {
|
|
s.sess.returnTokens(n)
|
|
return n, nil
|
|
} else if atomic.LoadInt32(&s.rstflag) == 1 {
|
|
_ = s.Close()
|
|
return 0, errors.New(errConnReset)
|
|
}
|
|
|
|
select {
|
|
case <-s.chReadEvent:
|
|
goto READ
|
|
case <-s.die:
|
|
return 0, errors.New(errBrokenPipe)
|
|
}
|
|
}
|
|
|
|
// Write implements io.ReadWriteCloser
|
|
func (s *Stream) Write(b []byte) (n int, err error) {
|
|
select {
|
|
case <-s.die:
|
|
return 0, errors.New(errBrokenPipe)
|
|
default:
|
|
}
|
|
|
|
frames := s.split(b, cmdPSH, s.id)
|
|
// preallocate buffer
|
|
buffer := make([]byte, len(frames)*headerSize+len(b))
|
|
bts := buffer
|
|
|
|
// combine frames into a large blob
|
|
for k := range frames {
|
|
bts[0] = version
|
|
bts[1] = frames[k].cmd
|
|
binary.LittleEndian.PutUint16(bts[2:], uint16(len(frames[k].data)))
|
|
binary.LittleEndian.PutUint32(bts[4:], frames[k].sid)
|
|
copy(bts[headerSize:], frames[k].data)
|
|
bts = bts[len(frames[k].data)+headerSize:]
|
|
}
|
|
|
|
if _, err = s.sess.writeBinary(buffer); err != nil {
|
|
return 0, err
|
|
}
|
|
return len(b), nil
|
|
}
|
|
|
|
// Close implements io.ReadWriteCloser
|
|
func (s *Stream) Close() error {
|
|
s.dieLock.Lock()
|
|
defer s.dieLock.Unlock()
|
|
|
|
select {
|
|
case <-s.die:
|
|
return errors.New(errBrokenPipe)
|
|
default:
|
|
close(s.die)
|
|
s.sess.streamClosed(s.id)
|
|
_, err := s.sess.writeFrame(newFrame(cmdRST, s.id))
|
|
return err
|
|
}
|
|
}
|
|
|
|
// session closes the stream
|
|
func (s *Stream) sessionClose() {
|
|
s.dieLock.Lock()
|
|
defer s.dieLock.Unlock()
|
|
|
|
select {
|
|
case <-s.die:
|
|
default:
|
|
close(s.die)
|
|
}
|
|
}
|
|
|
|
// pushBytes a slice into buffer
|
|
func (s *Stream) pushBytes(p []byte) {
|
|
s.bufferLock.Lock()
|
|
s.buffer.Write(p)
|
|
s.bufferLock.Unlock()
|
|
}
|
|
|
|
// recycleTokens transform remaining bytes to tokens(will truncate buffer)
|
|
func (s *Stream) recycleTokens() (n int) {
|
|
s.bufferLock.Lock()
|
|
n = s.buffer.Len()
|
|
s.buffer.Reset()
|
|
s.bufferLock.Unlock()
|
|
return
|
|
}
|
|
|
|
// split large byte buffer into smaller frames, reference only
|
|
func (s *Stream) split(bts []byte, cmd byte, sid uint32) []Frame {
|
|
var frames []Frame
|
|
for len(bts) > s.frameSize {
|
|
frame := newFrame(cmd, sid)
|
|
frame.data = bts[:s.frameSize]
|
|
bts = bts[s.frameSize:]
|
|
frames = append(frames, frame)
|
|
}
|
|
if len(bts) > 0 {
|
|
frame := newFrame(cmd, sid)
|
|
frame.data = bts
|
|
frames = append(frames, frame)
|
|
}
|
|
return frames
|
|
}
|
|
|
|
// notify read event
|
|
func (s *Stream) notifyReadEvent() {
|
|
select {
|
|
case s.chReadEvent <- struct{}{}:
|
|
default:
|
|
}
|
|
}
|
|
|
|
// mark this stream has been reset
|
|
func (s *Stream) markRST() {
|
|
atomic.StoreInt32(&s.rstflag, 1)
|
|
}
|