136 lines
2.9 KiB
Go
136 lines
2.9 KiB
Go
package shadowsocks
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"net"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
maxPacketSize = 4096 // increase it if error occurs
|
|
)
|
|
|
|
var (
|
|
errPacketTooSmall = fmt.Errorf("[udp]read error: cannot decrypt, received packet is smaller than ivLen")
|
|
errPacketTooLarge = fmt.Errorf("[udp]read error: received packet is latger than maxPacketSize(%d)", maxPacketSize)
|
|
errBufferTooSmall = fmt.Errorf("[udp]read error: given buffer is too small to hold data")
|
|
errPacketOtaFailed = fmt.Errorf("[udp]read error: received packet has invalid ota")
|
|
)
|
|
|
|
type SecurePacketConn struct {
|
|
net.PacketConn
|
|
*Cipher
|
|
ota bool
|
|
}
|
|
|
|
func NewSecurePacketConn(c net.PacketConn, cipher *Cipher, ota bool) *SecurePacketConn {
|
|
return &SecurePacketConn{
|
|
PacketConn: c,
|
|
Cipher: cipher,
|
|
ota: ota,
|
|
}
|
|
}
|
|
|
|
func (c *SecurePacketConn) Close() error {
|
|
return c.PacketConn.Close()
|
|
}
|
|
|
|
func (c *SecurePacketConn) ReadFrom(b []byte) (n int, src net.Addr, err error) {
|
|
ota := false
|
|
cipher := c.Copy()
|
|
buf := make([]byte, 4096)
|
|
n, src, err = c.PacketConn.ReadFrom(buf)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
if n < c.info.ivLen {
|
|
return 0, nil, errPacketTooSmall
|
|
}
|
|
|
|
if len(b) < n-c.info.ivLen {
|
|
err = errBufferTooSmall // just a warning
|
|
}
|
|
|
|
iv := make([]byte, c.info.ivLen)
|
|
copy(iv, buf[:c.info.ivLen])
|
|
|
|
if err = cipher.initDecrypt(iv); err != nil {
|
|
return
|
|
}
|
|
|
|
cipher.decrypt(b[0:], buf[c.info.ivLen:n])
|
|
n -= c.info.ivLen
|
|
if b[idType]&OneTimeAuthMask > 0 {
|
|
ota = true
|
|
}
|
|
|
|
if c.ota && !ota {
|
|
return 0, src, errPacketOtaFailed
|
|
}
|
|
|
|
if ota {
|
|
key := cipher.key
|
|
actualHmacSha1Buf := HmacSha1(append(iv, key...), b[:n-lenHmacSha1])
|
|
if !bytes.Equal(b[n-lenHmacSha1:n], actualHmacSha1Buf) {
|
|
Debug.Printf("verify one time auth failed, iv=%v key=%v data=%v", iv, key, b)
|
|
return 0, src, errPacketOtaFailed
|
|
}
|
|
n -= lenHmacSha1
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (c *SecurePacketConn) WriteTo(b []byte, dst net.Addr) (n int, err error) {
|
|
cipher := c.Copy()
|
|
iv, err := cipher.initEncrypt()
|
|
if err != nil {
|
|
return
|
|
}
|
|
packetLen := len(b) + len(iv)
|
|
|
|
if c.ota {
|
|
b[idType] |= OneTimeAuthMask
|
|
packetLen += lenHmacSha1
|
|
key := cipher.key
|
|
actualHmacSha1Buf := HmacSha1(append(iv, key...), b)
|
|
b = append(b, actualHmacSha1Buf...)
|
|
}
|
|
|
|
cipherData := make([]byte, packetLen)
|
|
copy(cipherData, iv)
|
|
|
|
cipher.encrypt(cipherData[len(iv):], b)
|
|
n, err = c.PacketConn.WriteTo(cipherData, dst)
|
|
if c.ota {
|
|
n -= lenHmacSha1
|
|
}
|
|
return
|
|
}
|
|
|
|
func (c *SecurePacketConn) LocalAddr() net.Addr {
|
|
return c.PacketConn.LocalAddr()
|
|
}
|
|
|
|
func (c *SecurePacketConn) SetDeadline(t time.Time) error {
|
|
return c.PacketConn.SetDeadline(t)
|
|
}
|
|
|
|
func (c *SecurePacketConn) SetReadDeadline(t time.Time) error {
|
|
return c.PacketConn.SetReadDeadline(t)
|
|
}
|
|
|
|
func (c *SecurePacketConn) SetWriteDeadline(t time.Time) error {
|
|
return c.PacketConn.SetWriteDeadline(t)
|
|
}
|
|
|
|
func (c *SecurePacketConn) IsOta() bool {
|
|
return c.ota
|
|
}
|
|
|
|
func (c *SecurePacketConn) ForceOTA() net.PacketConn {
|
|
return NewSecurePacketConn(c.PacketConn, c.Cipher.Copy(), true)
|
|
}
|