352 lines
6.4 KiB
Go
352 lines
6.4 KiB
Go
package stun
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"io"
|
|
"math/rand"
|
|
"net"
|
|
"sort"
|
|
"strconv"
|
|
)
|
|
|
|
const (
|
|
KindRequest uint16 = 0x0000
|
|
KindIndication uint16 = 0x0010
|
|
KindResponse uint16 = 0x0100
|
|
KindError uint16 = 0x0110
|
|
)
|
|
|
|
// Message represents a STUN message.
|
|
type Message struct {
|
|
Type uint16
|
|
Transaction []byte
|
|
Attributes []Attr
|
|
}
|
|
|
|
func (m *Message) Marshal(p []byte) []byte {
|
|
pos := len(p)
|
|
r, b := grow(p, 20)
|
|
be.PutUint16(b, m.Type)
|
|
|
|
if m.Transaction != nil {
|
|
copy(b[4:], m.Transaction)
|
|
} else {
|
|
copy(b[4:], magicCookie)
|
|
rand.Read(b[8:20])
|
|
}
|
|
|
|
sort.Sort(byPosition(m.Attributes))
|
|
for _, attr := range m.Attributes {
|
|
r = m.marshalAttr(r, attr, pos)
|
|
}
|
|
|
|
be.PutUint16(r[pos+2:], uint16(len(r)-pos-20))
|
|
return r
|
|
}
|
|
|
|
func (m *Message) marshalAttr(p []byte, attr Attr, pos int) []byte {
|
|
h := len(p)
|
|
r, b := grow(p, 4)
|
|
be.PutUint16(b, attr.Type())
|
|
|
|
switch v := attr.(type) {
|
|
case *addr:
|
|
r = v.MarshalAddr(r, r[pos+4:])
|
|
case *integrity:
|
|
r = v.MarshalSum(r, r[pos:])
|
|
case *fingerprint:
|
|
r = v.MarshalSum(r, r[pos:])
|
|
default:
|
|
r = v.Marshal(r)
|
|
}
|
|
n := len(r) - h - 4
|
|
be.PutUint16(r[h+2:], uint16(n))
|
|
|
|
if pad := n & 3; pad != 0 {
|
|
r, b = grow(r, 4-pad)
|
|
for i := range b {
|
|
b[i] = 0
|
|
}
|
|
}
|
|
return r
|
|
}
|
|
|
|
func (m *Message) Unmarshal(b []byte) (n int, err error) {
|
|
if len(b) < 20 {
|
|
err = io.EOF
|
|
return
|
|
}
|
|
l := int(be.Uint16(b[2:])) + 20
|
|
if len(b) < l {
|
|
err = io.EOF
|
|
return
|
|
}
|
|
pos, p := 20, make([]byte, l)
|
|
copy(p, b[:l])
|
|
|
|
m.Type = be.Uint16(p)
|
|
m.Transaction = p[4:20]
|
|
|
|
for pos < len(p) {
|
|
s, attr, err := m.unmarshalAttr(p, pos)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
pos += s
|
|
if attr != nil {
|
|
m.Attributes = append(m.Attributes, attr)
|
|
}
|
|
}
|
|
|
|
return l, nil
|
|
}
|
|
|
|
func (m *Message) unmarshalAttr(p []byte, pos int) (n int, attr Attr, err error) {
|
|
b := p[pos:]
|
|
if len(b) < 4 {
|
|
err = errFormat
|
|
return
|
|
}
|
|
typ := be.Uint16(b)
|
|
attr, n = newAttr(typ), int(be.Uint16(b[2:]))+4
|
|
if len(b) < n {
|
|
err = errFormat
|
|
return
|
|
}
|
|
|
|
b = b[4:n]
|
|
if attr != nil {
|
|
switch v := attr.(type) {
|
|
case *addr:
|
|
err = v.UnmarshalAddr(b, m.Transaction)
|
|
case *integrity:
|
|
err = v.UnmarshalSum(b, p[:pos+n])
|
|
case *fingerprint:
|
|
err = v.UnmarshalSum(b, p[:pos+n])
|
|
default:
|
|
err = attr.Unmarshal(b)
|
|
}
|
|
} else if typ < 0x8000 {
|
|
err = errFormat
|
|
}
|
|
if err != nil {
|
|
err = &errAttribute{err, typ}
|
|
return
|
|
}
|
|
if pad := n & 3; pad != 0 {
|
|
n += 4 - pad
|
|
if len(p) < pos+n {
|
|
err = errFormat
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func (m *Message) Kind() uint16 {
|
|
return m.Type & 0x110
|
|
}
|
|
|
|
func (m *Message) Method() uint16 {
|
|
return m.Type &^ 0x110
|
|
}
|
|
|
|
func (m *Message) Add(attr Attr) {
|
|
m.Attributes = append(m.Attributes, attr)
|
|
}
|
|
|
|
func (m *Message) Set(attr Attr) {
|
|
m.Del(attr.Type())
|
|
m.Add(attr)
|
|
}
|
|
|
|
func (m *Message) Del(typ uint16) {
|
|
n := 0
|
|
for _, a := range m.Attributes {
|
|
if a.Type() != typ {
|
|
m.Attributes[n] = a
|
|
n++
|
|
}
|
|
}
|
|
m.Attributes = m.Attributes[:n]
|
|
}
|
|
|
|
func (m *Message) Get(typ uint16) (attr Attr) {
|
|
for _, attr = range m.Attributes {
|
|
if attr.Type() == typ {
|
|
return
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *Message) Has(typ uint16) bool {
|
|
for _, attr := range m.Attributes {
|
|
if attr.Type() == typ {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (m *Message) GetString(typ uint16) string {
|
|
if str, ok := m.Get(typ).(fmt.Stringer); ok {
|
|
return str.String()
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (m *Message) GetAddr(network string, typ ...uint16) net.Addr {
|
|
for _, t := range typ {
|
|
if addr, ok := m.Get(t).(*addr); ok {
|
|
return addr.Addr(network)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *Message) GetInt(typ uint16) (v uint64, ok bool) {
|
|
attr := m.Get(typ)
|
|
if r, ok := attr.(*number); ok {
|
|
return r.v, true
|
|
}
|
|
return
|
|
}
|
|
|
|
func (m *Message) GetBytes(typ uint16) []byte {
|
|
if attr, ok := m.Get(typ).(*raw); ok {
|
|
return attr.data
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *Message) GetError() *Error {
|
|
if err, ok := m.Get(AttrErrorCode).(*Error); ok {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *Message) CheckIntegrity(key []byte) bool {
|
|
if attr, ok := m.Get(AttrMessageIntegrity).(*integrity); ok {
|
|
return attr.Check(key)
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (m *Message) CheckFingerprint() bool {
|
|
if attr, ok := m.Get(AttrFingerprint).(*fingerprint); ok {
|
|
return attr.Check()
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (m *Message) String() string {
|
|
sort.Sort(byPosition(m.Attributes))
|
|
|
|
// TODO: use sprintf
|
|
|
|
b := &bytes.Buffer{}
|
|
b.WriteString(MethodName(m.Type))
|
|
b.WriteByte('{')
|
|
tx := m.Transaction
|
|
if tx == nil {
|
|
b.WriteString("nil")
|
|
} else if bytes.Equal(magicCookie, tx[:4]) {
|
|
b.WriteString(hex.EncodeToString(tx[4:]))
|
|
} else {
|
|
b.WriteString(hex.EncodeToString(tx))
|
|
}
|
|
for _, attr := range m.Attributes {
|
|
b.WriteString(", ")
|
|
b.WriteString(AttrName(attr.Type()))
|
|
switch v := attr.(type) {
|
|
case *raw:
|
|
b.WriteString(": \"")
|
|
b.Write(v.data)
|
|
b.WriteByte('"')
|
|
case *str:
|
|
b.WriteString(": \"")
|
|
b.WriteString(v.data)
|
|
b.WriteByte('"')
|
|
case flag, *integrity, *fingerprint:
|
|
default:
|
|
b.WriteString(fmt.Sprintf(": %v", attr))
|
|
}
|
|
}
|
|
b.WriteByte('}')
|
|
return b.String()
|
|
}
|
|
|
|
func MethodName(typ uint16) string {
|
|
if r, ok := methodNames[typ&^0x110]; ok {
|
|
switch typ & 0x110 {
|
|
case KindRequest:
|
|
return r + "Request"
|
|
case KindIndication:
|
|
return r + "Indication"
|
|
case KindResponse:
|
|
return r + "Response"
|
|
case KindError:
|
|
return r + "Error"
|
|
}
|
|
}
|
|
return "0x" + strconv.FormatUint(uint64(typ), 16)
|
|
}
|
|
|
|
func UnmarshalMessage(b []byte) (*Message, error) {
|
|
m := &Message{}
|
|
if _, err := m.Unmarshal(b); err != nil {
|
|
return nil, err
|
|
}
|
|
return m, nil
|
|
}
|
|
|
|
var magicCookie = []byte{0x21, 0x12, 0xa4, 0x42}
|
|
var alphanum = dict("01234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
|
|
|
type dict []byte
|
|
|
|
func (d dict) rand(n int) string {
|
|
m, b := len(d), make([]byte, n)
|
|
for i := range b {
|
|
b[i] = d[rand.Intn(m)]
|
|
}
|
|
return string(b)
|
|
}
|
|
|
|
func NewTransaction() []byte {
|
|
id := make([]byte, 16)
|
|
copy(id, magicCookie)
|
|
rand.Read(id[4:]) // TODO: configure random source
|
|
return id
|
|
}
|
|
|
|
type byPosition []Attr
|
|
|
|
func (s byPosition) Len() int { return len(s) }
|
|
func (s byPosition) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
|
func (s byPosition) Less(i, j int) bool {
|
|
a, b := s[i].Type(), s[j].Type()
|
|
switch b {
|
|
case a:
|
|
return i < j
|
|
case AttrMessageIntegrity:
|
|
return a != AttrFingerprint
|
|
case AttrFingerprint:
|
|
return true
|
|
default:
|
|
return i < j
|
|
}
|
|
}
|
|
|
|
type errAttribute struct {
|
|
error
|
|
typ uint16
|
|
}
|
|
|
|
func (err errAttribute) Error() string {
|
|
return "attribute " + AttrName(err.typ) + ": " + err.error.Error()
|
|
}
|