add chain.DialContext

This commit is contained in:
ginuerzh 2020-02-08 15:02:04 +08:00
parent 425099a7ba
commit abe4043413
17 changed files with 426 additions and 387 deletions

View File

@ -1,6 +1,7 @@
package gost
import (
"context"
"errors"
"net"
"time"
@ -100,9 +101,14 @@ func (c *Chain) IsEmpty() bool {
return c == nil || len(c.nodeGroups) == 0
}
// Dial connects to the target address addr through the chain.
// If the chain is empty, it will use the net.Dial directly.
func (c *Chain) Dial(addr string, opts ...ChainOption) (conn net.Conn, err error) {
// Dial connects to the target TCP address addr through the chain.
// Deprecated: use DialContext instead.
func (c *Chain) Dial(address string, opts ...ChainOption) (conn net.Conn, err error) {
return c.DialContext(context.Background(), "tcp", address, opts...)
}
// DialContext connects to the address on the named network using the provided context.
func (c *Chain) DialContext(ctx context.Context, network, address string, opts ...ChainOption) (conn net.Conn, err error) {
options := &ChainOptions{}
for _, opt := range opts {
opt(options)
@ -117,7 +123,7 @@ func (c *Chain) Dial(addr string, opts ...ChainOption) (conn net.Conn, err error
}
for i := 0; i < retries; i++ {
conn, err = c.dialWithOptions(addr, options)
conn, err = c.dialWithOptions(ctx, network, address, options)
if err == nil {
break
}
@ -125,16 +131,19 @@ func (c *Chain) Dial(addr string, opts ...ChainOption) (conn net.Conn, err error
return
}
func (c *Chain) dialWithOptions(addr string, options *ChainOptions) (net.Conn, error) {
func (c *Chain) dialWithOptions(ctx context.Context, network, address string, options *ChainOptions) (net.Conn, error) {
if options == nil {
options = &ChainOptions{}
}
route, err := c.selectRouteFor(addr)
route, err := c.selectRouteFor(address)
if err != nil {
return nil, err
}
ipAddr := c.resolve(addr, options.Resolver, options.Hosts)
ipAddr := address
if address != "" {
ipAddr = c.resolve(address, options.Resolver, options.Hosts)
}
timeout := options.Timeout
if timeout <= 0 {
@ -142,16 +151,27 @@ func (c *Chain) dialWithOptions(addr string, options *ChainOptions) (net.Conn, e
}
if route.IsEmpty() {
return net.DialTimeout("tcp", ipAddr, timeout)
switch network {
case "udp", "udp4", "udp6":
if address == "" {
return net.ListenUDP(network, nil)
}
default:
}
d := &net.Dialer{
Timeout: timeout,
// LocalAddr: laddr, // TODO: optional local address
}
return d.DialContext(ctx, network, ipAddr)
}
conn, err := route.getConn()
conn, err := route.getConn(ctx)
if err != nil {
return nil, err
}
cOpts := append([]ConnectOption{AddrConnectOption(addr)}, route.LastNode().ConnectOptions...)
cc, err := route.LastNode().Client.Connect(conn, ipAddr, cOpts...)
cOpts := append([]ConnectOption{AddrConnectOption(address)}, route.LastNode().ConnectOptions...)
cc, err := route.LastNode().Client.ConnectContext(ctx, conn, network, ipAddr, cOpts...)
if err != nil {
conn.Close()
return nil, err
@ -187,6 +207,8 @@ func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) {
opt(options)
}
ctx := context.Background()
retries := 1
if c != nil && c.Retries > 0 {
retries = c.Retries
@ -201,7 +223,7 @@ func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) {
if err != nil {
continue
}
conn, err = route.getConn()
conn, err = route.getConn(ctx)
if err == nil {
break
}
@ -210,7 +232,7 @@ func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) {
}
// getConn obtains a connection to the last node of the chain.
func (c *Chain) getConn() (conn net.Conn, err error) {
func (c *Chain) getConn(ctx context.Context) (conn net.Conn, err error) {
if c.IsEmpty() {
err = ErrEmptyChain
return
@ -234,7 +256,7 @@ func (c *Chain) getConn() (conn net.Conn, err error) {
preNode := node
for _, node := range nodes[1:] {
var cc net.Conn
cc, err = preNode.Client.Connect(cn, node.Addr, preNode.ConnectOptions...)
cc, err = preNode.Client.ConnectContext(ctx, cn, "tcp", node.Addr, preNode.ConnectOptions...)
if err != nil {
cn.Close()
node.MarkDead()

View File

@ -1,6 +1,7 @@
package gost
import (
"context"
"crypto/tls"
"net"
"net/url"
@ -14,23 +15,8 @@ import (
// Connector is responsible for connecting to the destination address through this proxy.
// Transporter performs a handshake with this proxy.
type Client struct {
Connector Connector
Transporter Transporter
}
// Dial connects to the target address.
func (c *Client) Dial(addr string, options ...DialOption) (net.Conn, error) {
return c.Transporter.Dial(addr, options...)
}
// Handshake performs a handshake with the proxy over connection conn.
func (c *Client) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) {
return c.Transporter.Handshake(conn, options...)
}
// Connect connects to the address addr via the proxy over connection conn.
func (c *Client) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
return c.Connector.Connect(conn, addr, options...)
Connector
Transporter
}
// DefaultClient is a standard HTTP proxy client.
@ -53,7 +39,36 @@ func Connect(conn net.Conn, addr string) (net.Conn, error) {
// Connector is responsible for connecting to the destination address.
type Connector interface {
Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error)
// Deprecated: use ConnectContext instead.
Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error)
ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error)
}
type autoConnector struct {
User *url.Userinfo
}
// AutoConnector is a Connector.
func AutoConnector(user *url.Userinfo) Connector {
return &autoConnector{
User: user,
}
}
func (c *autoConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
return c.ConnectContext(context.Background(), conn, "tcp", address, options...)
}
func (c *autoConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
var cnr Connector
switch network {
case "tcp", "tcp4", "tcp6":
cnr = &httpConnector{User: c.User}
default:
cnr = &socks5UDPTunConnector{User: c.User}
}
return cnr.ConnectContext(ctx, conn, network, address, options...)
}
// Transporter is responsible for handshaking with the proxy server.

View File

@ -227,10 +227,9 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
case "sni":
connector = gost.SNIConnector(node.Get("host"))
case "http":
fallthrough
default:
node.Protocol = "http" // default protocol is HTTP
connector = gost.HTTPConnector(node.User)
default:
connector = gost.AutoConnector(node.User)
}
timeout := node.GetInt("timeout")

View File

@ -1,6 +1,7 @@
package gost
import (
"context"
"errors"
"net"
"strings"
@ -22,7 +23,11 @@ func ForwardConnector() Connector {
return &forwardConnector{}
}
func (c *forwardConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
func (c *forwardConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
return conn, nil
}
func (c *forwardConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
return conn, nil
}
@ -186,42 +191,12 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) {
return
}
raddr, err := net.ResolveUDPAddr("udp", node.Addr)
cc, err := h.options.Chain.DialContext(context.Background(), "udp", node.Addr)
if err != nil {
node.MarkDead()
log.Logf("[udp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
return
}
var cc net.Conn
if h.options.Chain.IsEmpty() {
cc, err = net.DialUDP("udp", nil, raddr)
if err != nil {
node.MarkDead()
log.Logf("[udp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
return
}
} else if h.options.Chain.LastNode().Protocol == "ssu" {
cc, err = h.options.Chain.Dial(node.Addr,
RetryChainOption(h.options.Retries),
TimeoutChainOption(h.options.Timeout),
)
if err != nil {
node.MarkDead()
log.Logf("[udp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
return
}
} else {
var err error
cc, err = getSOCKS5UDPTunnel(h.options.Chain, nil)
if err != nil {
log.Logf("[udp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
return
}
cc = &udpTunnelConn{Conn: cc, raddr: raddr}
}
defer cc.Close()
node.ResetDead()
@ -726,11 +701,11 @@ func (l *udpRemoteForwardListener) connect() (conn net.PacketConn, err error) {
lastNode := l.chain.LastNode()
if lastNode.Protocol == "socks5" {
var cc net.Conn
cc, err = getSOCKS5UDPTunnel(l.chain, l.addr)
cc, err = getSocks5UDPTunnel(l.chain, l.addr)
if err != nil {
log.Logf("[rudp] %s : %s", l.Addr(), err)
} else {
conn = &udpTunnelConn{Conn: cc}
conn = cc.(net.PacketConn)
}
} else {
var uc *net.UDPConn

View File

@ -20,7 +20,7 @@ import (
)
// Version is the gost version.
const Version = "2.10.0"
const Version = "2.10.1"
// Debug is a flag that enables the debug log.
var Debug bool

16
http.go
View File

@ -3,6 +3,7 @@ package gost
import (
"bufio"
"bytes"
"context"
"encoding/base64"
"fmt"
"net"
@ -27,7 +28,16 @@ func HTTPConnector(user *url.Userinfo) Connector {
return &httpConnector{User: user}
}
func (c *httpConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
func (c *httpConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
return c.ConnectContext(context.Background(), conn, "tcp", address, options...)
}
func (c *httpConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
switch network {
case "udp", "udp4", "udp6":
return nil, fmt.Errorf("%s unsupported", network)
}
opts := &ConnectOptions{}
for _, option := range options {
option(opts)
@ -47,8 +57,8 @@ func (c *httpConnector) Connect(conn net.Conn, addr string, options ...ConnectOp
req := &http.Request{
Method: http.MethodConnect,
URL: &url.URL{Host: addr},
Host: addr,
URL: &url.URL{Host: address},
Host: address,
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),

View File

@ -3,6 +3,7 @@ package gost
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"errors"
@ -33,7 +34,16 @@ func HTTP2Connector(user *url.Userinfo) Connector {
return &http2Connector{User: user}
}
func (c *http2Connector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
func (c *http2Connector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
return c.ConnectContext(context.Background(), conn, "tcp", address, options...)
}
func (c *http2Connector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
switch network {
case "udp", "udp4", "udp6":
return nil, fmt.Errorf("%s unsupported", network)
}
opts := &ConnectOptions{}
for _, option := range options {
option(opts)
@ -57,7 +67,7 @@ func (c *http2Connector) Connect(conn net.Conn, addr string, options ...ConnectO
ProtoMajor: 2,
ProtoMinor: 0,
Body: pr,
Host: addr,
Host: address,
ContentLength: -1,
}
req.Header.Set("User-Agent", ua)
@ -97,7 +107,7 @@ func (c *http2Connector) Connect(conn net.Conn, addr string, options ...ConnectO
closed: make(chan struct{}),
}
hc.remoteAddr, _ = net.ResolveTCPAddr("tcp", addr)
hc.remoteAddr, _ = net.ResolveTCPAddr("tcp", address)
hc.localAddr, _ = net.ResolveTCPAddr("tcp", cc.addr)
return hc, nil

View File

@ -3,6 +3,7 @@
package gost
import (
"context"
"errors"
"fmt"
"net"
@ -132,32 +133,14 @@ func (h *udpRedirectHandler) Handle(conn net.Conn) {
return
}
var cc net.Conn
var err error
if h.options.Chain.IsEmpty() {
cc, err = net.DialUDP("udp", nil, raddr)
if err != nil {
log.Logf("[red-udp] %s - %s : %s", conn.RemoteAddr(), raddr, err)
return
}
} else if h.options.Chain.LastNode().Protocol == "ssu" {
cc, err = h.options.Chain.Dial(raddr.String(),
RetryChainOption(h.options.Retries),
TimeoutChainOption(h.options.Timeout),
)
if err != nil {
log.Logf("[red-udp] %s - %s : %s", conn.RemoteAddr(), raddr, err)
return
}
} else {
var err error
cc, err = getSOCKS5UDPTunnel(h.options.Chain, nil)
if err != nil {
log.Logf("[red-udp] %s - %s : %s", conn.RemoteAddr(), raddr, err)
return
}
cc = &udpTunnelConn{Conn: cc, raddr: raddr}
cc, err := h.options.Chain.DialContext(context.Background(),
"udp", raddr.String(),
RetryChainOption(h.options.Retries),
TimeoutChainOption(h.options.Timeout),
)
if err != nil {
log.Logf("[red-udp] %s - %s : %s", conn.RemoteAddr(), raddr, err)
return
}
defer cc.Close()

View File

@ -606,31 +606,12 @@ func NewDNSExchanger(addr string, opts ...ExchangerOption) Exchanger {
}
}
func (ex *dnsExchanger) dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
if ex.options.chain.IsEmpty() {
d := &net.Dialer{
Timeout: ex.options.timeout,
}
return d.DialContext(ctx, network, address)
}
if ex.options.chain.LastNode().Protocol == "ssu" {
return ex.options.chain.Dial(address, TimeoutChainOption(ex.options.timeout))
}
raddr, err := net.ResolveUDPAddr(network, address)
if err != nil {
return
}
cc, err := getSOCKS5UDPTunnel(ex.options.chain, nil)
conn = &udpTunnelConn{Conn: cc, raddr: raddr}
return
}
func (ex *dnsExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) {
t := time.Now()
c, err := ex.dial(ctx, "udp", ex.addr)
c, err := ex.options.chain.DialContext(ctx,
"udp", ex.addr,
TimeoutChainOption(ex.options.timeout),
)
if err != nil {
return nil, err
}
@ -674,19 +655,12 @@ func NewDNSTCPExchanger(addr string, opts ...ExchangerOption) Exchanger {
}
}
func (ex *dnsTCPExchanger) dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
if ex.options.chain.IsEmpty() {
d := &net.Dialer{
Timeout: ex.options.timeout,
}
return d.DialContext(ctx, network, address)
}
return ex.options.chain.Dial(address, TimeoutChainOption(ex.options.timeout))
}
func (ex *dnsTCPExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) {
t := time.Now()
c, err := ex.dial(ctx, "tcp", ex.addr)
c, err := ex.options.chain.DialContext(ctx,
"tcp", ex.addr,
TimeoutChainOption(ex.options.timeout),
)
if err != nil {
return nil, err
}
@ -738,14 +712,10 @@ func NewDoTExchanger(addr string, tlsConfig *tls.Config, opts ...ExchangerOption
}
func (ex *dotExchanger) dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
if ex.options.chain.IsEmpty() {
d := &net.Dialer{
Timeout: ex.options.timeout,
}
conn, err = d.DialContext(ctx, network, address)
} else {
conn, err = ex.options.chain.Dial(address, TimeoutChainOption(ex.options.timeout))
}
conn, err = ex.options.chain.DialContext(ctx,
network, address,
TimeoutChainOption(ex.options.timeout),
)
if err != nil {
return
}
@ -812,14 +782,11 @@ func NewDoHExchanger(urlStr *url.URL, tlsConfig *tls.Config, opts ...ExchangerOp
return ex
}
func (ex *dohExchanger) dialContext(ctx context.Context, network, address string) (conn net.Conn, err error) {
if ex.options.chain.IsEmpty() {
d := &net.Dialer{
Timeout: ex.options.timeout,
}
return d.DialContext(ctx, network, address)
}
return ex.options.chain.Dial(address, TimeoutChainOption(ex.options.timeout))
func (ex *dohExchanger) dialContext(ctx context.Context, network, address string) (net.Conn, error) {
return ex.options.chain.DialContext(ctx,
network, address,
TimeoutChainOption(ex.options.timeout),
)
}
func (ex *dohExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) {

View File

@ -1,6 +1,6 @@
name: gost
type: app
version: '2.10.0'
version: '2.10.1'
title: GO Simple Tunnel
summary: A simple security tunnel written in golang
description: |

14
sni.go
View File

@ -5,6 +5,7 @@ package gost
import (
"bufio"
"bytes"
"context"
"encoding/base64"
"encoding/binary"
"errors"
@ -29,8 +30,17 @@ func SNIConnector(host string) Connector {
return &sniConnector{host: host}
}
func (c *sniConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
return &sniClientConn{addr: addr, host: c.host, Conn: conn}, nil
func (c *sniConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
return c.ConnectContext(context.Background(), conn, "tcp", address, options...)
}
func (c *sniConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
switch network {
case "udp", "udp4", "udp6":
return nil, fmt.Errorf("%s unsupported", network)
}
return &sniClientConn{addr: address, host: c.host, Conn: conn}, nil
}
type sniHandler struct {

279
socks.go
View File

@ -2,6 +2,7 @@ package gost
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
@ -35,6 +36,10 @@ const (
CmdUDPTun uint8 = 0xF3
)
var (
_ net.PacketConn = (*socks5UDPTunnelConn)(nil)
)
type clientSelector struct {
methods []uint8
User *url.Userinfo
@ -201,7 +206,17 @@ func SOCKS5Connector(user *url.Userinfo) Connector {
return &socks5Connector{User: user}
}
func (c *socks5Connector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
func (c *socks5Connector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
return c.ConnectContext(context.Background(), conn, "tcp", address, options...)
}
func (c *socks5Connector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
switch network {
case "udp", "udp4", "udp6":
cnr := &socks5UDPTunConnector{User: c.User}
return cnr.ConnectContext(ctx, conn, network, address, options...)
}
opts := &ConnectOptions{}
for _, option := range options {
option(opts)
@ -229,7 +244,7 @@ func (c *socks5Connector) Connect(conn net.Conn, addr string, options ...Connect
}
conn = cc
host, port, err := net.SplitHostPort(addr)
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
@ -273,7 +288,16 @@ func SOCKS5BindConnector(user *url.Userinfo) Connector {
return &socks5BindConnector{User: user}
}
func (c *socks5BindConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
func (c *socks5BindConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
return c.ConnectContext(context.Background(), conn, "tcp", address, options...)
}
func (c *socks5BindConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
switch network {
case "udp", "udp4", "udp6":
return nil, fmt.Errorf("%s unsupported", network)
}
opts := &ConnectOptions{}
for _, option := range options {
option(opts)
@ -301,7 +325,7 @@ func (c *socks5BindConnector) Connect(conn net.Conn, addr string, options ...Con
}
conn = cc
laddr, err := net.ResolveTCPAddr("tcp", addr)
laddr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
log.Log(err)
return nil, err
@ -331,8 +355,8 @@ func (c *socks5BindConnector) Connect(conn net.Conn, addr string, options ...Con
}
if reply.Rep != gosocks5.Succeeded {
log.Logf("[socks5] bind on %s failure", addr)
return nil, fmt.Errorf("SOCKS5 bind on %s failure", addr)
log.Logf("[socks5] bind on %s failure", address)
return nil, fmt.Errorf("SOCKS5 bind on %s failure", address)
}
baddr, err := net.ResolveTCPAddr("tcp", reply.Addr.String())
if err != nil {
@ -350,8 +374,17 @@ func Socks5MuxBindConnector() Connector {
return &socks5MuxBindConnector{}
}
func (c *socks5MuxBindConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
return c.ConnectContext(context.Background(), conn, "tcp", address, options...)
}
// NOTE: the conn must be *muxBindClientConn.
func (c *socks5MuxBindConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
func (c *socks5MuxBindConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
switch network {
case "udp", "udp4", "udp6":
return nil, fmt.Errorf("%s unsupported", network)
}
accepter, ok := conn.(Accepter)
if !ok {
return nil, errors.New("wrong connection type")
@ -513,7 +546,16 @@ func SOCKS5UDPConnector(user *url.Userinfo) Connector {
return &socks5UDPConnector{User: user}
}
func (c *socks5UDPConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
func (c *socks5UDPConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
return c.ConnectContext(context.Background(), conn, "udp", address, options...)
}
func (c *socks5UDPConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
switch network {
case "tcp", "tcp4", "tcp6":
return nil, fmt.Errorf("%s unsupported", network)
}
opts := &ConnectOptions{}
for _, option := range options {
option(opts)
@ -541,7 +583,7 @@ func (c *socks5UDPConnector) Connect(conn net.Conn, addr string, options ...Conn
}
conn = cc
taddr, err := net.ResolveUDPAddr("udp", addr)
taddr, err := net.ResolveUDPAddr("udp", address)
if err != nil {
return nil, err
}
@ -596,71 +638,40 @@ func SOCKS5UDPTunConnector(user *url.Userinfo) Connector {
return &socks5UDPTunConnector{User: user}
}
func (c *socks5UDPTunConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
func (c *socks5UDPTunConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
return c.ConnectContext(context.Background(), conn, "udp", address, options...)
}
func (c *socks5UDPTunConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
switch network {
case "tcp", "tcp4", "tcp6":
return nil, fmt.Errorf("%s unsupported", network)
}
opts := &ConnectOptions{}
for _, option := range options {
option(opts)
}
user := opts.User
if user == nil {
user = c.User
}
timeout := opts.Timeout
if timeout <= 0 {
timeout = ConnectTimeout
}
conn.SetDeadline(time.Now().Add(timeout))
defer conn.SetDeadline(time.Time{})
user := opts.User
if user == nil {
user = c.User
}
cc, err := socks5Handshake(conn,
taddr, _ := net.ResolveUDPAddr("udp", address)
return newSocks5UDPTunnelConn(conn,
nil, taddr,
selectorSocks5HandshakeOption(opts.Selector),
userSocks5HandshakeOption(user),
noTLSSocks5HandshakeOption(opts.NoTLS),
)
if err != nil {
return nil, err
}
conn = cc
taddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
req := gosocks5.NewRequest(CmdUDPTun, &gosocks5.Addr{
Type: gosocks5.AddrIPv4,
})
if err := req.Write(conn); err != nil {
return nil, err
}
if Debug {
log.Log("[socks5] udp\n", req)
}
reply, err := gosocks5.ReadReply(conn)
if err != nil {
return nil, err
}
if Debug {
log.Log("[socks5] udp\n", reply)
}
if reply.Rep != gosocks5.Succeeded {
log.Logf("[socks5] udp relay failure")
return nil, fmt.Errorf("SOCKS5 udp relay failure")
}
baddr, err := net.ResolveUDPAddr("udp", reply.Addr.String())
if err != nil {
return nil, err
}
log.Logf("[socks5] udp-tun associate on %s OK", baddr)
return &udpTunnelConn{Conn: conn, raddr: taddr}, nil
}
type socks4Connector struct{}
@ -670,7 +681,16 @@ func SOCKS4Connector() Connector {
return &socks4Connector{}
}
func (c *socks4Connector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
func (c *socks4Connector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
return c.ConnectContext(context.Background(), conn, "tcp", address, options...)
}
func (c *socks4Connector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
switch network {
case "udp", "udp4", "udp6":
return nil, fmt.Errorf("%s unsupported", network)
}
opts := &ConnectOptions{}
for _, option := range options {
option(opts)
@ -684,7 +704,7 @@ func (c *socks4Connector) Connect(conn net.Conn, addr string, options ...Connect
conn.SetDeadline(time.Now().Add(timeout))
defer conn.SetDeadline(time.Time{})
taddr, err := net.ResolveTCPAddr("tcp4", addr)
taddr, err := net.ResolveTCPAddr("tcp4", address)
if err != nil {
return nil, err
}
@ -730,7 +750,16 @@ func SOCKS4AConnector() Connector {
return &socks4aConnector{}
}
func (c *socks4aConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
func (c *socks4aConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
return c.ConnectContext(context.Background(), conn, "tcp", address, options...)
}
func (c *socks4aConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
switch network {
case "udp", "udp4", "udp6":
return nil, fmt.Errorf("%s unsupported", network)
}
opts := &ConnectOptions{}
for _, option := range options {
option(opts)
@ -744,7 +773,7 @@ func (c *socks4aConnector) Connect(conn net.Conn, addr string, options ...Connec
conn.SetDeadline(time.Now().Add(timeout))
defer conn.SetDeadline(time.Time{})
host, port, err := net.SplitHostPort(addr)
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
@ -1601,6 +1630,7 @@ func (h *socks5Handler) muxBindOn(conn net.Conn, addr string) {
}
}
// TODO: support ipv6 and domain
func toSocksAddr(addr net.Addr) *gosocks5.Addr {
host := "0.0.0.0"
port := 0
@ -1795,52 +1825,6 @@ func (h *socks4Handler) handleBind(conn net.Conn, req *gosocks4.Request) {
log.Logf("[socks4-bind] %s >-< %s", conn.RemoteAddr(), cc.RemoteAddr())
}
func getSOCKS5UDPTunnel(chain *Chain, addr net.Addr) (net.Conn, error) {
conn, err := chain.Conn()
if err != nil {
return nil, err
}
conn.SetDeadline(time.Now().Add(HandshakeTimeout))
defer conn.SetDeadline(time.Time{})
node := chain.LastNode()
cc, err := socks5Handshake(conn,
userSocks5HandshakeOption(node.User),
noTLSSocks5HandshakeOption(node.GetBool("notls")),
)
if err != nil {
conn.Close()
return nil, err
}
conn = cc
req := gosocks5.NewRequest(CmdUDPTun, toSocksAddr(addr))
if err := req.Write(conn); err != nil {
conn.Close()
return nil, err
}
if Debug {
log.Log("[socks5]", req)
}
reply, err := gosocks5.ReadReply(conn)
if err != nil {
conn.Close()
return nil, err
}
if Debug {
log.Log("[socks5]", reply)
}
if reply.Rep != gosocks5.Succeeded {
conn.Close()
return nil, errors.New("UDP tunnel failure")
}
return conn, nil
}
type socks5HandshakeOptions struct {
selector gosocks5.Selector
user *url.Userinfo
@ -1896,21 +1880,74 @@ func socks5Handshake(conn net.Conn, opts ...socks5HandshakeOption) (net.Conn, er
return cc, nil
}
type udpTunnelConn struct {
raddr net.Addr
net.Conn
func getSocks5UDPTunnel(chain *Chain, addr net.Addr) (net.Conn, error) {
c, err := chain.Conn()
if err != nil {
return nil, err
}
node := chain.LastNode()
conn, err := newSocks5UDPTunnelConn(c,
addr, nil,
userSocks5HandshakeOption(node.User),
noTLSSocks5HandshakeOption(node.GetBool("notls")),
)
if err != nil {
c.Close()
}
return conn, nil
}
func (c *udpTunnelConn) Read(b []byte) (n int, err error) {
dgram, err := gosocks5.ReadUDPDatagram(c.Conn)
type socks5UDPTunnelConn struct {
net.Conn
taddr net.Addr
}
func newSocks5UDPTunnelConn(conn net.Conn, raddr, taddr net.Addr, opts ...socks5HandshakeOption) (net.Conn, error) {
cc, err := socks5Handshake(conn, opts...)
if err != nil {
return
return nil, err
}
n = copy(b, dgram.Data)
req := gosocks5.NewRequest(CmdUDPTun, toSocksAddr(raddr))
if err := req.Write(cc); err != nil {
return nil, err
}
if Debug {
log.Log("[socks5] udp-tun", req)
}
reply, err := gosocks5.ReadReply(cc)
if err != nil {
return nil, err
}
if Debug {
log.Log("[socks5] udp-tun", reply)
}
if reply.Rep != gosocks5.Succeeded {
return nil, errors.New("socks5 UDP tunnel failure")
}
baddr, err := net.ResolveUDPAddr("udp", reply.Addr.String())
if err != nil {
return nil, err
}
log.Logf("[socks5] udp-tun associate on %s OK", baddr)
return &socks5UDPTunnelConn{
Conn: cc,
taddr: taddr,
}, nil
}
func (c *socks5UDPTunnelConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *udpTunnelConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
func (c *socks5UDPTunnelConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
dgram, err := gosocks5.ReadUDPDatagram(c.Conn)
if err != nil {
return
@ -1920,15 +1957,11 @@ func (c *udpTunnelConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
return
}
func (c *udpTunnelConn) Write(b []byte) (n int, err error) {
dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(uint16(len(b)), 0, toSocksAddr(c.raddr)), b)
if err = dgram.Write(c.Conn); err != nil {
return
}
return len(b), nil
func (c *socks5UDPTunnelConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.taddr)
}
func (c *udpTunnelConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
func (c *socks5UDPTunnelConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(uint16(len(b)), 0, toSocksAddr(addr)), b)
if err = dgram.Write(c.Conn); err != nil {
return

163
ss.go
View File

@ -2,6 +2,7 @@ package gost
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"io"
@ -15,6 +16,15 @@ import (
ss "github.com/shadowsocks/shadowsocks-go/shadowsocks"
)
const (
maxSocksAddrLen = 259
)
var (
_ net.Conn = (*shadowConn)(nil)
_ net.PacketConn = (*shadowUDPPacketConn)(nil)
)
type shadowConnector struct {
cipher core.Cipher
}
@ -27,7 +37,16 @@ func ShadowConnector(info *url.Userinfo) Connector {
}
}
func (c *shadowConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
func (c *shadowConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
return c.ConnectContext(context.Background(), conn, "tcp", address, options...)
}
func (c *shadowConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
switch network {
case "udp", "udp4", "udp6":
return nil, fmt.Errorf("%s unsupported", network)
}
opts := &ConnectOptions{}
for _, option := range options {
option(opts)
@ -38,7 +57,7 @@ func (c *shadowConnector) Connect(conn net.Conn, addr string, options ...Connect
timeout = ConnectTimeout
}
socksAddr, err := gosocks5.NewAddr(addr)
socksAddr, err := gosocks5.NewAddr(address)
if err != nil {
return nil, err
}
@ -183,7 +202,16 @@ func ShadowUDPConnector(info *url.Userinfo) Connector {
}
}
func (c *shadowUDPConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
func (c *shadowUDPConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
return c.ConnectContext(context.Background(), conn, "udp", address, options...)
}
func (c *shadowUDPConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
switch network {
case "tcp", "tcp4", "tcp6":
return nil, fmt.Errorf("%s unsupported", network)
}
opts := &ConnectOptions{}
for _, option := range options {
option(opts)
@ -197,13 +225,13 @@ func (c *shadowUDPConnector) Connect(conn net.Conn, addr string, options ...Conn
conn.SetDeadline(time.Now().Add(timeout))
defer conn.SetDeadline(time.Time{})
taddr, _ := net.ResolveUDPAddr(network, address)
if taddr == nil {
taddr = &net.UDPAddr{}
}
pc, ok := conn.(net.PacketConn)
if ok {
rawaddr, err := ss.RawAddr(addr)
if err != nil {
return nil, err
}
if c.cipher != nil {
pc = c.cipher.PacketConn(pc)
}
@ -211,22 +239,17 @@ func (c *shadowUDPConnector) Connect(conn net.Conn, addr string, options ...Conn
return &shadowUDPPacketConn{
PacketConn: pc,
raddr: conn.RemoteAddr(),
header: rawaddr,
taddr: taddr,
}, nil
}
taddr, err := gosocks5.NewAddr(addr)
if err != nil {
return nil, err
}
if c.cipher != nil {
conn = c.cipher.StreamConn(conn)
}
return &shadowUDPStreamConn{
Conn: conn,
addr: taddr,
return &socks5UDPTunnelConn{
Conn: conn,
taddr: taddr,
}, nil
}
@ -258,23 +281,13 @@ func (h *shadowUDPHandler) Init(options ...HandlerOption) {
func (h *shadowUDPHandler) Handle(conn net.Conn) {
defer conn.Close()
var err error
var cc net.PacketConn
if h.options.Chain.IsEmpty() {
cc, err = net.ListenUDP("udp", nil)
if err != nil {
log.Logf("[ssu] %s - : %s", conn.LocalAddr(), err)
return
}
} else {
var c net.Conn
c, err = getSOCKS5UDPTunnel(h.options.Chain, nil)
if err != nil {
log.Logf("[ssu] %s - : %s", conn.LocalAddr(), err)
return
}
cc = &udpTunnelConn{Conn: c}
c, err := h.options.Chain.DialContext(context.Background(), "udp", "")
if err != nil {
log.Logf("[ssu] %s: %s", conn.LocalAddr(), err)
return
}
cc = c.(net.PacketConn)
defer cc.Close()
pc, ok := conn.(net.PacketConn)
@ -466,24 +479,11 @@ func (c *shadowConn) Write(b []byte) (n int, err error) {
type shadowUDPPacketConn struct {
net.PacketConn
raddr net.Addr
header []byte
raddr net.Addr
taddr net.Addr
}
func (c *shadowUDPPacketConn) Write(b []byte) (n int, err error) {
n = len(b) // force byte length consistent
buf := bytes.Buffer{}
if _, err = buf.Write(c.header); err != nil {
return
}
if _, err = buf.Write(b); err != nil {
return
}
_, err = c.PacketConn.WriteTo(buf.Bytes(), c.raddr)
return
}
func (c *shadowUDPPacketConn) Read(b []byte) (n int, err error) {
func (c *shadowUDPPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
buf := mPool.Get().([]byte)
defer mPool.Put(buf)
@ -501,47 +501,46 @@ func (c *shadowUDPPacketConn) Read(b []byte) (n int, err error) {
return
}
n = copy(b, dgram.Data)
addr, err = net.ResolveUDPAddr("udp", dgram.Header.Addr.String())
return
}
func (c *shadowUDPPacketConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *shadowUDPPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
sa, err := gosocks5.NewAddr(addr.String())
if err != nil {
return
}
var rawaddr [maxSocksAddrLen]byte
nn, err := sa.Encode(rawaddr[:])
if err != nil {
return
}
buf := mPool.Get().([]byte)
defer mPool.Put(buf)
copy(buf, rawaddr[:nn])
n = copy(buf[nn:], b)
_, err = c.PacketConn.WriteTo(buf[:n+nn], c.raddr)
return
}
func (c *shadowUDPPacketConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.taddr)
}
func (c *shadowUDPPacketConn) RemoteAddr() net.Addr {
return c.raddr
}
type shadowUDPStreamConn struct {
net.Conn
addr *gosocks5.Addr
}
func (c *shadowUDPStreamConn) Read(b []byte) (n int, err error) {
dgram, err := gosocks5.ReadUDPDatagram(c.Conn)
if err != nil {
return
}
n = copy(b, dgram.Data)
return
}
func (c *shadowUDPStreamConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
n, err = c.Read(b)
addr = c.Conn.RemoteAddr()
return
}
func (c *shadowUDPStreamConn) Write(b []byte) (n int, err error) {
n = len(b) // force byte length consistent
dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(uint16(len(b)), 0, c.addr), b)
buf := bytes.Buffer{}
dgram.Write(&buf)
_, err = c.Conn.Write(buf.Bytes())
return
}
func (c *shadowUDPStreamConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
return c.Write(b)
}
type shadowCipher struct {
cipher *ss.Cipher
}

View File

@ -138,7 +138,7 @@ var ssProxyTests = []struct {
serverCipher *url.Userinfo
pass bool
}{
{nil, nil, false},
{nil, nil, true},
{&url.Userinfo{}, &url.Userinfo{}, true},
{url.User("abc"), url.User("abc"), true},
{url.UserPassword("abc", "def"), url.UserPassword("abc", "def"), true},

30
ssh.go
View File

@ -39,6 +39,15 @@ func SSHDirectForwardConnector() Connector {
}
func (c *sshDirectForwardConnector) Connect(conn net.Conn, raddr string, options ...ConnectOption) (net.Conn, error) {
return c.ConnectContext(context.Background(), conn, "tcp", raddr, options...)
}
func (c *sshDirectForwardConnector) ConnectContext(ctx context.Context, conn net.Conn, network, raddr string, options ...ConnectOption) (net.Conn, error) {
switch network {
case "udp", "udp4", "udp6":
return nil, fmt.Errorf("%s unsupported", network)
}
opts := &ConnectOptions{}
for _, option := range options {
option(opts)
@ -73,7 +82,16 @@ func SSHRemoteForwardConnector() Connector {
return &sshRemoteForwardConnector{}
}
func (c *sshRemoteForwardConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
func (c *sshRemoteForwardConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
return c.ConnectContext(context.Background(), conn, "tcp", address, options...)
}
func (c *sshRemoteForwardConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
switch network {
case "udp", "udp4", "udp6":
return nil, fmt.Errorf("%s unsupported", network)
}
cc, ok := conn.(*sshNopConn) // TODO: this is an ugly type assertion, need to find a better solution.
if !ok {
return nil, errors.New("ssh: wrong connection type")
@ -87,10 +105,10 @@ func (c *sshRemoteForwardConnector) Connect(conn net.Conn, addr string, options
if cc.session == nil || cc.session.client == nil {
return
}
if strings.HasPrefix(addr, ":") {
addr = "0.0.0.0" + addr
if strings.HasPrefix(address, ":") {
address = "0.0.0.0" + address
}
ln, err := cc.session.client.Listen("tcp", addr)
ln, err := cc.session.client.Listen("tcp", address)
if err != nil {
return
}
@ -99,7 +117,7 @@ func (c *sshRemoteForwardConnector) Connect(conn net.Conn, addr string, options
for {
rc, err := ln.Accept()
if err != nil {
log.Logf("[ssh-rtcp] %s <-> %s accpet : %s", ln.Addr(), addr, err)
log.Logf("[ssh-rtcp] %s <-> %s accpet : %s", ln.Addr(), address, err)
return
}
// log.Log("[ssh-rtcp] accept", rc.LocalAddr(), rc.RemoteAddr())
@ -107,7 +125,7 @@ func (c *sshRemoteForwardConnector) Connect(conn net.Conn, addr string, options
case cc.session.connChan <- rc:
default:
rc.Close()
log.Logf("[ssh-rtcp] %s - %s: connection queue is full", ln.Addr(), addr)
log.Logf("[ssh-rtcp] %s - %s: connection queue is full", ln.Addr(), address)
}
}
}()

View File

@ -1,6 +1,7 @@
package gost
import (
"context"
"errors"
"fmt"
"io"
@ -167,9 +168,11 @@ func (h *tunHandler) Handle(conn net.Conn) {
var pc net.PacketConn
// fake tcp mode will be ignored when the client specifies a chain.
if raddr != nil && !h.options.Chain.IsEmpty() {
var cc net.Conn
cc, err = getSOCKS5UDPTunnel(h.options.Chain, nil)
pc = &udpTunnelConn{Conn: cc, raddr: raddr}
cc, err := h.options.Chain.DialContext(context.Background(), "udp", raddr.String())
if err != nil {
return err
}
pc = cc.(net.PacketConn)
} else {
if h.options.TCPMode {
if raddr != nil {
@ -549,9 +552,11 @@ func (h *tapHandler) Handle(conn net.Conn) {
var pc net.PacketConn
// fake tcp mode will be ignored when the client specifies a chain.
if raddr != nil && !h.options.Chain.IsEmpty() {
var cc net.Conn
cc, err = getSOCKS5UDPTunnel(h.options.Chain, nil)
pc = &udpTunnelConn{Conn: cc, raddr: raddr}
cc, err := h.options.Chain.DialContext(context.Background(), "udp", raddr.String())
if err != nil {
return err
}
pc = cc.(net.PacketConn)
} else {
if h.options.TCPMode {
if raddr != nil {

21
udp.go
View File

@ -19,19 +19,17 @@ func UDPTransporter() Transporter {
}
func (tr *udpTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) {
raddr, err := net.ResolveUDPAddr("udp", addr)
taddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
conn, err := net.ListenUDP("udp", nil)
conn, err := net.DialUDP("udp", nil, taddr)
if err != nil {
return nil, err
}
return &udpClientConn{
UDPConn: conn,
raddr: raddr,
}, nil
}
@ -340,19 +338,14 @@ func (c *udpServerConn) SetWriteDeadline(t time.Time) error {
type udpClientConn struct {
*net.UDPConn
raddr net.Addr
}
func (c *udpClientConn) Write(b []byte) (int, error) {
if c.raddr != nil {
return c.WriteTo(b, c.raddr)
}
func (c *udpClientConn) WriteTo(b []byte, addr net.Addr) (int, error) {
return c.UDPConn.Write(b)
}
func (c *udpClientConn) RemoteAddr() net.Addr {
if c.raddr != nil {
return c.raddr
}
return c.UDPConn.RemoteAddr()
func (c *udpClientConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
n, err = c.Read(b)
addr = c.RemoteAddr()
return
}