add ChainOptions for Chain.Dial

This commit is contained in:
ginuerzh 2018-07-08 10:41:56 +08:00
parent c242286a06
commit c1f4325b19
12 changed files with 200 additions and 100 deletions

View File

@ -19,8 +19,6 @@ var (
type Chain struct { type Chain struct {
isRoute bool isRoute bool
Retries int Retries int
Hosts *Hosts
Resolver Resolver
nodeGroups []*NodeGroup nodeGroups []*NodeGroup
} }
@ -102,17 +100,22 @@ func (c *Chain) IsEmpty() bool {
// Dial connects to the target address addr through the chain. // Dial connects to the target address addr through the chain.
// If the chain is empty, it will use the net.Dial directly. // If the chain is empty, it will use the net.Dial directly.
func (c *Chain) Dial(addr string) (conn net.Conn, err error) { func (c *Chain) Dial(addr string, opts ...ChainOption) (conn net.Conn, err error) {
var retries int options := &ChainOptions{}
if c != nil { for _, opt := range opts {
opt(options)
}
retries := 1
if c != nil && c.Retries > 0 {
retries = c.Retries retries = c.Retries
} }
if retries == 0 { if options.Retries > 0 {
retries = 1 retries = options.Retries
} }
for i := 0; i < retries; i++ { for i := 0; i < retries; i++ {
conn, err = c.dial(addr) conn, err = c.dialWithOptions(addr, options)
if err == nil { if err == nil {
break break
} }
@ -120,16 +123,19 @@ func (c *Chain) Dial(addr string) (conn net.Conn, err error) {
return return
} }
func (c *Chain) dial(addr string) (net.Conn, error) { func (c *Chain) dialWithOptions(addr string, options *ChainOptions) (net.Conn, error) {
if options == nil {
options = &ChainOptions{}
}
route, err := c.selectRouteFor(addr) route, err := c.selectRouteFor(addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
addr = c.resolve(addr) addr = c.resolve(addr, options.Resolver, options.Hosts)
if route.IsEmpty() { if route.IsEmpty() {
return net.DialTimeout("tcp", addr, DialTimeout) return net.DialTimeout("tcp", addr, options.Timeout)
} }
conn, err := route.getConn() conn, err := route.getConn()
@ -145,17 +151,17 @@ func (c *Chain) dial(addr string) (net.Conn, error) {
return cc, nil return cc, nil
} }
func (c *Chain) resolve(addr string) string { func (c *Chain) resolve(addr string, resolver Resolver, hosts *Hosts) string {
host, port, err := net.SplitHostPort(addr) host, port, err := net.SplitHostPort(addr)
if err != nil { if err != nil {
return addr return addr
} }
if ip := c.Hosts.Lookup(host); ip != nil { if ip := hosts.Lookup(host); ip != nil {
return net.JoinHostPort(ip.String(), port) return net.JoinHostPort(ip.String(), port)
} }
if c.Resolver != nil { if resolver != nil {
ips, err := c.Resolver.Resolve(host) ips, err := resolver.Resolve(host)
if err != nil { if err != nil {
log.Logf("[resolver] %s: %v", host, err) log.Logf("[resolver] %s: %v", host, err)
} }
@ -168,8 +174,21 @@ func (c *Chain) resolve(addr string) string {
// Conn obtains a handshaked connection to the last node of the chain. // Conn obtains a handshaked connection to the last node of the chain.
// If the chain is empty, it returns an ErrEmptyChain error. // If the chain is empty, it returns an ErrEmptyChain error.
func (c *Chain) Conn() (conn net.Conn, err error) { func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) {
for i := 0; i < c.Retries; i++ { options := &ChainOptions{}
for _, opt := range opts {
opt(options)
}
retries := 1
if c != nil && c.Retries > 0 {
retries = c.Retries
}
if options.Retries > 0 {
retries = options.Retries
}
for i := 0; i < retries; i++ {
var route *Chain var route *Chain
route, err = c.selectRoute() route, err = c.selectRoute()
if err != nil { if err != nil {
@ -177,6 +196,7 @@ func (c *Chain) Conn() (conn net.Conn, err error) {
} }
conn, err = route.getConn() conn, err = route.getConn()
if err != nil { if err != nil {
log.Log(err)
continue continue
} }
@ -185,6 +205,7 @@ func (c *Chain) Conn() (conn net.Conn, err error) {
return return
} }
// getConn obtains a connection to the last node of the chain.
func (c *Chain) getConn() (conn net.Conn, err error) { func (c *Chain) getConn() (conn net.Conn, err error) {
if c.IsEmpty() { if c.IsEmpty() {
err = ErrEmptyChain err = ErrEmptyChain
@ -232,7 +253,7 @@ func (c *Chain) getConn() (conn net.Conn, err error) {
} }
func (c *Chain) selectRoute() (route *Chain, err error) { func (c *Chain) selectRoute() (route *Chain, err error) {
if c.isRoute { if c.IsEmpty() || c.isRoute {
return c, nil return c, nil
} }
@ -256,7 +277,6 @@ func (c *Chain) selectRoute() (route *Chain, err error) {
route.AddNode(node) route.AddNode(node)
} }
route.Retries = c.Retries route.Retries = c.Retries
route.Resolver = c.Resolver
if Debug { if Debug {
log.Log("select route:", buf.String()) log.Log("select route:", buf.String())
@ -299,9 +319,7 @@ func (c *Chain) selectRouteFor(addr string) (route *Chain, err error) {
route.AddNode(node) route.AddNode(node)
} }
route.Retries = c.Retries route.Retries = c.Retries
route.Resolver = c.Resolver
if Debug { if Debug {
buf.WriteString(addr) buf.WriteString(addr)
@ -312,7 +330,7 @@ func (c *Chain) selectRouteFor(addr string) (route *Chain, err error) {
// ChainOptions holds options for Chain. // ChainOptions holds options for Chain.
type ChainOptions struct { type ChainOptions struct {
Retry int Retries int
Timeout time.Duration Timeout time.Duration
Hosts *Hosts Hosts *Hosts
Resolver Resolver Resolver Resolver
@ -322,9 +340,9 @@ type ChainOptions struct {
type ChainOption func(opts *ChainOptions) type ChainOption func(opts *ChainOptions)
// RetryChainOption specifies the times of retry used by Chain.Dial. // RetryChainOption specifies the times of retry used by Chain.Dial.
func RetryChainOption(retry int) ChainOption { func RetryChainOption(retries int) ChainOption {
return func(opts *ChainOptions) { return func(opts *ChainOptions) {
opts.Retry = retry opts.Retries = retries
} }
} }

View File

@ -6,8 +6,7 @@ import (
"flag" "flag"
"fmt" "fmt"
"net" "net"
"net/http" // _ "net/http/pprof"
_ "net/http/pprof"
"os" "os"
"runtime" "runtime"
"time" "time"
@ -59,9 +58,9 @@ func init() {
} }
func main() { func main() {
go func() { // go func() {
log.Log(http.ListenAndServe("localhost:6060", nil)) // log.Log(http.ListenAndServe("localhost:6060", nil))
}() // }()
// NOTE: as of 2.6, you can use custom cert/key files to initialize the default certificate. // NOTE: as of 2.6, you can use custom cert/key files to initialize the default certificate.
config, err := tlsConfig(defaultCertFile, defaultKeyFile) config, err := tlsConfig(defaultCertFile, defaultKeyFile)
if err != nil { if err != nil {
@ -95,12 +94,7 @@ type route struct {
func (r *route) initChain() (*gost.Chain, error) { func (r *route) initChain() (*gost.Chain, error) {
chain := gost.NewChain() chain := gost.NewChain()
chain.Retries = r.Retries chain.Retries = r.Retries
if chain.Retries == 0 {
chain.Retries = 1
}
gid := 1 // group ID gid := 1 // group ID
for _, ns := range r.ChainNodes { for _, ns := range r.ChainNodes {
@ -454,18 +448,6 @@ func (r *route) serve() error {
return err return err
} }
var whitelist, blacklist *gost.Permissions
if node.Values.Get("whitelist") != "" {
if whitelist, err = gost.ParsePermissions(node.Get("whitelist")); err != nil {
return err
}
}
if node.Values.Get("blacklist") != "" {
if blacklist, err = gost.ParsePermissions(node.Get("blacklist")); err != nil {
return err
}
}
var handler gost.Handler var handler gost.Handler
switch node.Protocol { switch node.Protocol {
case "http2": case "http2":
@ -502,6 +484,27 @@ func (r *route) serve() error {
handler = gost.AutoHandler() handler = gost.AutoHandler()
} }
} }
var whitelist, blacklist *gost.Permissions
if node.Values.Get("whitelist") != "" {
if whitelist, err = gost.ParsePermissions(node.Get("whitelist")); err != nil {
return err
}
}
if node.Values.Get("blacklist") != "" {
if blacklist, err = gost.ParsePermissions(node.Get("blacklist")); err != nil {
return err
}
}
var hosts *gost.Hosts
if f, _ := os.Open(node.Get("hosts")); f != nil {
hosts, err = gost.ParseHosts(f)
if err != nil {
log.Logf("[hosts] %s: %v", f.Name(), err)
}
}
handler.Init( handler.Init(
gost.AddrHandlerOption(node.Addr), gost.AddrHandlerOption(node.Addr),
gost.ChainHandlerOption(chain), gost.ChainHandlerOption(chain),
@ -511,20 +514,12 @@ func (r *route) serve() error {
gost.BlacklistHandlerOption(blacklist), gost.BlacklistHandlerOption(blacklist),
gost.BypassHandlerOption(parseBypass(node.Get("bypass"))), gost.BypassHandlerOption(parseBypass(node.Get("bypass"))),
gost.StrategyHandlerOption(parseStrategy(node.Get("strategy"))), gost.StrategyHandlerOption(parseStrategy(node.Get("strategy"))),
gost.ResolverHandlerOption(parseResolver(node.Get("dns"))),
gost.HostsHandlerOption(hosts),
gost.RetryHandlerOption(node.GetInt("retry")),
gost.TimeoutHandlerOption(time.Duration(node.GetInt("timeout"))*time.Second),
) )
chain.Resolver = parseResolver(node.Get("dns"))
if gost.Debug {
log.Logf("[resolver]\n%v", chain.Resolver)
}
if f, _ := os.Open(node.Get("hosts")); f != nil {
chain.Hosts, err = gost.ParseHosts(f)
if err != nil {
log.Logf("[hosts] %s: %v", f.Name(), err)
}
}
srv := &gost.Server{Listener: ln} srv := &gost.Server{Listener: ln}
go srv.Serve(handler) go srv.Serve(handler)
} }

View File

@ -86,7 +86,10 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) {
} }
log.Logf("[tcp] %s - %s", conn.RemoteAddr(), node.Addr) log.Logf("[tcp] %s - %s", conn.RemoteAddr(), node.Addr)
cc, err := h.options.Chain.Dial(node.Addr) cc, err := h.options.Chain.Dial(node.Addr,
RetryChainOption(h.options.Retries),
TimeoutChainOption(h.options.Timeout),
)
if err != nil { if err != nil {
log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), node.Addr, err) log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), node.Addr, err)
node.MarkDead() node.MarkDead()

View File

@ -5,6 +5,7 @@ import (
"crypto/tls" "crypto/tls"
"net" "net"
"net/url" "net/url"
"time"
"github.com/ginuerzh/gosocks4" "github.com/ginuerzh/gosocks4"
"github.com/ginuerzh/gosocks5" "github.com/ginuerzh/gosocks5"
@ -25,8 +26,12 @@ type HandlerOptions struct {
TLSConfig *tls.Config TLSConfig *tls.Config
Whitelist *Permissions Whitelist *Permissions
Blacklist *Permissions Blacklist *Permissions
Bypass *Bypass
Strategy Strategy Strategy Strategy
Bypass *Bypass
Retries int
Timeout time.Duration
Resolver Resolver
Hosts *Hosts
} }
// HandlerOption allows a common way to set handler options. // HandlerOption allows a common way to set handler options.
@ -88,6 +93,34 @@ func StrategyHandlerOption(strategy Strategy) HandlerOption {
} }
} }
// RetryHandlerOption sets the retry option of HandlerOptions.
func RetryHandlerOption(retries int) HandlerOption {
return func(opts *HandlerOptions) {
opts.Retries = retries
}
}
// TimeoutHandlerOption sets the timeout option of HandlerOptions.
func TimeoutHandlerOption(timeout time.Duration) HandlerOption {
return func(opts *HandlerOptions) {
opts.Timeout = timeout
}
}
// ResolverHandlerOption sets the resolver option of HandlerOptions.
func ResolverHandlerOption(resolver Resolver) HandlerOption {
return func(opts *HandlerOptions) {
opts.Resolver = resolver
}
}
// HostsHandlerOption sets the Hosts option of HandlerOptions.
func HostsHandlerOption(hosts *Hosts) HandlerOption {
return func(opts *HandlerOptions) {
opts.Hosts = hosts
}
}
type autoHandler struct { type autoHandler struct {
options *HandlerOptions options *HandlerOptions
} }

74
http.go
View File

@ -166,24 +166,50 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) {
req.Header.Del("Proxy-Authorization") req.Header.Del("Proxy-Authorization")
// req.Header.Del("Proxy-Connection") // req.Header.Del("Proxy-Connection")
route, err := h.options.Chain.selectRouteFor(req.Host)
if err != nil {
log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err)
return
}
// forward http request
lastNode := route.LastNode()
if req.Method != http.MethodConnect && lastNode.Protocol == "http" {
h.forwardRequest(conn, req, route)
return
}
host := req.Host host := req.Host
if _, port, _ := net.SplitHostPort(host); port == "" { if _, port, _ := net.SplitHostPort(host); port == "" {
host = net.JoinHostPort(req.Host, "80") host = net.JoinHostPort(req.Host, "80")
} }
cc, err := route.Dial(host) retries := 1
if h.options.Chain != nil && h.options.Chain.Retries > 0 {
retries = h.options.Chain.Retries
}
if h.options.Retries > 0 {
retries = h.options.Retries
}
var err error
var cc net.Conn
var route *Chain
for i := 0; i < retries; i++ {
route, err = h.options.Chain.selectRouteFor(req.Host)
if err != nil {
log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err)
continue
}
// forward http request
lastNode := route.LastNode()
if req.Method != http.MethodConnect && lastNode.Protocol == "http" {
err = h.forwardRequest(conn, req, route)
if err == nil {
return
}
log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err)
continue
}
cc, err = route.Dial(host,
RetryChainOption(1),
TimeoutChainOption(h.options.Timeout),
HostsChainOption(h.options.Hosts),
ResolverChainOption(h.options.Resolver),
)
if err == nil {
break
}
}
if err != nil { if err != nil {
log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), host, err) log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), host, err)
@ -218,23 +244,17 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) {
log.Logf("[http] %s >-< %s", cc.LocalAddr(), host) log.Logf("[http] %s >-< %s", cc.LocalAddr(), host)
} }
func (h *httpHandler) forwardRequest(conn net.Conn, req *http.Request, route *Chain) { func (h *httpHandler) forwardRequest(conn net.Conn, req *http.Request, route *Chain) error {
if route.IsEmpty() { if route.IsEmpty() {
return return nil
} }
lastNode := route.LastNode() lastNode := route.LastNode()
cc, err := route.Conn() cc, err := route.Conn(
RetryChainOption(1), // we control the retry manually.
)
if err != nil { if err != nil {
log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), lastNode.Addr, err) return err
b := []byte("HTTP/1.1 503 Service unavailable\r\n" +
"Proxy-Agent: gost/" + Version + "\r\n\r\n")
if Debug {
log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), lastNode.Addr, string(b))
}
conn.Write(b)
return
} }
defer cc.Close() defer cc.Close()
@ -253,14 +273,14 @@ func (h *httpHandler) forwardRequest(conn net.Conn, req *http.Request, route *Ch
} }
if err = req.WriteProxy(cc); err != nil { if err = req.WriteProxy(cc); err != nil {
log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err) log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err)
return return nil
} }
cc.SetWriteDeadline(time.Time{}) cc.SetWriteDeadline(time.Time{})
log.Logf("[http] %s <-> %s", conn.RemoteAddr(), req.Host) log.Logf("[http] %s <-> %s", conn.RemoteAddr(), req.Host)
transport(conn, cc) transport(conn, cc)
log.Logf("[http] %s >-< %s", conn.RemoteAddr(), req.Host) log.Logf("[http] %s >-< %s", conn.RemoteAddr(), req.Host)
return return nil
} }
func basicProxyAuth(proxyAuth string) (username, password string, ok bool) { func basicProxyAuth(proxyAuth string) (username, password string, ok bool) {

View File

@ -321,7 +321,12 @@ func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) {
r.Header.Del("Proxy-Authorization") r.Header.Del("Proxy-Authorization")
r.Header.Del("Proxy-Connection") r.Header.Del("Proxy-Connection")
cc, err := h.options.Chain.Dial(target) cc, err := h.options.Chain.Dial(target,
RetryChainOption(h.options.Retries),
TimeoutChainOption(h.options.Timeout),
HostsChainOption(h.options.Hosts),
ResolverChainOption(h.options.Resolver),
)
if err != nil { if err != nil {
log.Logf("[http2] %s -> %s : %s", r.RemoteAddr, target, err) log.Logf("[http2] %s -> %s : %s", r.RemoteAddr, target, err)
w.WriteHeader(http.StatusServiceUnavailable) w.WriteHeader(http.StatusServiceUnavailable)

View File

@ -49,7 +49,10 @@ func (h *tcpRedirectHandler) Handle(c net.Conn) {
log.Logf("[red-tcp] %s -> %s", srcAddr, dstAddr) log.Logf("[red-tcp] %s -> %s", srcAddr, dstAddr)
cc, err := h.options.Chain.Dial(dstAddr.String()) cc, err := h.options.Chain.Dial(dstAddr.String(),
RetryChainOption(h.options.Retries),
TimeoutChainOption(h.options.Timeout),
)
if err != nil { if err != nil {
log.Logf("[red-tcp] %s -> %s : %s", srcAddr, dstAddr, err) log.Logf("[red-tcp] %s -> %s : %s", srcAddr, dstAddr, err)
return return

View File

@ -48,12 +48,6 @@ func (ns NameServer) String() string {
return fmt.Sprintf("%s/%s %s", addr, prot, host) return fmt.Sprintf("%s/%s %s", addr, prot, host)
} }
type nameServers struct {
Servers []NameServer
Timeout time.Duration
TTL time.Duration
}
type resolverCacheItem struct { type resolverCacheItem struct {
IPs []net.IP IPs []net.IP
ts int64 ts int64

8
sni.go
View File

@ -77,6 +77,7 @@ func (h *sniHandler) Handle(conn net.Conn) {
req.URL.Scheme = "http" // make sure that the URL is absolute req.URL.Scheme = "http" // make sure that the URL is absolute
} }
handler := &httpHandler{options: h.options} handler := &httpHandler{options: h.options}
handler.Init()
handler.handleRequest(conn, req) handler.handleRequest(conn, req)
return return
} }
@ -98,7 +99,12 @@ func (h *sniHandler) Handle(conn net.Conn) {
return return
} }
cc, err := h.options.Chain.Dial(addr) cc, err := h.options.Chain.Dial(addr,
RetryChainOption(h.options.Retries),
TimeoutChainOption(h.options.Timeout),
HostsChainOption(h.options.Hosts),
ResolverChainOption(h.options.Resolver),
)
if err != nil { if err != nil {
log.Logf("[sni] %s -> %s : %s", conn.RemoteAddr(), addr, err) log.Logf("[sni] %s -> %s : %s", conn.RemoteAddr(), addr, err)
return return

View File

@ -435,7 +435,12 @@ func (h *socks5Handler) handleConnect(conn net.Conn, req *gosocks5.Request) {
return return
} }
cc, err := h.options.Chain.Dial(addr) cc, err := h.options.Chain.Dial(addr,
RetryChainOption(h.options.Retries),
TimeoutChainOption(h.options.Timeout),
HostsChainOption(h.options.Hosts),
ResolverChainOption(h.options.Resolver),
)
if err != nil { if err != nil {
log.Logf("[socks5-connect] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) log.Logf("[socks5-connect] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err)
rep := gosocks5.NewReply(gosocks5.HostUnreachable, nil) rep := gosocks5.NewReply(gosocks5.HostUnreachable, nil)
@ -1181,7 +1186,10 @@ func (h *socks4Handler) handleConnect(conn net.Conn, req *gosocks4.Request) {
return return
} }
cc, err := h.options.Chain.Dial(addr) cc, err := h.options.Chain.Dial(addr,
RetryChainOption(h.options.Retries),
TimeoutChainOption(h.options.Timeout),
)
if err != nil { if err != nil {
log.Logf("[socks4-connect] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) log.Logf("[socks4-connect] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err)
rep := gosocks4.NewReply(gosocks4.Failed, nil) rep := gosocks4.NewReply(gosocks4.Failed, nil)

7
ss.go
View File

@ -152,7 +152,12 @@ func (h *shadowHandler) Handle(conn net.Conn) {
return return
} }
cc, err := h.options.Chain.Dial(addr) cc, err := h.options.Chain.Dial(addr,
RetryChainOption(h.options.Retries),
TimeoutChainOption(h.options.Timeout),
HostsChainOption(h.options.Hosts),
ResolverChainOption(h.options.Resolver),
)
if err != nil { if err != nil {
log.Logf("[ss] %s -> %s : %s", conn.RemoteAddr(), addr, err) log.Logf("[ss] %s -> %s : %s", conn.RemoteAddr(), addr, err)
return return

12
ssh.go
View File

@ -513,7 +513,17 @@ func (h *sshForwardHandler) directPortForwardChannel(channel ssh.Channel, raddr
return return
} }
conn, err := h.options.Chain.Dial(raddr) if h.options.Bypass.Contains(raddr) {
log.Logf("[ssh-tcp] [bypass] %s", raddr)
return
}
conn, err := h.options.Chain.Dial(raddr,
RetryChainOption(h.options.Retries),
TimeoutChainOption(h.options.Timeout),
HostsChainOption(h.options.Hosts),
ResolverChainOption(h.options.Resolver),
)
if err != nil { if err != nil {
log.Logf("[ssh-tcp] %s - %s : %s", h.options.Addr, raddr, err) log.Logf("[ssh-tcp] %s - %s : %s", h.options.Addr, raddr, err)
return return