285 lines
5.1 KiB
Go
285 lines
5.1 KiB
Go
package stun
|
|
|
|
import (
|
|
"errors"
|
|
"math/rand"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
var DefaultConfig = &Config{
|
|
RetransmissionTimeout: 500 * time.Millisecond,
|
|
TransactionTimeout: 39500 * time.Millisecond,
|
|
Software: "pixelbender/go-stun",
|
|
}
|
|
|
|
type Handler interface {
|
|
ServeSTUN(msg *Message, tr Transport)
|
|
}
|
|
|
|
type HandlerFunc func(msg *Message, tr Transport)
|
|
|
|
func (h HandlerFunc) ServeSTUN(msg *Message, tr Transport) {
|
|
h(msg, tr)
|
|
}
|
|
|
|
type Config struct {
|
|
// AuthMethod returns a key for MESSAGE-INTEGRITY attribute
|
|
AuthMethod AuthMethod
|
|
// Retransmission timeout, default is 500 milliseconds
|
|
RetransmissionTimeout time.Duration
|
|
// Transaction timeout, default is 39.5 seconds
|
|
TransactionTimeout time.Duration
|
|
// Fingerprint, if true all outgoing messages contain FINGERPRINT attribute
|
|
Fingerprint bool
|
|
// Software is a SOFTWARE attribute value for outgoing messages, if not empty
|
|
Software string
|
|
// Logf, if set all sent and received messages printed using Logf
|
|
Logf func(format string, args ...interface{})
|
|
}
|
|
|
|
func (c *Config) attrs() []Attr {
|
|
if c == nil {
|
|
return nil
|
|
}
|
|
var a []Attr
|
|
if c.Software != "" {
|
|
a = append(a, String(AttrSoftware, c.Software))
|
|
}
|
|
if c.Fingerprint {
|
|
a = append(a, Fingerprint)
|
|
}
|
|
return a
|
|
}
|
|
|
|
func (c *Config) Clone() *Config {
|
|
r := *c
|
|
return &r
|
|
}
|
|
|
|
type Agent struct {
|
|
config *Config
|
|
Handler Handler
|
|
m mux
|
|
}
|
|
|
|
func NewAgent(config *Config) *Agent {
|
|
if config == nil {
|
|
config = DefaultConfig
|
|
}
|
|
return &Agent{
|
|
config: config,
|
|
}
|
|
}
|
|
|
|
func (a *Agent) Send(msg *Message, tr Transport) (err error) {
|
|
msg = &Message{
|
|
msg.Type,
|
|
msg.Transaction,
|
|
append(a.config.attrs(), msg.Attributes...),
|
|
}
|
|
if log := a.config.Logf; log != nil {
|
|
log("%v → %v %v", tr.LocalAddr(), tr.RemoteAddr(), msg)
|
|
}
|
|
b := msg.Marshal(getBuffer()[:0])
|
|
_, err = tr.Write(b)
|
|
putBuffer(b)
|
|
return
|
|
}
|
|
|
|
func (a *Agent) ServeConn(c net.Conn, stop chan struct{}) error {
|
|
if c, ok := c.(net.PacketConn); ok {
|
|
return a.ServePacket(c, stop)
|
|
}
|
|
var (
|
|
b = getBuffer()
|
|
p int
|
|
)
|
|
defer putBuffer(b)
|
|
for {
|
|
select {
|
|
case <-stop:
|
|
return nil
|
|
default:
|
|
}
|
|
if p >= len(b) {
|
|
return errBufferOverflow
|
|
}
|
|
n, err := c.Read(b[p:])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
p += n
|
|
n = 0
|
|
for n < p {
|
|
r, err := a.ServeTransport(b[n:p], c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
n += r
|
|
}
|
|
if n > 0 {
|
|
if n < p {
|
|
p = copy(b, b[n:p])
|
|
} else {
|
|
p = 0
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (a *Agent) ServePacket(c net.PacketConn, stop chan struct{}) error {
|
|
b := getBuffer()
|
|
defer putBuffer(b)
|
|
// don't close the connection since we're going to reuse it
|
|
// defer c.Close()
|
|
|
|
for {
|
|
select {
|
|
case <-stop:
|
|
return nil
|
|
default:
|
|
}
|
|
n, addr, err := c.ReadFrom(b)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if n > 0 {
|
|
a.ServeTransport(b[:n], &packetConn{c, addr})
|
|
}
|
|
}
|
|
}
|
|
|
|
func (a *Agent) ServeTransport(b []byte, tr Transport) (n int, err error) {
|
|
msg := &Message{}
|
|
n, err = msg.Unmarshal(b)
|
|
if err != nil {
|
|
return
|
|
}
|
|
a.ServeSTUN(msg, tr)
|
|
return
|
|
}
|
|
|
|
func (a *Agent) ServeSTUN(msg *Message, tr Transport) {
|
|
if log := a.config.Logf; log != nil {
|
|
log("%v ← %v %v", tr.LocalAddr(), tr.RemoteAddr(), msg)
|
|
}
|
|
if a.m.serve(msg, tr) {
|
|
return
|
|
}
|
|
if h := a.Handler; h != nil {
|
|
go h.ServeSTUN(msg, tr)
|
|
}
|
|
}
|
|
|
|
func (a *Agent) RoundTrip(req *Message, to Transport) (res *Message, from Transport, err error) {
|
|
var (
|
|
start = time.Now()
|
|
rto = a.config.RetransmissionTimeout
|
|
udp = to.LocalAddr().Network() == "udp"
|
|
tx = a.m.newTx()
|
|
)
|
|
defer a.m.closeTx(tx)
|
|
req = &Message{req.Type, tx.id, req.Attributes}
|
|
if err = a.Send(req, to); err != nil {
|
|
return
|
|
}
|
|
for {
|
|
d := a.config.TransactionTimeout - time.Since(start)
|
|
if d < 0 {
|
|
err = errTimeout
|
|
return
|
|
}
|
|
if udp && d > rto {
|
|
d = rto
|
|
}
|
|
res, from, err = tx.Receive(d)
|
|
if udp && err == errTimeout && d == rto {
|
|
rto <<= 1
|
|
a.Send(req, to)
|
|
continue
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
type mux struct {
|
|
sync.RWMutex
|
|
t map[string]*transaction
|
|
}
|
|
|
|
func (m *mux) serve(msg *Message, tr Transport) bool {
|
|
m.RLock()
|
|
tx, ok := m.t[string(msg.Transaction)]
|
|
m.RUnlock()
|
|
if ok {
|
|
tx.msg, tx.from = msg, tr
|
|
tx.Done()
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (m *mux) newTx() *transaction {
|
|
tx := &transaction{id: NewTransaction()}
|
|
m.Lock()
|
|
if m.t == nil {
|
|
m.t = make(map[string]*transaction)
|
|
} else {
|
|
for m.t[string(tx.id)] != nil {
|
|
rand.Read(tx.id[4:])
|
|
}
|
|
}
|
|
m.t[string(tx.id)] = tx
|
|
m.Unlock()
|
|
return tx
|
|
}
|
|
|
|
func (m *mux) closeTx(tx *transaction) {
|
|
m.Lock()
|
|
delete(m.t, string(tx.id))
|
|
m.Unlock()
|
|
}
|
|
|
|
func (m *mux) Close() {
|
|
m.Lock()
|
|
defer m.Unlock()
|
|
for _, it := range m.t {
|
|
it.Close()
|
|
}
|
|
m.t = nil
|
|
}
|
|
|
|
type transaction struct {
|
|
sync.WaitGroup
|
|
id []byte
|
|
from Transport
|
|
msg *Message
|
|
err error
|
|
}
|
|
|
|
func (tx *transaction) Receive(d time.Duration) (msg *Message, from Transport, err error) {
|
|
tx.Add(1)
|
|
t := time.AfterFunc(d, tx.timeout)
|
|
tx.Wait()
|
|
t.Stop()
|
|
if err = tx.err; err != nil {
|
|
return
|
|
}
|
|
return tx.msg, tx.from, nil
|
|
}
|
|
|
|
func (tx *transaction) timeout() {
|
|
tx.err = errTimeout
|
|
tx.Done()
|
|
}
|
|
|
|
func (tx *transaction) Close() {
|
|
tx.err = errCanceled
|
|
tx.Done()
|
|
}
|
|
|
|
var errCanceled = errors.New("stun: transaction canceled")
|
|
var errTimeout = errors.New("stun: transaction timeout")
|