From 62663564cc38ff48997424998f77ca5d97eefdaa Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Wed, 9 Jan 2019 22:36:44 +0800 Subject: [PATCH] add reloader for authenticator --- {cmd/gost/.config => .config}/bypass.txt | 0 {cmd/gost/.config => .config}/dns.txt | 0 {cmd/gost/.config => .config}/gost.json | 0 {cmd/gost/.config => .config}/hosts.txt | 0 {cmd/gost/.config => .config}/kcp.json | 0 {cmd/gost/.config => .config}/peer.txt | 0 {.testdata => .config}/probe_resist.txt | 0 {cmd/gost/.config => .config}/secrets.txt | 3 + README.md | 1 + README_en.md | 1 + auth.go | 155 ++++++++++++++++++ auth_test.go | 191 ++++++++++++++++++++++ bypass.go | 40 +---- bypass_test.go | 4 +- cmd/gost/.config/probe_resist.txt | 1 - cmd/gost/cfg.go | 18 ++ cmd/gost/route.go | 14 +- gost.go | 24 ++- handler.go | 48 ++++-- hosts.go | 47 ++---- http.go | 19 +-- http2.go | 3 +- http2_test.go | 2 +- http_test.go | 4 +- reload.go | 21 +-- resolver.go | 21 +-- server.go | 5 + snapcraft.yaml | 2 +- socks.go | 30 ++-- socks_test.go | 2 +- ssh.go | 23 +-- 31 files changed, 492 insertions(+), 187 deletions(-) rename {cmd/gost/.config => .config}/bypass.txt (100%) rename {cmd/gost/.config => .config}/dns.txt (100%) rename {cmd/gost/.config => .config}/gost.json (100%) rename {cmd/gost/.config => .config}/hosts.txt (100%) rename {cmd/gost/.config => .config}/kcp.json (100%) rename {cmd/gost/.config => .config}/peer.txt (100%) rename {.testdata => .config}/probe_resist.txt (100%) rename {cmd/gost/.config => .config}/secrets.txt (74%) create mode 100644 auth.go create mode 100644 auth_test.go delete mode 100644 cmd/gost/.config/probe_resist.txt diff --git a/cmd/gost/.config/bypass.txt b/.config/bypass.txt similarity index 100% rename from cmd/gost/.config/bypass.txt rename to .config/bypass.txt diff --git a/cmd/gost/.config/dns.txt b/.config/dns.txt similarity index 100% rename from cmd/gost/.config/dns.txt rename to .config/dns.txt diff --git a/cmd/gost/.config/gost.json b/.config/gost.json similarity index 100% rename from cmd/gost/.config/gost.json rename to .config/gost.json diff --git a/cmd/gost/.config/hosts.txt b/.config/hosts.txt similarity index 100% rename from cmd/gost/.config/hosts.txt rename to .config/hosts.txt diff --git a/cmd/gost/.config/kcp.json b/.config/kcp.json similarity index 100% rename from cmd/gost/.config/kcp.json rename to .config/kcp.json diff --git a/cmd/gost/.config/peer.txt b/.config/peer.txt similarity index 100% rename from cmd/gost/.config/peer.txt rename to .config/peer.txt diff --git a/.testdata/probe_resist.txt b/.config/probe_resist.txt similarity index 100% rename from .testdata/probe_resist.txt rename to .config/probe_resist.txt diff --git a/cmd/gost/.config/secrets.txt b/.config/secrets.txt similarity index 74% rename from cmd/gost/.config/secrets.txt rename to .config/secrets.txt index 4b5f9d4..fe86322 100644 --- a/cmd/gost/.config/secrets.txt +++ b/.config/secrets.txt @@ -1,3 +1,6 @@ +# period for live reloading +reload 3s + # username password $test.admin$ $123456$ diff --git a/README.md b/README.md index 7ab48ad..952df50 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ gost - GO Simple Tunnel * 多端口监听 * 可设置转发代理,支持多级转发(代理链) * 支持标准HTTP/HTTPS/HTTP2/SOCKS4(A)/SOCKS5代理协议 +* Web代理支持[探测防御](https://docs.ginuerzh.xyz/gost/probe_resist/) * [支持多种隧道类型](https://docs.ginuerzh.xyz/gost/configuration/) * [SOCKS5代理支持TLS协商加密](https://docs.ginuerzh.xyz/gost/socks/) * [Tunnel UDP over TCP](https://docs.ginuerzh.xyz/gost/socks/) diff --git a/README_en.md b/README_en.md index af67768..ce3a257 100644 --- a/README_en.md +++ b/README_en.md @@ -16,6 +16,7 @@ Features * Listening on multiple ports * Multi-level forward proxy - proxy chain * Standard HTTP/HTTPS/HTTP2/SOCKS4(A)/SOCKS5 proxy protocols support +* [Probing resistance](https://docs.ginuerzh.xyz/gost/en/probe_resist/) support for web proxy * [Support multiple tunnel types](https://docs.ginuerzh.xyz/gost/en/configuration/) * [TLS encryption via negotiation support for SOCKS5 proxy](https://docs.ginuerzh.xyz/gost/en/socks/) * [Tunnel UDP over TCP](https://docs.ginuerzh.xyz/gost/en/socks/) diff --git a/auth.go b/auth.go new file mode 100644 index 0000000..f3e4e31 --- /dev/null +++ b/auth.go @@ -0,0 +1,155 @@ +package gost + +import ( + "bufio" + "io" + "strings" + "sync" + "time" +) + +// Authenticator is an interface for user authentication. +type Authenticator interface { + Authenticate(user, password string) bool +} + +// LocalAuthenticator is an Authenticator that authenticates client by local key-value pairs. +type LocalAuthenticator struct { + kvs map[string]string + period time.Duration + stopped chan struct{} + mux sync.RWMutex +} + +// NewLocalAuthenticator creates an Authenticator that authenticates client by local infos. +func NewLocalAuthenticator(kvs map[string]string) *LocalAuthenticator { + return &LocalAuthenticator{ + kvs: kvs, + stopped: make(chan struct{}), + } +} + +// Authenticate checks the validity of the provided user-password pair. +func (au *LocalAuthenticator) Authenticate(user, password string) bool { + if au == nil { + return true + } + + au.mux.RLock() + defer au.mux.RUnlock() + + if len(au.kvs) == 0 { + return true + } + + v, ok := au.kvs[user] + return ok && (v == "" || password == v) +} + +// Add adds a key-value pair to the Authenticator. +func (au *LocalAuthenticator) Add(k, v string) { + au.mux.Lock() + defer au.mux.Unlock() + if au.kvs == nil { + au.kvs = make(map[string]string) + } + au.kvs[k] = v +} + +// Reload parses config from r, then live reloads the bypass. +func (au *LocalAuthenticator) Reload(r io.Reader) error { + var period time.Duration + kvs := make(map[string]string) + + if r == nil || au.Stopped() { + return nil + } + + // splitLine splits a line text by white space. + // A line started with '#' will be ignored, otherwise it is valid. + split := func(line string) []string { + if line == "" { + return nil + } + line = strings.Replace(line, "\t", " ", -1) + line = strings.TrimSpace(line) + + if strings.IndexByte(line, '#') == 0 { + return nil + } + + var ss []string + for _, s := range strings.Split(line, " ") { + if s = strings.TrimSpace(s); s != "" { + ss = append(ss, s) + } + } + return ss + } + + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + ss := split(line) + if len(ss) == 0 { + continue + } + + switch ss[0] { + case "reload": // reload option + if len(ss) > 1 { + period, _ = time.ParseDuration(ss[1]) + } + default: + var k, v string + k = ss[0] + if len(ss) > 1 { + v = ss[1] + } + kvs[k] = v + } + } + + if err := scanner.Err(); err != nil { + return err + } + + au.mux.Lock() + defer au.mux.Unlock() + + au.period = period + au.kvs = kvs + + return nil +} + +// Period returns the reload period. +func (au *LocalAuthenticator) Period() time.Duration { + if au.Stopped() { + return -1 + } + + au.mux.RLock() + defer au.mux.RUnlock() + + return au.period +} + +// Stop stops reloading. +func (au *LocalAuthenticator) Stop() { + select { + case <-au.stopped: + default: + close(au.stopped) + } +} + +// Stopped checks whether the reloader is stopped. +func (au *LocalAuthenticator) Stopped() bool { + select { + case <-au.stopped: + return true + default: + return false + } +} diff --git a/auth_test.go b/auth_test.go new file mode 100644 index 0000000..68a1c93 --- /dev/null +++ b/auth_test.go @@ -0,0 +1,191 @@ +package gost + +import ( + "bytes" + "fmt" + "io" + "net/url" + "testing" + "time" +) + +var localAuthenticatorTests = []struct { + clientUser *url.Userinfo + serverUsers []*url.Userinfo + valid bool +}{ + {nil, nil, true}, + {nil, []*url.Userinfo{url.User("admin")}, false}, + {nil, []*url.Userinfo{url.UserPassword("", "123456")}, false}, + {nil, []*url.Userinfo{url.UserPassword("admin", "123456")}, false}, + + {url.User("admin"), nil, true}, + {url.User("admin"), []*url.Userinfo{url.User("admin")}, true}, + {url.User("admin"), []*url.Userinfo{url.User("test")}, false}, + {url.User("admin"), []*url.Userinfo{url.UserPassword("test", "123456")}, false}, + {url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "123456")}, false}, + {url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "")}, true}, + {url.User("admin"), []*url.Userinfo{url.UserPassword("", "123456")}, false}, + + {url.UserPassword("", ""), nil, true}, + {url.UserPassword("", "123456"), nil, true}, + {url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, true}, + {url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("admin", "")}, false}, + {url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, false}, + + {url.UserPassword("admin", "123456"), nil, true}, + {url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("admin")}, true}, + {url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("test")}, false}, + {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "")}, true}, + {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, false}, + {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123")}, false}, + {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("test", "123456")}, false}, + {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, true}, + + {url.UserPassword("admin", "123456"), []*url.Userinfo{ + url.UserPassword("test", "123"), + url.UserPassword("admin", "123456"), + }, true}, +} + +func TestLocalAuthenticator(t *testing.T) { + for i, tc := range localAuthenticatorTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + au := NewLocalAuthenticator(nil) + for _, u := range tc.serverUsers { + if u != nil { + p, _ := u.Password() + au.Add(u.Username(), p) + } + } + + var u, p string + if tc.clientUser != nil { + u = tc.clientUser.Username() + p, _ = tc.clientUser.Password() + } + if au.Authenticate(u, p) != tc.valid { + t.Error("authenticate result should be", tc.valid) + } + }) + } +} + +var localAuthenticatorReloadTests = []struct { + r io.Reader + period time.Duration + kvs map[string]string + stopped bool +}{ + { + r: nil, + period: 0, + kvs: nil, + }, + { + r: bytes.NewBufferString(""), + period: 0, + }, + { + r: bytes.NewBufferString("reload 10s"), + period: 10 * time.Second, + }, + { + r: bytes.NewBufferString("# reload 10s\n"), + }, + { + r: bytes.NewBufferString("reload 10s\n#admin"), + period: 10 * time.Second, + }, + { + r: bytes.NewBufferString("reload 10s\nadmin"), + period: 10 * time.Second, + kvs: map[string]string{ + "admin": "", + }, + }, + { + r: bytes.NewBufferString("# reload 10s\nadmin"), + kvs: map[string]string{ + "admin": "", + }, + }, + { + r: bytes.NewBufferString("# reload 10s\nadmin #123456"), + kvs: map[string]string{ + "admin": "#123456", + }, + stopped: true, + }, + { + r: bytes.NewBufferString("admin \t #123456\n\n\ntest \t 123456"), + kvs: map[string]string{ + "admin": "#123456", + "test": "123456", + }, + stopped: true, + }, + { + r: bytes.NewBufferString(` + $test.admin$ $123456$ + @test.admin@ @123456@ + test.admin# #123456# + test.admin\admin 123456 + `), + kvs: map[string]string{ + "$test.admin$": "$123456$", + "@test.admin@": "@123456@", + "test.admin#": "#123456#", + "test.admin\\admin": "123456", + }, + stopped: true, + }, +} + +func TestLocalAuthenticatorReload(t *testing.T) { + isEquals := func(a, b map[string]string) bool { + if len(a) == 0 && len(b) == 0 { + return true + } + if len(a) != len(b) { + return false + } + + for k, v := range a { + if b[k] != v { + return false + } + } + return true + } + for i, tc := range localAuthenticatorReloadTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + au := NewLocalAuthenticator(nil) + + if err := au.Reload(tc.r); err != nil { + t.Error(err) + } + if au.Period() != tc.period { + t.Errorf("#%d test failed: period value should be %v, got %v", + i, tc.period, au.Period()) + } + if !isEquals(au.kvs, tc.kvs) { + t.Errorf("#%d test failed: %v, %s", i, au.kvs, tc.kvs) + } + + if tc.stopped { + au.Stop() + if au.Period() >= 0 { + t.Errorf("period of the stopped reloader should be minus value") + } + au.Stop() + } + if au.Stopped() != tc.stopped { + t.Errorf("#%d test failed: stopped value should be %v, got %v", + i, tc.stopped, au.Stopped()) + } + }) + } +} diff --git a/bypass.go b/bypass.go index 6e8bed9..28ca8c8 100644 --- a/bypass.go +++ b/bypass.go @@ -223,44 +223,22 @@ func (bp *Bypass) Reload(r io.Reader) error { scanner := bufio.NewScanner(r) for scanner.Scan() { line := scanner.Text() - if n := strings.IndexByte(line, '#'); n >= 0 { - line = line[:n] - } - line = strings.Replace(line, "\t", " ", -1) - line = strings.TrimSpace(line) - if line == "" { + ss := splitLine(line) + if len(ss) == 0 { continue } - - // reload option - if strings.HasPrefix(line, "reload ") { - var ss []string - for _, s := range strings.Split(line, " ") { - if s = strings.TrimSpace(s); s != "" { - ss = append(ss, s) - } - } - if len(ss) == 2 { + switch ss[0] { + case "reload": // reload option + if len(ss) > 1 { period, _ = time.ParseDuration(ss[1]) - continue } - } - - // reverse option - if strings.HasPrefix(line, "reverse ") { - var ss []string - for _, s := range strings.Split(line, " ") { - if s = strings.TrimSpace(s); s != "" { - ss = append(ss, s) - } - } - if len(ss) == 2 { + case "reverse": // reverse option + if len(ss) > 1 { reversed, _ = strconv.ParseBool(ss[1]) - continue } + default: + matchers = append(matchers, NewMatcher(ss[0])) } - - matchers = append(matchers, NewMatcher(line)) } if err := scanner.Err(); err != nil { diff --git a/bypass_test.go b/bypass_test.go index 52b1305..d895121 100644 --- a/bypass_test.go +++ b/bypass_test.go @@ -220,7 +220,7 @@ var bypassReloadTests = []struct { stopped: true, }, { - r: bytes.NewBufferString("#reverse true\n#reload 10s\n192.168.1.0/24"), + r: bytes.NewBufferString("#reverse true\n#reload 10s\n192.168.1.0/24 #comment"), reversed: false, period: 0, addr: "192.168.10.2", @@ -244,7 +244,7 @@ var bypassReloadTests = []struct { stopped: true, }, { - r: bytes.NewBufferString("#reverse true\n#reload 10s\nexample.com"), + r: bytes.NewBufferString("#reverse true\n#reload 10s\nexample.com #comment"), reversed: false, period: 0, addr: "example.com", diff --git a/cmd/gost/.config/probe_resist.txt b/cmd/gost/.config/probe_resist.txt deleted file mode 100644 index c57eff5..0000000 --- a/cmd/gost/.config/probe_resist.txt +++ /dev/null @@ -1 +0,0 @@ -Hello World! \ No newline at end of file diff --git a/cmd/gost/cfg.go b/cmd/gost/cfg.go index 1c7d98e..aca4b2e 100644 --- a/cmd/gost/cfg.go +++ b/cmd/gost/cfg.go @@ -118,6 +118,24 @@ func parseUsers(authFile string) (users []*url.Userinfo, err error) { return } +func parseAuthenticator(s string) (gost.Authenticator, error) { + if s == "" { + return nil, nil + } + f, err := os.Open(s) + if err != nil { + return nil, err + } + defer f.Close() + + au := gost.NewLocalAuthenticator(nil) + au.Reload(f) + + go gost.PeriodReload(au, s) + + return au, nil +} + func parseIP(s string, port string) (ips []string) { if s == "" { return diff --git a/cmd/gost/route.go b/cmd/gost/route.go index c9cdb6b..adf8859 100644 --- a/cmd/gost/route.go +++ b/cmd/gost/route.go @@ -257,12 +257,14 @@ func (r *route) GenRouters() ([]router, error) { if err != nil { return nil, err } - users, err := parseUsers(node.Get("secrets")) + authenticator, err := parseAuthenticator(node.Get("secrets")) if err != nil { return nil, err } - if node.User != nil { - users = append(users, node.User) + if authenticator == nil && node.User != nil { + kvs := make(map[string]string) + kvs[node.User.Username()], _ = node.User.Password() + authenticator = gost.NewLocalAuthenticator(kvs) } certFile, keyFile := node.Get("cert"), node.Get("key") tlsCfg, err := tlsConfig(certFile, keyFile) @@ -298,8 +300,8 @@ func (r *route) GenRouters() ([]router, error) { ln, err = gost.KCPListener(node.Addr, config) case "ssh": config := &gost.SSHConfig{ - Users: users, - TLSConfig: tlsCfg, + Authenticator: authenticator, + TLSConfig: tlsCfg, } if node.Protocol == "forward" { ln, err = gost.TCPListener(node.Addr) @@ -416,7 +418,7 @@ func (r *route) GenRouters() ([]router, error) { // gost.AddrHandlerOption(node.Addr), gost.AddrHandlerOption(ln.Addr().String()), gost.ChainHandlerOption(chain), - gost.UsersHandlerOption(users...), + gost.AuthenticatorHandlerOption(authenticator), gost.TLSConfigHandlerOption(tlsCfg), gost.WhitelistHandlerOption(whitelist), gost.BlacklistHandlerOption(blacklist), diff --git a/gost.go b/gost.go index 499acc6..faa630e 100644 --- a/gost.go +++ b/gost.go @@ -11,6 +11,7 @@ import ( "io" "math/big" "net" + "strings" "sync" "time" @@ -18,7 +19,7 @@ import ( ) // Version is the gost version. -const Version = "2.7" +const Version = "2.7.1" // Debug is a flag that enables the debug log. var Debug bool @@ -180,7 +181,22 @@ func (c *nopConn) SetWriteDeadline(t time.Time) error { return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} } -// Accepter represents a network endpoint that can accept connection from peer. -type Accepter interface { - Accept() (net.Conn, error) +// splitLine splits a line text by white space, mainly used by config parser. +func splitLine(line string) []string { + if line == "" { + return nil + } + if n := strings.IndexByte(line, '#'); n >= 0 { + line = line[:n] + } + line = strings.Replace(line, "\t", " ", -1) + line = strings.TrimSpace(line) + + var ss []string + for _, s := range strings.Split(line, " ") { + if s = strings.TrimSpace(s); s != "" { + ss = append(ss, s) + } + } + return ss } diff --git a/handler.go b/handler.go index 14b534e..603d110 100644 --- a/handler.go +++ b/handler.go @@ -20,21 +20,22 @@ type Handler interface { // HandlerOptions describes the options for Handler. type HandlerOptions struct { - Addr string - Chain *Chain - Users []*url.Userinfo - TLSConfig *tls.Config - Whitelist *Permissions - Blacklist *Permissions - Strategy Strategy - Bypass *Bypass - Retries int - Timeout time.Duration - Resolver Resolver - Hosts *Hosts - ProbeResist string - Node Node - Host string + Addr string + Chain *Chain + Users []*url.Userinfo + Authenticator Authenticator + TLSConfig *tls.Config + Whitelist *Permissions + Blacklist *Permissions + Strategy Strategy + Bypass *Bypass + Retries int + Timeout time.Duration + Resolver Resolver + Hosts *Hosts + ProbeResist string + Node Node + Host string } // HandlerOption allows a common way to set handler options. @@ -58,6 +59,23 @@ func ChainHandlerOption(chain *Chain) HandlerOption { func UsersHandlerOption(users ...*url.Userinfo) HandlerOption { return func(opts *HandlerOptions) { opts.Users = users + + kvs := make(map[string]string) + for _, u := range users { + if u != nil { + kvs[u.Username()], _ = u.Password() + } + } + if len(kvs) > 0 { + opts.Authenticator = NewLocalAuthenticator(kvs) + } + } +} + +// AuthenticatorHandlerOption sets the Authenticator option of HandlerOptions. +func AuthenticatorHandlerOption(au Authenticator) HandlerOption { + return func(opts *HandlerOptions) { + opts.Authenticator = au } } diff --git a/hosts.go b/hosts.go index aa0ef99..6df0325 100644 --- a/hosts.go +++ b/hosts.go @@ -4,7 +4,6 @@ import ( "bufio" "io" "net" - "strings" "sync" "time" @@ -94,42 +93,28 @@ func (h *Hosts) Reload(r io.Reader) error { scanner := bufio.NewScanner(r) for scanner.Scan() { line := scanner.Text() - if n := strings.IndexByte(line, '#'); n >= 0 { - line = line[:n] - } - line = strings.Replace(line, "\t", " ", -1) - line = strings.TrimSpace(line) - if line == "" { - continue - } - var ss []string - for _, s := range strings.Split(line, " ") { - if s = strings.TrimSpace(s); s != "" { - ss = append(ss, s) - } - } + ss := splitLine(line) if len(ss) < 2 { continue // invalid lines are ignored } - // reload option - if strings.ToLower(ss[0]) == "reload" { + switch ss[0] { + case "reload": // reload option period, _ = time.ParseDuration(ss[1]) - continue + default: + ip := net.ParseIP(ss[0]) + if ip == nil { + break // invalid IP addresses are ignored + } + host := Host{ + IP: ip, + Hostname: ss[1], + } + if len(ss) > 2 { + host.Aliases = ss[2:] + } + hosts = append(hosts, host) } - - ip := net.ParseIP(ss[0]) - if ip == nil { - continue // invalid IP addresses are ignored - } - host := Host{ - IP: ip, - Hostname: ss[1], - } - if len(ss) > 2 { - host.Aliases = ss[2:] - } - hosts = append(hosts, host) } if err := scanner.Err(); err != nil { return err diff --git a/http.go b/http.go index d815ebe..01ec157 100644 --- a/http.go +++ b/http.go @@ -299,7 +299,7 @@ func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http. log.Logf("[http] %s -> %s : Authorization '%s' '%s'", conn.RemoteAddr(), conn.LocalAddr(), u, p) } - if authenticate(u, p, h.options.Users...) { + if h.options.Authenticator == nil || h.options.Authenticator.Authenticate(u, p) { return true } @@ -423,20 +423,3 @@ func basicProxyAuth(proxyAuth string) (username, password string, ok bool) { return cs[:s], cs[s+1:], true } - -func authenticate(username, password string, users ...*url.Userinfo) bool { - if len(users) == 0 { - return true - } - - for _, user := range users { - u := user.Username() - p, _ := user.Password() - if (u == username && p == password) || - (u == username && p == "") || - (u == "" && p == password) { - return true - } - } - return false -} diff --git a/http2.go b/http2.go index 6902a62..3368cdd 100644 --- a/http2.go +++ b/http2.go @@ -457,9 +457,10 @@ func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp if Debug && (u != "" || p != "") { log.Logf("[http2] %s - %s : Authorization '%s' '%s'", r.RemoteAddr, laddr, u, p) } - if authenticate(u, p, h.options.Users...) { + if h.options.Authenticator == nil || h.options.Authenticator.Authenticate(u, p) { return true } + // probing resistance is enabled if ss := strings.SplitN(h.options.ProbeResist, ":", 2); len(ss) == 2 { switch ss[0] { diff --git a/http2_test.go b/http2_test.go index dcd6230..a1df2e3 100644 --- a/http2_test.go +++ b/http2_test.go @@ -1088,7 +1088,7 @@ func TestHTTP2ProxyWithFileProbeResist(t *testing.T) { Listener: ln, Handler: HTTP2Handler( UsersHandlerOption(url.UserPassword("admin", "123456")), - ProbeResistHandlerOption("file:.testdata/probe_resist.txt"), + ProbeResistHandlerOption("file:.config/probe_resist.txt"), ), } go server.Run() diff --git a/http_test.go b/http_test.go index 901c44e..a4f1041 100644 --- a/http_test.go +++ b/http_test.go @@ -26,7 +26,7 @@ var httpProxyTests = []struct { {url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "")}, ""}, {url.UserPassword("admin", "123456"), nil, ""}, {url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("admin")}, ""}, - {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, ""}, + {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, "407 Proxy Authentication Required"}, {url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, ""}, {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, ""}, {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("user", "pass"), url.UserPassword("admin", "123456")}, ""}, @@ -312,7 +312,7 @@ func TestHTTPProxyWithFileProbeResist(t *testing.T) { Listener: ln, Handler: HTTPHandler( UsersHandlerOption(url.UserPassword("admin", "123456")), - ProbeResistHandlerOption("file:.testdata/probe_resist.txt"), + ProbeResistHandlerOption("file:.config/probe_resist.txt"), ), } go server.Run() diff --git a/reload.go b/reload.go index e6d5648..08d708a 100644 --- a/reload.go +++ b/reload.go @@ -17,26 +17,7 @@ type Reloader interface { // Stoppable is the interface that indicates a Reloader can be stopped. type Stoppable interface { Stop() -} - -//StopReloader is the interface that adds Stop method to the Reloader. -type StopReloader interface { - Reloader - Stoppable -} - -type nopStoppable struct { - Reloader -} - -func (nopStoppable) Stop() { - return -} - -// NopStoppable returns a StopReloader with a no-op Stop method, -// wrapping the provided Reloader r. -func NopStoppable(r Reloader) StopReloader { - return nopStoppable{r} + Stopped() bool } // PeriodReload reloads the config configFile periodically according to the period of the Reloader r. diff --git a/resolver.go b/resolver.go index b0f8659..ff2b933 100644 --- a/resolver.go +++ b/resolver.go @@ -278,29 +278,10 @@ func (r *resolver) Reload(rd io.Reader) error { return nil } - split := func(line string) []string { - if line == "" { - return nil - } - if n := strings.IndexByte(line, '#'); n >= 0 { - line = line[:n] - } - line = strings.Replace(line, "\t", " ", -1) - line = strings.TrimSpace(line) - - var ss []string - for _, s := range strings.Split(line, " ") { - if s = strings.TrimSpace(s); s != "" { - ss = append(ss, s) - } - } - return ss - } - scanner := bufio.NewScanner(rd) for scanner.Scan() { line := scanner.Text() - ss := split(line) + ss := splitLine(line) if len(ss) == 0 { continue } diff --git a/server.go b/server.go index 5af2e08..86edf3e 100644 --- a/server.go +++ b/server.go @@ -8,6 +8,11 @@ import ( "github.com/go-log/log" ) +// Accepter represents a network endpoint that can accept connection from peer. +type Accepter interface { + Accept() (net.Conn, error) +} + // Server is a proxy server. type Server struct { Listener Listener diff --git a/snapcraft.yaml b/snapcraft.yaml index d970183..125e932 100644 --- a/snapcraft.yaml +++ b/snapcraft.yaml @@ -14,7 +14,7 @@ apps: parts: go: - source-tag: go1.11 + source-tag: go1.10 gost: after: [go] source: . diff --git a/socks.go b/socks.go index c2cee6d..429194c 100644 --- a/socks.go +++ b/socks.go @@ -96,9 +96,10 @@ func (selector *clientSelector) OnSelected(method uint8, conn net.Conn) (net.Con } type serverSelector struct { - methods []uint8 - Users []*url.Userinfo - TLSConfig *tls.Config + methods []uint8 + // Users []*url.Userinfo + Authenticator Authenticator + TLSConfig *tls.Config } func (selector *serverSelector) Methods() []uint8 { @@ -121,8 +122,8 @@ func (selector *serverSelector) Select(methods ...uint8) (method uint8) { } } - // when user/pass is set, auth is mandatory - if len(selector.Users) > 0 { + // when Authenticator is set, auth is mandatory + if selector.Authenticator != nil { if method == gosocks5.MethodNoAuth { method = gosocks5.MethodUserPass } @@ -155,18 +156,8 @@ func (selector *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Con if Debug { log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), req.String()) } - valid := false - for _, user := range selector.Users { - username := user.Username() - password, _ := user.Password() - if (req.Username == username && req.Password == password) || - (req.Username == username && password == "") || - (username == "" && req.Password == password) { - valid = true - break - } - } - if len(selector.Users) > 0 && !valid { + + if selector.Authenticator != nil && !selector.Authenticator.Authenticate(req.Username, req.Password) { resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Failure) if err := resp.Write(conn); err != nil { log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err) @@ -788,8 +779,9 @@ func (h *socks5Handler) Init(options ...HandlerOption) { tlsConfig = DefaultTLSConfig } h.selector = &serverSelector{ // socks5 server selector - Users: h.options.Users, - TLSConfig: tlsConfig, + // Users: h.options.Users, + Authenticator: h.options.Authenticator, + TLSConfig: tlsConfig, } // methods that socks5 server supported h.selector.AddMethod( diff --git a/socks_test.go b/socks_test.go index b2a94fa..c88d94f 100644 --- a/socks_test.go +++ b/socks_test.go @@ -25,7 +25,7 @@ var socks5ProxyTests = []struct { {url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "")}, true}, {url.UserPassword("admin", "123456"), nil, true}, {url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("admin")}, true}, - {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, true}, + {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, false}, {url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, true}, {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, true}, {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("user", "pass"), url.UserPassword("admin", "123456")}, true}, diff --git a/ssh.go b/ssh.go index 9195394..d434c31 100644 --- a/ssh.go +++ b/ssh.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "net" - "net/url" "strconv" "strings" "sync" @@ -466,8 +465,8 @@ func (h *sshForwardHandler) Init(options ...HandlerOption) { } h.config = &ssh.ServerConfig{} - h.config.PasswordCallback = defaultSSHPasswordCallback(h.options.Users...) - if len(h.options.Users) == 0 { + h.config.PasswordCallback = defaultSSHPasswordCallback(h.options.Authenticator) + if h.options.Authenticator == nil { h.config.NoClientAuth = true } tlsConfig := h.options.TLSConfig @@ -665,8 +664,8 @@ func (h *sshForwardHandler) tcpipForwardRequest(sshConn ssh.Conn, req *ssh.Reque // SSHConfig holds the SSH tunnel server config type SSHConfig struct { - Users []*url.Userinfo - TLSConfig *tls.Config + Authenticator Authenticator + TLSConfig *tls.Config } type sshTunnelListener struct { @@ -688,8 +687,8 @@ func SSHTunnelListener(addr string, config *SSHConfig) (Listener, error) { } sshConfig := &ssh.ServerConfig{} - sshConfig.PasswordCallback = defaultSSHPasswordCallback(config.Users...) - if len(config.Users) == 0 { + sshConfig.PasswordCallback = defaultSSHPasswordCallback(config.Authenticator) + if config.Authenticator == nil { sshConfig.NoClientAuth = true } tlsConfig := config.TLSConfig @@ -808,14 +807,10 @@ func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) { // PasswordCallbackFunc is a callback function used by SSH server. type PasswordCallbackFunc func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) -func defaultSSHPasswordCallback(users ...*url.Userinfo) PasswordCallbackFunc { +func defaultSSHPasswordCallback(au Authenticator) PasswordCallbackFunc { return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { - for _, user := range users { - u := user.Username() - p, _ := user.Password() - if u == conn.User() && p == string(password) { - return nil, nil - } + if au.Authenticate(conn.User(), string(password)) { + return nil, nil } log.Logf("[ssh] %s -> %s : password rejected for %s", conn.RemoteAddr(), conn.LocalAddr(), conn.User()) return nil, fmt.Errorf("password rejected for %s", conn.User())