gost_software/vendor/github.com/isofew/go-stun/stun/agent.go
2018-05-17 13:14:45 +01:00

285 lines
5.1 KiB
Go

package stun
import (
"errors"
"math/rand"
"net"
"sync"
"time"
)
var DefaultConfig = &Config{
RetransmissionTimeout: 500 * time.Millisecond,
TransactionTimeout: 39500 * time.Millisecond,
Software: "pixelbender/go-stun",
}
type Handler interface {
ServeSTUN(msg *Message, tr Transport)
}
type HandlerFunc func(msg *Message, tr Transport)
func (h HandlerFunc) ServeSTUN(msg *Message, tr Transport) {
h(msg, tr)
}
type Config struct {
// AuthMethod returns a key for MESSAGE-INTEGRITY attribute
AuthMethod AuthMethod
// Retransmission timeout, default is 500 milliseconds
RetransmissionTimeout time.Duration
// Transaction timeout, default is 39.5 seconds
TransactionTimeout time.Duration
// Fingerprint, if true all outgoing messages contain FINGERPRINT attribute
Fingerprint bool
// Software is a SOFTWARE attribute value for outgoing messages, if not empty
Software string
// Logf, if set all sent and received messages printed using Logf
Logf func(format string, args ...interface{})
}
func (c *Config) attrs() []Attr {
if c == nil {
return nil
}
var a []Attr
if c.Software != "" {
a = append(a, String(AttrSoftware, c.Software))
}
if c.Fingerprint {
a = append(a, Fingerprint)
}
return a
}
func (c *Config) Clone() *Config {
r := *c
return &r
}
type Agent struct {
config *Config
Handler Handler
m mux
}
func NewAgent(config *Config) *Agent {
if config == nil {
config = DefaultConfig
}
return &Agent{
config: config,
}
}
func (a *Agent) Send(msg *Message, tr Transport) (err error) {
msg = &Message{
msg.Type,
msg.Transaction,
append(a.config.attrs(), msg.Attributes...),
}
if log := a.config.Logf; log != nil {
log("%v → %v %v", tr.LocalAddr(), tr.RemoteAddr(), msg)
}
b := msg.Marshal(getBuffer()[:0])
_, err = tr.Write(b)
putBuffer(b)
return
}
func (a *Agent) ServeConn(c net.Conn, stop chan struct{}) error {
if c, ok := c.(net.PacketConn); ok {
return a.ServePacket(c, stop)
}
var (
b = getBuffer()
p int
)
defer putBuffer(b)
for {
select {
case <-stop:
return nil
default:
}
if p >= len(b) {
return errBufferOverflow
}
n, err := c.Read(b[p:])
if err != nil {
return err
}
p += n
n = 0
for n < p {
r, err := a.ServeTransport(b[n:p], c)
if err != nil {
return err
}
n += r
}
if n > 0 {
if n < p {
p = copy(b, b[n:p])
} else {
p = 0
}
}
}
}
func (a *Agent) ServePacket(c net.PacketConn, stop chan struct{}) error {
b := getBuffer()
defer putBuffer(b)
// don't close the connection since we're going to reuse it
// defer c.Close()
for {
select {
case <-stop:
return nil
default:
}
n, addr, err := c.ReadFrom(b)
if err != nil {
return err
}
if n > 0 {
a.ServeTransport(b[:n], &packetConn{c, addr})
}
}
}
func (a *Agent) ServeTransport(b []byte, tr Transport) (n int, err error) {
msg := &Message{}
n, err = msg.Unmarshal(b)
if err != nil {
return
}
a.ServeSTUN(msg, tr)
return
}
func (a *Agent) ServeSTUN(msg *Message, tr Transport) {
if log := a.config.Logf; log != nil {
log("%v ← %v %v", tr.LocalAddr(), tr.RemoteAddr(), msg)
}
if a.m.serve(msg, tr) {
return
}
if h := a.Handler; h != nil {
go h.ServeSTUN(msg, tr)
}
}
func (a *Agent) RoundTrip(req *Message, to Transport) (res *Message, from Transport, err error) {
var (
start = time.Now()
rto = a.config.RetransmissionTimeout
udp = to.LocalAddr().Network() == "udp"
tx = a.m.newTx()
)
defer a.m.closeTx(tx)
req = &Message{req.Type, tx.id, req.Attributes}
if err = a.Send(req, to); err != nil {
return
}
for {
d := a.config.TransactionTimeout - time.Since(start)
if d < 0 {
err = errTimeout
return
}
if udp && d > rto {
d = rto
}
res, from, err = tx.Receive(d)
if udp && err == errTimeout && d == rto {
rto <<= 1
a.Send(req, to)
continue
}
return
}
}
type mux struct {
sync.RWMutex
t map[string]*transaction
}
func (m *mux) serve(msg *Message, tr Transport) bool {
m.RLock()
tx, ok := m.t[string(msg.Transaction)]
m.RUnlock()
if ok {
tx.msg, tx.from = msg, tr
tx.Done()
return true
}
return false
}
func (m *mux) newTx() *transaction {
tx := &transaction{id: NewTransaction()}
m.Lock()
if m.t == nil {
m.t = make(map[string]*transaction)
} else {
for m.t[string(tx.id)] != nil {
rand.Read(tx.id[4:])
}
}
m.t[string(tx.id)] = tx
m.Unlock()
return tx
}
func (m *mux) closeTx(tx *transaction) {
m.Lock()
delete(m.t, string(tx.id))
m.Unlock()
}
func (m *mux) Close() {
m.Lock()
defer m.Unlock()
for _, it := range m.t {
it.Close()
}
m.t = nil
}
type transaction struct {
sync.WaitGroup
id []byte
from Transport
msg *Message
err error
}
func (tx *transaction) Receive(d time.Duration) (msg *Message, from Transport, err error) {
tx.Add(1)
t := time.AfterFunc(d, tx.timeout)
tx.Wait()
t.Stop()
if err = tx.err; err != nil {
return
}
return tx.msg, tx.from, nil
}
func (tx *transaction) timeout() {
tx.err = errTimeout
tx.Done()
}
func (tx *transaction) Close() {
tx.err = errCanceled
tx.Done()
}
var errCanceled = errors.New("stun: transaction canceled")
var errTimeout = errors.New("stun: transaction timeout")