diff --git a/chain.go b/chain.go index fc70a82..fd7bf45 100644 --- a/chain.go +++ b/chain.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "time" "github.com/go-log/log" ) @@ -308,3 +309,42 @@ func (c *Chain) selectRouteFor(addr string) (route *Chain, err error) { } return } + +// ChainOptions holds options for Chain. +type ChainOptions struct { + Retry int + Timeout time.Duration + Hosts *Hosts + Resolver Resolver +} + +// ChainOption allows a common way to set chain options. +type ChainOption func(opts *ChainOptions) + +// RetryChainOption specifies the times of retry used by Chain.Dial. +func RetryChainOption(retry int) ChainOption { + return func(opts *ChainOptions) { + opts.Retry = retry + } +} + +// TimeoutChainOption specifies the timeout used by Chain.Dial. +func TimeoutChainOption(timeout time.Duration) ChainOption { + return func(opts *ChainOptions) { + opts.Timeout = timeout + } +} + +// HostsChainOption specifies the hosts used by Chain.Dial. +func HostsChainOption(hosts *Hosts) ChainOption { + return func(opts *ChainOptions) { + opts.Hosts = hosts + } +} + +// ResolverChainOption specifies the Resolver used by Chain.Dial. +func ResolverChainOption(resolver Resolver) ChainOption { + return func(opts *ChainOptions) { + opts.Resolver = resolver + } +} diff --git a/client.go b/client.go index 8a74998..920dd30 100644 --- a/client.go +++ b/client.go @@ -90,7 +90,7 @@ func (tr *tcpTransporter) Multiplex() bool { return false } -// DialOptions describes the options for dialing. +// DialOptions describes the options for Transporter.Dial. type DialOptions struct { Timeout time.Duration Chain *Chain diff --git a/cmd/gost/main.go b/cmd/gost/main.go index 779596c..fb26896 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -341,15 +341,12 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { } func (r *route) serve() error { - baseChain, err := r.initChain() + chain, err := r.initChain() if err != nil { return err } for _, ns := range r.ServeNodes { - chain := &gost.Chain{} - *chain = *baseChain - node, err := gost.ParseNode(ns) if err != nil { return err @@ -469,8 +466,43 @@ func (r *route) serve() error { } } - var handlerOptions []gost.HandlerOption - handlerOptions = append(handlerOptions, + var handler gost.Handler + switch node.Protocol { + case "http2": + handler = gost.HTTP2Handler() + case "socks", "socks5": + handler = gost.SOCKS5Handler() + case "socks4", "socks4a": + handler = gost.SOCKS4Handler() + case "ss": + handler = gost.ShadowHandler() + case "http": + handler = gost.HTTPHandler() + case "tcp": + handler = gost.TCPDirectForwardHandler(node.Remote) + case "rtcp": + handler = gost.TCPRemoteForwardHandler(node.Remote) + case "udp": + handler = gost.UDPDirectForwardHandler(node.Remote) + case "rudp": + handler = gost.UDPRemoteForwardHandler(node.Remote) + case "forward": + handler = gost.SSHForwardHandler() + case "redirect": + handler = gost.TCPRedirectHandler() + case "ssu": + handler = gost.ShadowUDPdHandler() + case "sni": + handler = gost.SNIHandler() + default: + // start from 2.5, if remote is not empty, then we assume that it is a forward tunnel. + if node.Remote != "" { + handler = gost.TCPDirectForwardHandler(node.Remote) + } else { + handler = gost.AutoHandler() + } + } + handler.Init( gost.AddrHandlerOption(node.Addr), gost.ChainHandlerOption(chain), gost.UsersHandlerOption(users...), @@ -480,44 +512,6 @@ func (r *route) serve() error { gost.BypassHandlerOption(parseBypass(node.Get("bypass"))), gost.StrategyHandlerOption(parseStrategy(node.Get("strategy"))), ) - var handler gost.Handler - switch node.Protocol { - case "http2": - handler = gost.HTTP2Handler(handlerOptions...) - case "socks", "socks5": - handler = gost.SOCKS5Handler(handlerOptions...) - case "socks4", "socks4a": - handler = gost.SOCKS4Handler(handlerOptions...) - case "ss": - handler = gost.ShadowHandler(handlerOptions...) - case "http": - handler = gost.HTTPHandler(handlerOptions...) - case "tcp": - handler = gost.TCPDirectForwardHandler(node.Remote, handlerOptions...) - case "rtcp": - handler = gost.TCPRemoteForwardHandler(node.Remote, handlerOptions...) - case "udp": - handler = gost.UDPDirectForwardHandler(node.Remote, handlerOptions...) - case "rudp": - handler = gost.UDPRemoteForwardHandler(node.Remote, handlerOptions...) - case "forward": - handler = gost.SSHForwardHandler(handlerOptions...) - case "redirect": - handler = gost.TCPRedirectHandler(handlerOptions...) - case "ssu": - handler = gost.ShadowUDPdHandler(handlerOptions...) - case "sni": - handler = gost.SNIHandler(handlerOptions...) - default: - // start from 2.5, if remote is not empty, then we assume that it is a forward tunnel. - if node.Remote != "" { - handler = gost.TCPDirectForwardHandler(node.Remote, handlerOptions...) - } else { - handler = gost.AutoHandler(handlerOptions...) - } - } - - srv := &gost.Server{Listener: ln} chain.Resolver = parseResolver(node.Get("dns")) if gost.Debug { @@ -531,6 +525,7 @@ func (r *route) serve() error { } } + srv := &gost.Server{Listener: ln} go srv.Serve(handler) } diff --git a/forward.go b/forward.go index b84ca15..3f014c6 100644 --- a/forward.go +++ b/forward.go @@ -37,21 +37,9 @@ type tcpDirectForwardHandler struct { // NOTE: as of 2.6, remote address can be a comma-separated address list. func TCPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler { h := &tcpDirectForwardHandler{ - raddr: raddr, - options: &HandlerOptions{}, + raddr: raddr, + group: NewNodeGroup(), } - for _, opt := range opts { - opt(h.options) - } - - group := NewNodeGroup() - group.SetSelector(&defaultSelector{}, - WithStrategy(h.options.Strategy), - WithFilter(&FailFilter{ - MaxFails: 1, - FailTimeout: 30 * time.Second, - }), - ) for i, addr := range strings.Split(raddr, ",") { if addr == "" { @@ -59,17 +47,35 @@ func TCPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler { } // We treat the remote target server as a node, so we can put them in a group, // and perform the node selection for load balancing. - group.AddNode(Node{ + h.group.AddNode(Node{ ID: i + 1, Addr: addr, Host: addr, }) } - h.group = group + h.Init(opts...) return h } +func (h *tcpDirectForwardHandler) Init(options ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} + } + + for _, opt := range options { + opt(h.options) + } + + h.group.SetSelector(&defaultSelector{}, + WithStrategy(h.options.Strategy), + WithFilter(&FailFilter{ + MaxFails: 1, + FailTimeout: 30 * time.Second, + }), + ) +} + func (h *tcpDirectForwardHandler) Handle(conn net.Conn) { defer conn.Close() @@ -106,21 +112,9 @@ type udpDirectForwardHandler struct { // NOTE: as of 2.6, remote address can be a comma-separated address list. func UDPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler { h := &udpDirectForwardHandler{ - raddr: raddr, - options: &HandlerOptions{}, + raddr: raddr, + group: NewNodeGroup(), } - for _, opt := range opts { - opt(h.options) - } - - group := NewNodeGroup() - group.SetSelector(&defaultSelector{}, - WithStrategy(h.options.Strategy), - WithFilter(&FailFilter{ - MaxFails: 1, - FailTimeout: 30 * time.Second, - }), - ) for i, addr := range strings.Split(raddr, ",") { if addr == "" { @@ -128,17 +122,36 @@ func UDPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler { } // We treat the remote target server as a node, so we can put them in a group, // and perform the node selection for load balancing. - group.AddNode(Node{ + h.group.AddNode(Node{ ID: i + 1, Addr: addr, Host: addr, }) } - h.group = group + + h.Init(opts...) return h } +func (h *udpDirectForwardHandler) Init(options ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} + } + + for _, opt := range options { + opt(h.options) + } + + h.group.SetSelector(&defaultSelector{}, + WithStrategy(h.options.Strategy), + WithFilter(&FailFilter{ + MaxFails: 1, + FailTimeout: 30 * time.Second, + }), + ) +} + func (h *udpDirectForwardHandler) Handle(conn net.Conn) { defer conn.Close() @@ -191,21 +204,9 @@ type tcpRemoteForwardHandler struct { // NOTE: as of 2.6, remote address can be a comma-separated address list. func TCPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler { h := &tcpRemoteForwardHandler{ - raddr: raddr, - options: &HandlerOptions{}, + raddr: raddr, + group: NewNodeGroup(), } - for _, opt := range opts { - opt(h.options) - } - - group := NewNodeGroup() - group.SetSelector(&defaultSelector{}, - WithStrategy(h.options.Strategy), - WithFilter(&FailFilter{ - MaxFails: 1, - FailTimeout: 30 * time.Second, - }), - ) for i, addr := range strings.Split(raddr, ",") { if addr == "" { @@ -213,17 +214,34 @@ func TCPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler { } // We treat the remote target server as a node, so we can put them in a group, // and perform the node selection for load balancing. - group.AddNode(Node{ + h.group.AddNode(Node{ ID: i + 1, Addr: addr, Host: addr, }) } - h.group = group + h.Init(opts...) return h } +func (h *tcpRemoteForwardHandler) Init(options ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} + } + for _, opt := range options { + opt(h.options) + } + + h.group.SetSelector(&defaultSelector{}, + WithStrategy(h.options.Strategy), + WithFilter(&FailFilter{ + MaxFails: 1, + FailTimeout: 30 * time.Second, + }), + ) +} + func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) { defer conn.Close() @@ -258,38 +276,45 @@ type udpRemoteForwardHandler struct { // NOTE: as of 2.6, remote address can be a comma-separated address list. func UDPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler { h := &udpRemoteForwardHandler{ - raddr: raddr, - options: &HandlerOptions{}, - } - for _, opt := range opts { - opt(h.options) + raddr: raddr, + group: NewNodeGroup(), } - group := NewNodeGroup() - group.SetSelector(&defaultSelector{}, - WithStrategy(h.options.Strategy), - WithFilter(&FailFilter{ - MaxFails: 1, - FailTimeout: 30 * time.Second, - }), - ) for i, addr := range strings.Split(raddr, ",") { if addr == "" { continue } // We treat the remote target server as a node, so we can put them in a group, // and perform the node selection for load balancing. - group.AddNode(Node{ + h.group.AddNode(Node{ ID: i + 1, Addr: addr, Host: addr, }) } - h.group = group + + h.Init(opts...) return h } +func (h *udpRemoteForwardHandler) Init(options ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} + } + + for _, opt := range options { + opt(h.options) + } + h.group.SetSelector(&defaultSelector{}, + WithStrategy(h.options.Strategy), + WithFilter(&FailFilter{ + MaxFails: 1, + FailTimeout: 30 * time.Second, + }), + ) +} + func (h *udpRemoteForwardHandler) Handle(conn net.Conn) { defer conn.Close() diff --git a/handler.go b/handler.go index 559c6c4..d5d2fe8 100644 --- a/handler.go +++ b/handler.go @@ -13,6 +13,7 @@ import ( // Handler is a proxy server handler type Handler interface { + Init(options ...HandlerOption) Handle(net.Conn) } @@ -88,17 +89,25 @@ func StrategyHandlerOption(strategy Strategy) HandlerOption { } type autoHandler struct { - options []HandlerOption + options *HandlerOptions } // AutoHandler creates a server Handler for auto proxy server. func AutoHandler(opts ...HandlerOption) Handler { - h := &autoHandler{ - options: opts, - } + h := &autoHandler{} + h.Init(opts...) return h } +func (h *autoHandler) Init(options ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} + } + for _, opt := range options { + opt(h.options) + } +} + func (h *autoHandler) Handle(conn net.Conn) { br := bufio.NewReader(conn) b, err := br.Peek(1) @@ -109,25 +118,23 @@ func (h *autoHandler) Handle(conn net.Conn) { } cc := &bufferdConn{Conn: conn, br: br} + var handler Handler switch b[0] { case gosocks4.Ver4: - options := &HandlerOptions{} - for _, opt := range h.options { - opt(options) - } // SOCKS4(a) does not suppport authentication method, // so we ignore it when credentials are specified for security reason. - if len(options.Users) > 0 { + if len(h.options.Users) > 0 { cc.Close() return } - h := &socks4Handler{options} - h.Handle(cc) - case gosocks5.Ver5: - SOCKS5Handler(h.options...).Handle(cc) + handler = &socks4Handler{options: h.options} + case gosocks5.Ver5: // socks5 + handler = &socks5Handler{options: h.options} default: // http - HTTPHandler(h.options...).Handle(cc) + handler = &httpHandler{options: h.options} } + handler.Init() + handler.Handle(cc) } type bufferdConn struct { diff --git a/http.go b/http.go index c7f226f..7c9f86d 100644 --- a/http.go +++ b/http.go @@ -75,13 +75,18 @@ type httpHandler struct { // HTTPHandler creates a server Handler for HTTP proxy server. func HTTPHandler(opts ...HandlerOption) Handler { - h := &httpHandler{ - options: &HandlerOptions{}, + h := &httpHandler{} + h.Init(opts...) + return h +} + +func (h *httpHandler) Init(options ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} } - for _, opt := range opts { + for _, opt := range options { opt(h.options) } - return h } func (h *httpHandler) Handle(conn net.Conn) { diff --git a/http2.go b/http2.go index f3cef57..6501d83 100644 --- a/http2.go +++ b/http2.go @@ -250,16 +250,21 @@ type http2Handler struct { // HTTP2Handler creates a server Handler for HTTP2 proxy server. func HTTP2Handler(opts ...HandlerOption) Handler { - h := &http2Handler{ - options: new(HandlerOptions), - } - for _, opt := range opts { - opt(h.options) - } + h := &http2Handler{} + h.Init(opts...) return h } +func (h *http2Handler) Init(options ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} + } + for _, opt := range options { + opt(h.options) + } +} + func (h *http2Handler) Handle(conn net.Conn) { defer conn.Close() diff --git a/redirect.go b/redirect.go index f7033c6..f1be354 100644 --- a/redirect.go +++ b/redirect.go @@ -17,15 +17,20 @@ type tcpRedirectHandler struct { // TCPRedirectHandler creates a server Handler for TCP redirect server. func TCPRedirectHandler(opts ...HandlerOption) Handler { - h := &tcpRedirectHandler{ - options: &HandlerOptions{ - Chain: new(Chain), - }, + h := &tcpRedirectHandler{} + h.Init(opts...) + + return h +} + +func (h *tcpRedirectHandler) Init(options ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} } - for _, opt := range opts { + + for _, opt := range options { opt(h.options) } - return h } func (h *tcpRedirectHandler) Handle(c net.Conn) { diff --git a/resolver.go b/resolver.go index 08c38ff..2764ca7 100644 --- a/resolver.go +++ b/resolver.go @@ -48,6 +48,12 @@ func (ns NameServer) String() string { return fmt.Sprintf("%s/%s %s", addr, prot, host) } +type nameServers struct { + Servers []NameServer + Timeout time.Duration + TTL time.Duration +} + type resolverCacheItem struct { IPs []net.IP ts int64 diff --git a/sni.go b/sni.go index 321737f..0d99ae2 100644 --- a/sni.go +++ b/sni.go @@ -38,13 +38,20 @@ type sniHandler struct { // SNIHandler creates a server Handler for SNI proxy server. func SNIHandler(opts ...HandlerOption) Handler { - h := &sniHandler{ - options: &HandlerOptions{}, + h := &sniHandler{} + h.Init(opts...) + + return h +} + +func (h *sniHandler) Init(options ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} } - for _, opt := range opts { + + for _, opt := range options { opt(h.options) } - return h } func (h *sniHandler) Handle(conn net.Conn) { diff --git a/socks.go b/socks.go index 177e24c..fa4a926 100644 --- a/socks.go +++ b/socks.go @@ -348,30 +348,36 @@ type socks5Handler struct { // SOCKS5Handler creates a server Handler for SOCKS5 proxy server. func SOCKS5Handler(opts ...HandlerOption) Handler { - options := &HandlerOptions{} - for _, opt := range opts { - opt(options) + h := &socks5Handler{} + h.Init(opts...) + + return h +} + +func (h *socks5Handler) Init(options ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} } - tlsConfig := options.TLSConfig + for _, opt := range options { + opt(h.options) + } + + tlsConfig := h.options.TLSConfig if tlsConfig == nil { tlsConfig = DefaultTLSConfig } - selector := &serverSelector{ // socks5 server selector - Users: options.Users, + h.selector = &serverSelector{ // socks5 server selector + Users: h.options.Users, TLSConfig: tlsConfig, } // methods that socks5 server supported - selector.AddMethod( + h.selector.AddMethod( gosocks5.MethodNoAuth, gosocks5.MethodUserPass, MethodTLS, MethodTLSAuth, ) - return &socks5Handler{ - options: options, - selector: selector, - } } func (h *socks5Handler) Handle(conn net.Conn) { @@ -1110,12 +1116,19 @@ type socks4Handler struct { // SOCKS4Handler creates a server Handler for SOCKS4(A) proxy server. func SOCKS4Handler(opts ...HandlerOption) Handler { - options := &HandlerOptions{} - for _, opt := range opts { - opt(options) + h := &socks4Handler{} + h.Init(opts...) + + return h +} + +func (h *socks4Handler) Init(options ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} } - return &socks4Handler{ - options: options, + + for _, opt := range options { + opt(h.options) } } diff --git a/ss.go b/ss.go index 6bfc694..6b682c8 100644 --- a/ss.go +++ b/ss.go @@ -97,13 +97,20 @@ type shadowHandler struct { // ShadowHandler creates a server Handler for shadowsocks proxy server. func ShadowHandler(opts ...HandlerOption) Handler { - h := &shadowHandler{ - options: &HandlerOptions{}, + h := &shadowHandler{} + h.Init(opts...) + + return h +} + +func (h *shadowHandler) Init(options ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} } - for _, opt := range opts { + + for _, opt := range options { opt(h.options) } - return h } func (h *shadowHandler) Handle(conn net.Conn) { @@ -326,13 +333,20 @@ type shadowUDPdHandler struct { // ShadowUDPdHandler creates a server Handler for shadowsocks UDP relay server. func ShadowUDPdHandler(opts ...HandlerOption) Handler { - h := &shadowUDPdHandler{ - options: &HandlerOptions{}, + h := &shadowUDPdHandler{} + h.Init(opts...) + + return h +} + +func (h *shadowUDPdHandler) Init(options ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} } - for _, opt := range opts { + + for _, opt := range options { opt(h.options) } - return h } func (h *shadowUDPdHandler) Handle(conn net.Conn) { diff --git a/ssh.go b/ssh.go index 58c0cbc..c66b9a3 100644 --- a/ssh.go +++ b/ssh.go @@ -408,13 +408,22 @@ type sshForwardHandler struct { // SSHForwardHandler creates a server Handler for SSH port forwarding server. func SSHForwardHandler(opts ...HandlerOption) Handler { - h := &sshForwardHandler{ - options: new(HandlerOptions), - config: new(ssh.ServerConfig), + h := &sshForwardHandler{} + h.Init(opts...) + + return h +} + +func (h *sshForwardHandler) Init(options ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} } - for _, opt := range opts { + + for _, opt := range options { opt(h.options) } + h.config = &ssh.ServerConfig{} + h.config.PasswordCallback = defaultSSHPasswordCallback(h.options.Users...) if len(h.options.Users) == 0 { h.config.NoClientAuth = true @@ -430,8 +439,6 @@ func SSHForwardHandler(opts ...HandlerOption) Handler { } h.config.AddHostKey(signer) } - - return h } func (h *sshForwardHandler) Handle(conn net.Conn) {