add reloader for authenticator

This commit is contained in:
ginuerzh 2019-01-09 22:36:44 +08:00
parent 1930da5210
commit 62663564cc
31 changed files with 492 additions and 187 deletions

View File

@ -1,3 +1,6 @@
# period for live reloading
reload 3s
# username password # username password
$test.admin$ $123456$ $test.admin$ $123456$

View File

@ -19,6 +19,7 @@ gost - GO Simple Tunnel
* 多端口监听 * 多端口监听
* 可设置转发代理,支持多级转发(代理链) * 可设置转发代理,支持多级转发(代理链)
* 支持标准HTTP/HTTPS/HTTP2/SOCKS4(A)/SOCKS5代理协议 * 支持标准HTTP/HTTPS/HTTP2/SOCKS4(A)/SOCKS5代理协议
* Web代理支持[探测防御](https://docs.ginuerzh.xyz/gost/probe_resist/)
* [支持多种隧道类型](https://docs.ginuerzh.xyz/gost/configuration/) * [支持多种隧道类型](https://docs.ginuerzh.xyz/gost/configuration/)
* [SOCKS5代理支持TLS协商加密](https://docs.ginuerzh.xyz/gost/socks/) * [SOCKS5代理支持TLS协商加密](https://docs.ginuerzh.xyz/gost/socks/)
* [Tunnel UDP over TCP](https://docs.ginuerzh.xyz/gost/socks/) * [Tunnel UDP over TCP](https://docs.ginuerzh.xyz/gost/socks/)

View File

@ -16,6 +16,7 @@ Features
* Listening on multiple ports * Listening on multiple ports
* Multi-level forward proxy - proxy chain * Multi-level forward proxy - proxy chain
* Standard HTTP/HTTPS/HTTP2/SOCKS4(A)/SOCKS5 proxy protocols support * Standard HTTP/HTTPS/HTTP2/SOCKS4(A)/SOCKS5 proxy protocols support
* [Probing resistance](https://docs.ginuerzh.xyz/gost/en/probe_resist/) support for web proxy
* [Support multiple tunnel types](https://docs.ginuerzh.xyz/gost/en/configuration/) * [Support multiple tunnel types](https://docs.ginuerzh.xyz/gost/en/configuration/)
* [TLS encryption via negotiation support for SOCKS5 proxy](https://docs.ginuerzh.xyz/gost/en/socks/) * [TLS encryption via negotiation support for SOCKS5 proxy](https://docs.ginuerzh.xyz/gost/en/socks/)
* [Tunnel UDP over TCP](https://docs.ginuerzh.xyz/gost/en/socks/) * [Tunnel UDP over TCP](https://docs.ginuerzh.xyz/gost/en/socks/)

155
auth.go Normal file
View File

@ -0,0 +1,155 @@
package gost
import (
"bufio"
"io"
"strings"
"sync"
"time"
)
// Authenticator is an interface for user authentication.
type Authenticator interface {
Authenticate(user, password string) bool
}
// LocalAuthenticator is an Authenticator that authenticates client by local key-value pairs.
type LocalAuthenticator struct {
kvs map[string]string
period time.Duration
stopped chan struct{}
mux sync.RWMutex
}
// NewLocalAuthenticator creates an Authenticator that authenticates client by local infos.
func NewLocalAuthenticator(kvs map[string]string) *LocalAuthenticator {
return &LocalAuthenticator{
kvs: kvs,
stopped: make(chan struct{}),
}
}
// Authenticate checks the validity of the provided user-password pair.
func (au *LocalAuthenticator) Authenticate(user, password string) bool {
if au == nil {
return true
}
au.mux.RLock()
defer au.mux.RUnlock()
if len(au.kvs) == 0 {
return true
}
v, ok := au.kvs[user]
return ok && (v == "" || password == v)
}
// Add adds a key-value pair to the Authenticator.
func (au *LocalAuthenticator) Add(k, v string) {
au.mux.Lock()
defer au.mux.Unlock()
if au.kvs == nil {
au.kvs = make(map[string]string)
}
au.kvs[k] = v
}
// Reload parses config from r, then live reloads the bypass.
func (au *LocalAuthenticator) Reload(r io.Reader) error {
var period time.Duration
kvs := make(map[string]string)
if r == nil || au.Stopped() {
return nil
}
// splitLine splits a line text by white space.
// A line started with '#' will be ignored, otherwise it is valid.
split := func(line string) []string {
if line == "" {
return nil
}
line = strings.Replace(line, "\t", " ", -1)
line = strings.TrimSpace(line)
if strings.IndexByte(line, '#') == 0 {
return nil
}
var ss []string
for _, s := range strings.Split(line, " ") {
if s = strings.TrimSpace(s); s != "" {
ss = append(ss, s)
}
}
return ss
}
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
ss := split(line)
if len(ss) == 0 {
continue
}
switch ss[0] {
case "reload": // reload option
if len(ss) > 1 {
period, _ = time.ParseDuration(ss[1])
}
default:
var k, v string
k = ss[0]
if len(ss) > 1 {
v = ss[1]
}
kvs[k] = v
}
}
if err := scanner.Err(); err != nil {
return err
}
au.mux.Lock()
defer au.mux.Unlock()
au.period = period
au.kvs = kvs
return nil
}
// Period returns the reload period.
func (au *LocalAuthenticator) Period() time.Duration {
if au.Stopped() {
return -1
}
au.mux.RLock()
defer au.mux.RUnlock()
return au.period
}
// Stop stops reloading.
func (au *LocalAuthenticator) Stop() {
select {
case <-au.stopped:
default:
close(au.stopped)
}
}
// Stopped checks whether the reloader is stopped.
func (au *LocalAuthenticator) Stopped() bool {
select {
case <-au.stopped:
return true
default:
return false
}
}

191
auth_test.go Normal file
View File

@ -0,0 +1,191 @@
package gost
import (
"bytes"
"fmt"
"io"
"net/url"
"testing"
"time"
)
var localAuthenticatorTests = []struct {
clientUser *url.Userinfo
serverUsers []*url.Userinfo
valid bool
}{
{nil, nil, true},
{nil, []*url.Userinfo{url.User("admin")}, false},
{nil, []*url.Userinfo{url.UserPassword("", "123456")}, false},
{nil, []*url.Userinfo{url.UserPassword("admin", "123456")}, false},
{url.User("admin"), nil, true},
{url.User("admin"), []*url.Userinfo{url.User("admin")}, true},
{url.User("admin"), []*url.Userinfo{url.User("test")}, false},
{url.User("admin"), []*url.Userinfo{url.UserPassword("test", "123456")}, false},
{url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "123456")}, false},
{url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "")}, true},
{url.User("admin"), []*url.Userinfo{url.UserPassword("", "123456")}, false},
{url.UserPassword("", ""), nil, true},
{url.UserPassword("", "123456"), nil, true},
{url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, true},
{url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("admin", "")}, false},
{url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, false},
{url.UserPassword("admin", "123456"), nil, true},
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("admin")}, true},
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("test")}, false},
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "")}, true},
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, false},
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123")}, false},
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("test", "123456")}, false},
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, true},
{url.UserPassword("admin", "123456"), []*url.Userinfo{
url.UserPassword("test", "123"),
url.UserPassword("admin", "123456"),
}, true},
}
func TestLocalAuthenticator(t *testing.T) {
for i, tc := range localAuthenticatorTests {
tc := tc
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
au := NewLocalAuthenticator(nil)
for _, u := range tc.serverUsers {
if u != nil {
p, _ := u.Password()
au.Add(u.Username(), p)
}
}
var u, p string
if tc.clientUser != nil {
u = tc.clientUser.Username()
p, _ = tc.clientUser.Password()
}
if au.Authenticate(u, p) != tc.valid {
t.Error("authenticate result should be", tc.valid)
}
})
}
}
var localAuthenticatorReloadTests = []struct {
r io.Reader
period time.Duration
kvs map[string]string
stopped bool
}{
{
r: nil,
period: 0,
kvs: nil,
},
{
r: bytes.NewBufferString(""),
period: 0,
},
{
r: bytes.NewBufferString("reload 10s"),
period: 10 * time.Second,
},
{
r: bytes.NewBufferString("# reload 10s\n"),
},
{
r: bytes.NewBufferString("reload 10s\n#admin"),
period: 10 * time.Second,
},
{
r: bytes.NewBufferString("reload 10s\nadmin"),
period: 10 * time.Second,
kvs: map[string]string{
"admin": "",
},
},
{
r: bytes.NewBufferString("# reload 10s\nadmin"),
kvs: map[string]string{
"admin": "",
},
},
{
r: bytes.NewBufferString("# reload 10s\nadmin #123456"),
kvs: map[string]string{
"admin": "#123456",
},
stopped: true,
},
{
r: bytes.NewBufferString("admin \t #123456\n\n\ntest \t 123456"),
kvs: map[string]string{
"admin": "#123456",
"test": "123456",
},
stopped: true,
},
{
r: bytes.NewBufferString(`
$test.admin$ $123456$
@test.admin@ @123456@
test.admin# #123456#
test.admin\admin 123456
`),
kvs: map[string]string{
"$test.admin$": "$123456$",
"@test.admin@": "@123456@",
"test.admin#": "#123456#",
"test.admin\\admin": "123456",
},
stopped: true,
},
}
func TestLocalAuthenticatorReload(t *testing.T) {
isEquals := func(a, b map[string]string) bool {
if len(a) == 0 && len(b) == 0 {
return true
}
if len(a) != len(b) {
return false
}
for k, v := range a {
if b[k] != v {
return false
}
}
return true
}
for i, tc := range localAuthenticatorReloadTests {
tc := tc
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
au := NewLocalAuthenticator(nil)
if err := au.Reload(tc.r); err != nil {
t.Error(err)
}
if au.Period() != tc.period {
t.Errorf("#%d test failed: period value should be %v, got %v",
i, tc.period, au.Period())
}
if !isEquals(au.kvs, tc.kvs) {
t.Errorf("#%d test failed: %v, %s", i, au.kvs, tc.kvs)
}
if tc.stopped {
au.Stop()
if au.Period() >= 0 {
t.Errorf("period of the stopped reloader should be minus value")
}
au.Stop()
}
if au.Stopped() != tc.stopped {
t.Errorf("#%d test failed: stopped value should be %v, got %v",
i, tc.stopped, au.Stopped())
}
})
}
}

View File

@ -223,44 +223,22 @@ func (bp *Bypass) Reload(r io.Reader) error {
scanner := bufio.NewScanner(r) scanner := bufio.NewScanner(r)
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
if n := strings.IndexByte(line, '#'); n >= 0 { ss := splitLine(line)
line = line[:n] if len(ss) == 0 {
}
line = strings.Replace(line, "\t", " ", -1)
line = strings.TrimSpace(line)
if line == "" {
continue continue
} }
switch ss[0] {
// reload option case "reload": // reload option
if strings.HasPrefix(line, "reload ") { if len(ss) > 1 {
var ss []string
for _, s := range strings.Split(line, " ") {
if s = strings.TrimSpace(s); s != "" {
ss = append(ss, s)
}
}
if len(ss) == 2 {
period, _ = time.ParseDuration(ss[1]) period, _ = time.ParseDuration(ss[1])
continue
} }
} case "reverse": // reverse option
if len(ss) > 1 {
// reverse option
if strings.HasPrefix(line, "reverse ") {
var ss []string
for _, s := range strings.Split(line, " ") {
if s = strings.TrimSpace(s); s != "" {
ss = append(ss, s)
}
}
if len(ss) == 2 {
reversed, _ = strconv.ParseBool(ss[1]) reversed, _ = strconv.ParseBool(ss[1])
continue
} }
default:
matchers = append(matchers, NewMatcher(ss[0]))
} }
matchers = append(matchers, NewMatcher(line))
} }
if err := scanner.Err(); err != nil { if err := scanner.Err(); err != nil {

View File

@ -220,7 +220,7 @@ var bypassReloadTests = []struct {
stopped: true, stopped: true,
}, },
{ {
r: bytes.NewBufferString("#reverse true\n#reload 10s\n192.168.1.0/24"), r: bytes.NewBufferString("#reverse true\n#reload 10s\n192.168.1.0/24 #comment"),
reversed: false, reversed: false,
period: 0, period: 0,
addr: "192.168.10.2", addr: "192.168.10.2",
@ -244,7 +244,7 @@ var bypassReloadTests = []struct {
stopped: true, stopped: true,
}, },
{ {
r: bytes.NewBufferString("#reverse true\n#reload 10s\nexample.com"), r: bytes.NewBufferString("#reverse true\n#reload 10s\nexample.com #comment"),
reversed: false, reversed: false,
period: 0, period: 0,
addr: "example.com", addr: "example.com",

View File

@ -1 +0,0 @@
Hello World!

View File

@ -118,6 +118,24 @@ func parseUsers(authFile string) (users []*url.Userinfo, err error) {
return return
} }
func parseAuthenticator(s string) (gost.Authenticator, error) {
if s == "" {
return nil, nil
}
f, err := os.Open(s)
if err != nil {
return nil, err
}
defer f.Close()
au := gost.NewLocalAuthenticator(nil)
au.Reload(f)
go gost.PeriodReload(au, s)
return au, nil
}
func parseIP(s string, port string) (ips []string) { func parseIP(s string, port string) (ips []string) {
if s == "" { if s == "" {
return return

View File

@ -257,12 +257,14 @@ func (r *route) GenRouters() ([]router, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
users, err := parseUsers(node.Get("secrets")) authenticator, err := parseAuthenticator(node.Get("secrets"))
if err != nil { if err != nil {
return nil, err return nil, err
} }
if node.User != nil { if authenticator == nil && node.User != nil {
users = append(users, node.User) kvs := make(map[string]string)
kvs[node.User.Username()], _ = node.User.Password()
authenticator = gost.NewLocalAuthenticator(kvs)
} }
certFile, keyFile := node.Get("cert"), node.Get("key") certFile, keyFile := node.Get("cert"), node.Get("key")
tlsCfg, err := tlsConfig(certFile, keyFile) tlsCfg, err := tlsConfig(certFile, keyFile)
@ -298,8 +300,8 @@ func (r *route) GenRouters() ([]router, error) {
ln, err = gost.KCPListener(node.Addr, config) ln, err = gost.KCPListener(node.Addr, config)
case "ssh": case "ssh":
config := &gost.SSHConfig{ config := &gost.SSHConfig{
Users: users, Authenticator: authenticator,
TLSConfig: tlsCfg, TLSConfig: tlsCfg,
} }
if node.Protocol == "forward" { if node.Protocol == "forward" {
ln, err = gost.TCPListener(node.Addr) ln, err = gost.TCPListener(node.Addr)
@ -416,7 +418,7 @@ func (r *route) GenRouters() ([]router, error) {
// gost.AddrHandlerOption(node.Addr), // gost.AddrHandlerOption(node.Addr),
gost.AddrHandlerOption(ln.Addr().String()), gost.AddrHandlerOption(ln.Addr().String()),
gost.ChainHandlerOption(chain), gost.ChainHandlerOption(chain),
gost.UsersHandlerOption(users...), gost.AuthenticatorHandlerOption(authenticator),
gost.TLSConfigHandlerOption(tlsCfg), gost.TLSConfigHandlerOption(tlsCfg),
gost.WhitelistHandlerOption(whitelist), gost.WhitelistHandlerOption(whitelist),
gost.BlacklistHandlerOption(blacklist), gost.BlacklistHandlerOption(blacklist),

24
gost.go
View File

@ -11,6 +11,7 @@ import (
"io" "io"
"math/big" "math/big"
"net" "net"
"strings"
"sync" "sync"
"time" "time"
@ -18,7 +19,7 @@ import (
) )
// Version is the gost version. // Version is the gost version.
const Version = "2.7" const Version = "2.7.1"
// Debug is a flag that enables the debug log. // Debug is a flag that enables the debug log.
var Debug bool var Debug bool
@ -180,7 +181,22 @@ func (c *nopConn) SetWriteDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
} }
// Accepter represents a network endpoint that can accept connection from peer. // splitLine splits a line text by white space, mainly used by config parser.
type Accepter interface { func splitLine(line string) []string {
Accept() (net.Conn, error) if line == "" {
return nil
}
if n := strings.IndexByte(line, '#'); n >= 0 {
line = line[:n]
}
line = strings.Replace(line, "\t", " ", -1)
line = strings.TrimSpace(line)
var ss []string
for _, s := range strings.Split(line, " ") {
if s = strings.TrimSpace(s); s != "" {
ss = append(ss, s)
}
}
return ss
} }

View File

@ -20,21 +20,22 @@ type Handler interface {
// HandlerOptions describes the options for Handler. // HandlerOptions describes the options for Handler.
type HandlerOptions struct { type HandlerOptions struct {
Addr string Addr string
Chain *Chain Chain *Chain
Users []*url.Userinfo Users []*url.Userinfo
TLSConfig *tls.Config Authenticator Authenticator
Whitelist *Permissions TLSConfig *tls.Config
Blacklist *Permissions Whitelist *Permissions
Strategy Strategy Blacklist *Permissions
Bypass *Bypass Strategy Strategy
Retries int Bypass *Bypass
Timeout time.Duration Retries int
Resolver Resolver Timeout time.Duration
Hosts *Hosts Resolver Resolver
ProbeResist string Hosts *Hosts
Node Node ProbeResist string
Host string Node Node
Host string
} }
// HandlerOption allows a common way to set handler options. // HandlerOption allows a common way to set handler options.
@ -58,6 +59,23 @@ func ChainHandlerOption(chain *Chain) HandlerOption {
func UsersHandlerOption(users ...*url.Userinfo) HandlerOption { func UsersHandlerOption(users ...*url.Userinfo) HandlerOption {
return func(opts *HandlerOptions) { return func(opts *HandlerOptions) {
opts.Users = users opts.Users = users
kvs := make(map[string]string)
for _, u := range users {
if u != nil {
kvs[u.Username()], _ = u.Password()
}
}
if len(kvs) > 0 {
opts.Authenticator = NewLocalAuthenticator(kvs)
}
}
}
// AuthenticatorHandlerOption sets the Authenticator option of HandlerOptions.
func AuthenticatorHandlerOption(au Authenticator) HandlerOption {
return func(opts *HandlerOptions) {
opts.Authenticator = au
} }
} }

View File

@ -4,7 +4,6 @@ import (
"bufio" "bufio"
"io" "io"
"net" "net"
"strings"
"sync" "sync"
"time" "time"
@ -94,42 +93,28 @@ func (h *Hosts) Reload(r io.Reader) error {
scanner := bufio.NewScanner(r) scanner := bufio.NewScanner(r)
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
if n := strings.IndexByte(line, '#'); n >= 0 { ss := splitLine(line)
line = line[:n]
}
line = strings.Replace(line, "\t", " ", -1)
line = strings.TrimSpace(line)
if line == "" {
continue
}
var ss []string
for _, s := range strings.Split(line, " ") {
if s = strings.TrimSpace(s); s != "" {
ss = append(ss, s)
}
}
if len(ss) < 2 { if len(ss) < 2 {
continue // invalid lines are ignored continue // invalid lines are ignored
} }
// reload option switch ss[0] {
if strings.ToLower(ss[0]) == "reload" { case "reload": // reload option
period, _ = time.ParseDuration(ss[1]) period, _ = time.ParseDuration(ss[1])
continue default:
ip := net.ParseIP(ss[0])
if ip == nil {
break // invalid IP addresses are ignored
}
host := Host{
IP: ip,
Hostname: ss[1],
}
if len(ss) > 2 {
host.Aliases = ss[2:]
}
hosts = append(hosts, host)
} }
ip := net.ParseIP(ss[0])
if ip == nil {
continue // invalid IP addresses are ignored
}
host := Host{
IP: ip,
Hostname: ss[1],
}
if len(ss) > 2 {
host.Aliases = ss[2:]
}
hosts = append(hosts, host)
} }
if err := scanner.Err(); err != nil { if err := scanner.Err(); err != nil {
return err return err

19
http.go
View File

@ -299,7 +299,7 @@ func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.
log.Logf("[http] %s -> %s : Authorization '%s' '%s'", log.Logf("[http] %s -> %s : Authorization '%s' '%s'",
conn.RemoteAddr(), conn.LocalAddr(), u, p) conn.RemoteAddr(), conn.LocalAddr(), u, p)
} }
if authenticate(u, p, h.options.Users...) { if h.options.Authenticator == nil || h.options.Authenticator.Authenticate(u, p) {
return true return true
} }
@ -423,20 +423,3 @@ func basicProxyAuth(proxyAuth string) (username, password string, ok bool) {
return cs[:s], cs[s+1:], true return cs[:s], cs[s+1:], true
} }
func authenticate(username, password string, users ...*url.Userinfo) bool {
if len(users) == 0 {
return true
}
for _, user := range users {
u := user.Username()
p, _ := user.Password()
if (u == username && p == password) ||
(u == username && p == "") ||
(u == "" && p == password) {
return true
}
}
return false
}

View File

@ -457,9 +457,10 @@ func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp
if Debug && (u != "" || p != "") { if Debug && (u != "" || p != "") {
log.Logf("[http2] %s - %s : Authorization '%s' '%s'", r.RemoteAddr, laddr, u, p) log.Logf("[http2] %s - %s : Authorization '%s' '%s'", r.RemoteAddr, laddr, u, p)
} }
if authenticate(u, p, h.options.Users...) { if h.options.Authenticator == nil || h.options.Authenticator.Authenticate(u, p) {
return true return true
} }
// probing resistance is enabled // probing resistance is enabled
if ss := strings.SplitN(h.options.ProbeResist, ":", 2); len(ss) == 2 { if ss := strings.SplitN(h.options.ProbeResist, ":", 2); len(ss) == 2 {
switch ss[0] { switch ss[0] {

View File

@ -1088,7 +1088,7 @@ func TestHTTP2ProxyWithFileProbeResist(t *testing.T) {
Listener: ln, Listener: ln,
Handler: HTTP2Handler( Handler: HTTP2Handler(
UsersHandlerOption(url.UserPassword("admin", "123456")), UsersHandlerOption(url.UserPassword("admin", "123456")),
ProbeResistHandlerOption("file:.testdata/probe_resist.txt"), ProbeResistHandlerOption("file:.config/probe_resist.txt"),
), ),
} }
go server.Run() go server.Run()

View File

@ -26,7 +26,7 @@ var httpProxyTests = []struct {
{url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "")}, ""}, {url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "")}, ""},
{url.UserPassword("admin", "123456"), nil, ""}, {url.UserPassword("admin", "123456"), nil, ""},
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("admin")}, ""}, {url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("admin")}, ""},
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, ""}, {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, "407 Proxy Authentication Required"},
{url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, ""}, {url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, ""},
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, ""}, {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, ""},
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("user", "pass"), url.UserPassword("admin", "123456")}, ""}, {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("user", "pass"), url.UserPassword("admin", "123456")}, ""},
@ -312,7 +312,7 @@ func TestHTTPProxyWithFileProbeResist(t *testing.T) {
Listener: ln, Listener: ln,
Handler: HTTPHandler( Handler: HTTPHandler(
UsersHandlerOption(url.UserPassword("admin", "123456")), UsersHandlerOption(url.UserPassword("admin", "123456")),
ProbeResistHandlerOption("file:.testdata/probe_resist.txt"), ProbeResistHandlerOption("file:.config/probe_resist.txt"),
), ),
} }
go server.Run() go server.Run()

View File

@ -17,26 +17,7 @@ type Reloader interface {
// Stoppable is the interface that indicates a Reloader can be stopped. // Stoppable is the interface that indicates a Reloader can be stopped.
type Stoppable interface { type Stoppable interface {
Stop() Stop()
} Stopped() bool
//StopReloader is the interface that adds Stop method to the Reloader.
type StopReloader interface {
Reloader
Stoppable
}
type nopStoppable struct {
Reloader
}
func (nopStoppable) Stop() {
return
}
// NopStoppable returns a StopReloader with a no-op Stop method,
// wrapping the provided Reloader r.
func NopStoppable(r Reloader) StopReloader {
return nopStoppable{r}
} }
// PeriodReload reloads the config configFile periodically according to the period of the Reloader r. // PeriodReload reloads the config configFile periodically according to the period of the Reloader r.

View File

@ -278,29 +278,10 @@ func (r *resolver) Reload(rd io.Reader) error {
return nil return nil
} }
split := func(line string) []string {
if line == "" {
return nil
}
if n := strings.IndexByte(line, '#'); n >= 0 {
line = line[:n]
}
line = strings.Replace(line, "\t", " ", -1)
line = strings.TrimSpace(line)
var ss []string
for _, s := range strings.Split(line, " ") {
if s = strings.TrimSpace(s); s != "" {
ss = append(ss, s)
}
}
return ss
}
scanner := bufio.NewScanner(rd) scanner := bufio.NewScanner(rd)
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
ss := split(line) ss := splitLine(line)
if len(ss) == 0 { if len(ss) == 0 {
continue continue
} }

View File

@ -8,6 +8,11 @@ import (
"github.com/go-log/log" "github.com/go-log/log"
) )
// Accepter represents a network endpoint that can accept connection from peer.
type Accepter interface {
Accept() (net.Conn, error)
}
// Server is a proxy server. // Server is a proxy server.
type Server struct { type Server struct {
Listener Listener Listener Listener

View File

@ -14,7 +14,7 @@ apps:
parts: parts:
go: go:
source-tag: go1.11 source-tag: go1.10
gost: gost:
after: [go] after: [go]
source: . source: .

View File

@ -96,9 +96,10 @@ func (selector *clientSelector) OnSelected(method uint8, conn net.Conn) (net.Con
} }
type serverSelector struct { type serverSelector struct {
methods []uint8 methods []uint8
Users []*url.Userinfo // Users []*url.Userinfo
TLSConfig *tls.Config Authenticator Authenticator
TLSConfig *tls.Config
} }
func (selector *serverSelector) Methods() []uint8 { func (selector *serverSelector) Methods() []uint8 {
@ -121,8 +122,8 @@ func (selector *serverSelector) Select(methods ...uint8) (method uint8) {
} }
} }
// when user/pass is set, auth is mandatory // when Authenticator is set, auth is mandatory
if len(selector.Users) > 0 { if selector.Authenticator != nil {
if method == gosocks5.MethodNoAuth { if method == gosocks5.MethodNoAuth {
method = gosocks5.MethodUserPass method = gosocks5.MethodUserPass
} }
@ -155,18 +156,8 @@ func (selector *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Con
if Debug { if Debug {
log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), req.String()) log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), req.String())
} }
valid := false
for _, user := range selector.Users { if selector.Authenticator != nil && !selector.Authenticator.Authenticate(req.Username, req.Password) {
username := user.Username()
password, _ := user.Password()
if (req.Username == username && req.Password == password) ||
(req.Username == username && password == "") ||
(username == "" && req.Password == password) {
valid = true
break
}
}
if len(selector.Users) > 0 && !valid {
resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Failure) resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Failure)
if err := resp.Write(conn); err != nil { if err := resp.Write(conn); err != nil {
log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err) log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err)
@ -788,8 +779,9 @@ func (h *socks5Handler) Init(options ...HandlerOption) {
tlsConfig = DefaultTLSConfig tlsConfig = DefaultTLSConfig
} }
h.selector = &serverSelector{ // socks5 server selector h.selector = &serverSelector{ // socks5 server selector
Users: h.options.Users, // Users: h.options.Users,
TLSConfig: tlsConfig, Authenticator: h.options.Authenticator,
TLSConfig: tlsConfig,
} }
// methods that socks5 server supported // methods that socks5 server supported
h.selector.AddMethod( h.selector.AddMethod(

View File

@ -25,7 +25,7 @@ var socks5ProxyTests = []struct {
{url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "")}, true}, {url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "")}, true},
{url.UserPassword("admin", "123456"), nil, true}, {url.UserPassword("admin", "123456"), nil, true},
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("admin")}, true}, {url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("admin")}, true},
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, true}, {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, false},
{url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, true}, {url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, true},
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, true}, {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, true},
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("user", "pass"), url.UserPassword("admin", "123456")}, true}, {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("user", "pass"), url.UserPassword("admin", "123456")}, true},

23
ssh.go
View File

@ -7,7 +7,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/url"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -466,8 +465,8 @@ func (h *sshForwardHandler) Init(options ...HandlerOption) {
} }
h.config = &ssh.ServerConfig{} h.config = &ssh.ServerConfig{}
h.config.PasswordCallback = defaultSSHPasswordCallback(h.options.Users...) h.config.PasswordCallback = defaultSSHPasswordCallback(h.options.Authenticator)
if len(h.options.Users) == 0 { if h.options.Authenticator == nil {
h.config.NoClientAuth = true h.config.NoClientAuth = true
} }
tlsConfig := h.options.TLSConfig tlsConfig := h.options.TLSConfig
@ -665,8 +664,8 @@ func (h *sshForwardHandler) tcpipForwardRequest(sshConn ssh.Conn, req *ssh.Reque
// SSHConfig holds the SSH tunnel server config // SSHConfig holds the SSH tunnel server config
type SSHConfig struct { type SSHConfig struct {
Users []*url.Userinfo Authenticator Authenticator
TLSConfig *tls.Config TLSConfig *tls.Config
} }
type sshTunnelListener struct { type sshTunnelListener struct {
@ -688,8 +687,8 @@ func SSHTunnelListener(addr string, config *SSHConfig) (Listener, error) {
} }
sshConfig := &ssh.ServerConfig{} sshConfig := &ssh.ServerConfig{}
sshConfig.PasswordCallback = defaultSSHPasswordCallback(config.Users...) sshConfig.PasswordCallback = defaultSSHPasswordCallback(config.Authenticator)
if len(config.Users) == 0 { if config.Authenticator == nil {
sshConfig.NoClientAuth = true sshConfig.NoClientAuth = true
} }
tlsConfig := config.TLSConfig tlsConfig := config.TLSConfig
@ -808,14 +807,10 @@ func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) {
// PasswordCallbackFunc is a callback function used by SSH server. // PasswordCallbackFunc is a callback function used by SSH server.
type PasswordCallbackFunc func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) type PasswordCallbackFunc func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error)
func defaultSSHPasswordCallback(users ...*url.Userinfo) PasswordCallbackFunc { func defaultSSHPasswordCallback(au Authenticator) PasswordCallbackFunc {
return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
for _, user := range users { if au.Authenticate(conn.User(), string(password)) {
u := user.Username() return nil, nil
p, _ := user.Password()
if u == conn.User() && p == string(password) {
return nil, nil
}
} }
log.Logf("[ssh] %s -> %s : password rejected for %s", conn.RemoteAddr(), conn.LocalAddr(), conn.User()) log.Logf("[ssh] %s -> %s : password rejected for %s", conn.RemoteAddr(), conn.LocalAddr(), conn.User())
return nil, fmt.Errorf("password rejected for %s", conn.User()) return nil, fmt.Errorf("password rejected for %s", conn.User())