add reloader for authenticator
This commit is contained in:
parent
1930da5210
commit
62663564cc
@ -1,3 +1,6 @@
|
|||||||
|
# period for live reloading
|
||||||
|
reload 3s
|
||||||
|
|
||||||
# username password
|
# username password
|
||||||
|
|
||||||
$test.admin$ $123456$
|
$test.admin$ $123456$
|
@ -19,6 +19,7 @@ gost - GO Simple Tunnel
|
|||||||
* 多端口监听
|
* 多端口监听
|
||||||
* 可设置转发代理,支持多级转发(代理链)
|
* 可设置转发代理,支持多级转发(代理链)
|
||||||
* 支持标准HTTP/HTTPS/HTTP2/SOCKS4(A)/SOCKS5代理协议
|
* 支持标准HTTP/HTTPS/HTTP2/SOCKS4(A)/SOCKS5代理协议
|
||||||
|
* Web代理支持[探测防御](https://docs.ginuerzh.xyz/gost/probe_resist/)
|
||||||
* [支持多种隧道类型](https://docs.ginuerzh.xyz/gost/configuration/)
|
* [支持多种隧道类型](https://docs.ginuerzh.xyz/gost/configuration/)
|
||||||
* [SOCKS5代理支持TLS协商加密](https://docs.ginuerzh.xyz/gost/socks/)
|
* [SOCKS5代理支持TLS协商加密](https://docs.ginuerzh.xyz/gost/socks/)
|
||||||
* [Tunnel UDP over TCP](https://docs.ginuerzh.xyz/gost/socks/)
|
* [Tunnel UDP over TCP](https://docs.ginuerzh.xyz/gost/socks/)
|
||||||
|
@ -16,6 +16,7 @@ Features
|
|||||||
* Listening on multiple ports
|
* Listening on multiple ports
|
||||||
* Multi-level forward proxy - proxy chain
|
* Multi-level forward proxy - proxy chain
|
||||||
* Standard HTTP/HTTPS/HTTP2/SOCKS4(A)/SOCKS5 proxy protocols support
|
* Standard HTTP/HTTPS/HTTP2/SOCKS4(A)/SOCKS5 proxy protocols support
|
||||||
|
* [Probing resistance](https://docs.ginuerzh.xyz/gost/en/probe_resist/) support for web proxy
|
||||||
* [Support multiple tunnel types](https://docs.ginuerzh.xyz/gost/en/configuration/)
|
* [Support multiple tunnel types](https://docs.ginuerzh.xyz/gost/en/configuration/)
|
||||||
* [TLS encryption via negotiation support for SOCKS5 proxy](https://docs.ginuerzh.xyz/gost/en/socks/)
|
* [TLS encryption via negotiation support for SOCKS5 proxy](https://docs.ginuerzh.xyz/gost/en/socks/)
|
||||||
* [Tunnel UDP over TCP](https://docs.ginuerzh.xyz/gost/en/socks/)
|
* [Tunnel UDP over TCP](https://docs.ginuerzh.xyz/gost/en/socks/)
|
||||||
|
155
auth.go
Normal file
155
auth.go
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
package gost
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Authenticator is an interface for user authentication.
|
||||||
|
type Authenticator interface {
|
||||||
|
Authenticate(user, password string) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalAuthenticator is an Authenticator that authenticates client by local key-value pairs.
|
||||||
|
type LocalAuthenticator struct {
|
||||||
|
kvs map[string]string
|
||||||
|
period time.Duration
|
||||||
|
stopped chan struct{}
|
||||||
|
mux sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLocalAuthenticator creates an Authenticator that authenticates client by local infos.
|
||||||
|
func NewLocalAuthenticator(kvs map[string]string) *LocalAuthenticator {
|
||||||
|
return &LocalAuthenticator{
|
||||||
|
kvs: kvs,
|
||||||
|
stopped: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authenticate checks the validity of the provided user-password pair.
|
||||||
|
func (au *LocalAuthenticator) Authenticate(user, password string) bool {
|
||||||
|
if au == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
au.mux.RLock()
|
||||||
|
defer au.mux.RUnlock()
|
||||||
|
|
||||||
|
if len(au.kvs) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
v, ok := au.kvs[user]
|
||||||
|
return ok && (v == "" || password == v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add adds a key-value pair to the Authenticator.
|
||||||
|
func (au *LocalAuthenticator) Add(k, v string) {
|
||||||
|
au.mux.Lock()
|
||||||
|
defer au.mux.Unlock()
|
||||||
|
if au.kvs == nil {
|
||||||
|
au.kvs = make(map[string]string)
|
||||||
|
}
|
||||||
|
au.kvs[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reload parses config from r, then live reloads the bypass.
|
||||||
|
func (au *LocalAuthenticator) Reload(r io.Reader) error {
|
||||||
|
var period time.Duration
|
||||||
|
kvs := make(map[string]string)
|
||||||
|
|
||||||
|
if r == nil || au.Stopped() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// splitLine splits a line text by white space.
|
||||||
|
// A line started with '#' will be ignored, otherwise it is valid.
|
||||||
|
split := func(line string) []string {
|
||||||
|
if line == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
line = strings.Replace(line, "\t", " ", -1)
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
|
||||||
|
if strings.IndexByte(line, '#') == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var ss []string
|
||||||
|
for _, s := range strings.Split(line, " ") {
|
||||||
|
if s = strings.TrimSpace(s); s != "" {
|
||||||
|
ss = append(ss, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ss
|
||||||
|
}
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(r)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
ss := split(line)
|
||||||
|
if len(ss) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch ss[0] {
|
||||||
|
case "reload": // reload option
|
||||||
|
if len(ss) > 1 {
|
||||||
|
period, _ = time.ParseDuration(ss[1])
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
var k, v string
|
||||||
|
k = ss[0]
|
||||||
|
if len(ss) > 1 {
|
||||||
|
v = ss[1]
|
||||||
|
}
|
||||||
|
kvs[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
au.mux.Lock()
|
||||||
|
defer au.mux.Unlock()
|
||||||
|
|
||||||
|
au.period = period
|
||||||
|
au.kvs = kvs
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Period returns the reload period.
|
||||||
|
func (au *LocalAuthenticator) Period() time.Duration {
|
||||||
|
if au.Stopped() {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
au.mux.RLock()
|
||||||
|
defer au.mux.RUnlock()
|
||||||
|
|
||||||
|
return au.period
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops reloading.
|
||||||
|
func (au *LocalAuthenticator) Stop() {
|
||||||
|
select {
|
||||||
|
case <-au.stopped:
|
||||||
|
default:
|
||||||
|
close(au.stopped)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stopped checks whether the reloader is stopped.
|
||||||
|
func (au *LocalAuthenticator) Stopped() bool {
|
||||||
|
select {
|
||||||
|
case <-au.stopped:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
191
auth_test.go
Normal file
191
auth_test.go
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
package gost
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var localAuthenticatorTests = []struct {
|
||||||
|
clientUser *url.Userinfo
|
||||||
|
serverUsers []*url.Userinfo
|
||||||
|
valid bool
|
||||||
|
}{
|
||||||
|
{nil, nil, true},
|
||||||
|
{nil, []*url.Userinfo{url.User("admin")}, false},
|
||||||
|
{nil, []*url.Userinfo{url.UserPassword("", "123456")}, false},
|
||||||
|
{nil, []*url.Userinfo{url.UserPassword("admin", "123456")}, false},
|
||||||
|
|
||||||
|
{url.User("admin"), nil, true},
|
||||||
|
{url.User("admin"), []*url.Userinfo{url.User("admin")}, true},
|
||||||
|
{url.User("admin"), []*url.Userinfo{url.User("test")}, false},
|
||||||
|
{url.User("admin"), []*url.Userinfo{url.UserPassword("test", "123456")}, false},
|
||||||
|
{url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "123456")}, false},
|
||||||
|
{url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "")}, true},
|
||||||
|
{url.User("admin"), []*url.Userinfo{url.UserPassword("", "123456")}, false},
|
||||||
|
|
||||||
|
{url.UserPassword("", ""), nil, true},
|
||||||
|
{url.UserPassword("", "123456"), nil, true},
|
||||||
|
{url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, true},
|
||||||
|
{url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("admin", "")}, false},
|
||||||
|
{url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, false},
|
||||||
|
|
||||||
|
{url.UserPassword("admin", "123456"), nil, true},
|
||||||
|
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("admin")}, true},
|
||||||
|
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("test")}, false},
|
||||||
|
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "")}, true},
|
||||||
|
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, false},
|
||||||
|
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123")}, false},
|
||||||
|
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("test", "123456")}, false},
|
||||||
|
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, true},
|
||||||
|
|
||||||
|
{url.UserPassword("admin", "123456"), []*url.Userinfo{
|
||||||
|
url.UserPassword("test", "123"),
|
||||||
|
url.UserPassword("admin", "123456"),
|
||||||
|
}, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocalAuthenticator(t *testing.T) {
|
||||||
|
for i, tc := range localAuthenticatorTests {
|
||||||
|
tc := tc
|
||||||
|
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
|
||||||
|
au := NewLocalAuthenticator(nil)
|
||||||
|
for _, u := range tc.serverUsers {
|
||||||
|
if u != nil {
|
||||||
|
p, _ := u.Password()
|
||||||
|
au.Add(u.Username(), p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var u, p string
|
||||||
|
if tc.clientUser != nil {
|
||||||
|
u = tc.clientUser.Username()
|
||||||
|
p, _ = tc.clientUser.Password()
|
||||||
|
}
|
||||||
|
if au.Authenticate(u, p) != tc.valid {
|
||||||
|
t.Error("authenticate result should be", tc.valid)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var localAuthenticatorReloadTests = []struct {
|
||||||
|
r io.Reader
|
||||||
|
period time.Duration
|
||||||
|
kvs map[string]string
|
||||||
|
stopped bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
r: nil,
|
||||||
|
period: 0,
|
||||||
|
kvs: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
r: bytes.NewBufferString(""),
|
||||||
|
period: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
r: bytes.NewBufferString("reload 10s"),
|
||||||
|
period: 10 * time.Second,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
r: bytes.NewBufferString("# reload 10s\n"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
r: bytes.NewBufferString("reload 10s\n#admin"),
|
||||||
|
period: 10 * time.Second,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
r: bytes.NewBufferString("reload 10s\nadmin"),
|
||||||
|
period: 10 * time.Second,
|
||||||
|
kvs: map[string]string{
|
||||||
|
"admin": "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
r: bytes.NewBufferString("# reload 10s\nadmin"),
|
||||||
|
kvs: map[string]string{
|
||||||
|
"admin": "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
r: bytes.NewBufferString("# reload 10s\nadmin #123456"),
|
||||||
|
kvs: map[string]string{
|
||||||
|
"admin": "#123456",
|
||||||
|
},
|
||||||
|
stopped: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
r: bytes.NewBufferString("admin \t #123456\n\n\ntest \t 123456"),
|
||||||
|
kvs: map[string]string{
|
||||||
|
"admin": "#123456",
|
||||||
|
"test": "123456",
|
||||||
|
},
|
||||||
|
stopped: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
r: bytes.NewBufferString(`
|
||||||
|
$test.admin$ $123456$
|
||||||
|
@test.admin@ @123456@
|
||||||
|
test.admin# #123456#
|
||||||
|
test.admin\admin 123456
|
||||||
|
`),
|
||||||
|
kvs: map[string]string{
|
||||||
|
"$test.admin$": "$123456$",
|
||||||
|
"@test.admin@": "@123456@",
|
||||||
|
"test.admin#": "#123456#",
|
||||||
|
"test.admin\\admin": "123456",
|
||||||
|
},
|
||||||
|
stopped: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocalAuthenticatorReload(t *testing.T) {
|
||||||
|
isEquals := func(a, b map[string]string) bool {
|
||||||
|
if len(a) == 0 && len(b) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range a {
|
||||||
|
if b[k] != v {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for i, tc := range localAuthenticatorReloadTests {
|
||||||
|
tc := tc
|
||||||
|
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
|
||||||
|
au := NewLocalAuthenticator(nil)
|
||||||
|
|
||||||
|
if err := au.Reload(tc.r); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if au.Period() != tc.period {
|
||||||
|
t.Errorf("#%d test failed: period value should be %v, got %v",
|
||||||
|
i, tc.period, au.Period())
|
||||||
|
}
|
||||||
|
if !isEquals(au.kvs, tc.kvs) {
|
||||||
|
t.Errorf("#%d test failed: %v, %s", i, au.kvs, tc.kvs)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.stopped {
|
||||||
|
au.Stop()
|
||||||
|
if au.Period() >= 0 {
|
||||||
|
t.Errorf("period of the stopped reloader should be minus value")
|
||||||
|
}
|
||||||
|
au.Stop()
|
||||||
|
}
|
||||||
|
if au.Stopped() != tc.stopped {
|
||||||
|
t.Errorf("#%d test failed: stopped value should be %v, got %v",
|
||||||
|
i, tc.stopped, au.Stopped())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
40
bypass.go
40
bypass.go
@ -223,44 +223,22 @@ func (bp *Bypass) Reload(r io.Reader) error {
|
|||||||
scanner := bufio.NewScanner(r)
|
scanner := bufio.NewScanner(r)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
if n := strings.IndexByte(line, '#'); n >= 0 {
|
ss := splitLine(line)
|
||||||
line = line[:n]
|
if len(ss) == 0 {
|
||||||
}
|
|
||||||
line = strings.Replace(line, "\t", " ", -1)
|
|
||||||
line = strings.TrimSpace(line)
|
|
||||||
if line == "" {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
switch ss[0] {
|
||||||
// reload option
|
case "reload": // reload option
|
||||||
if strings.HasPrefix(line, "reload ") {
|
if len(ss) > 1 {
|
||||||
var ss []string
|
|
||||||
for _, s := range strings.Split(line, " ") {
|
|
||||||
if s = strings.TrimSpace(s); s != "" {
|
|
||||||
ss = append(ss, s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(ss) == 2 {
|
|
||||||
period, _ = time.ParseDuration(ss[1])
|
period, _ = time.ParseDuration(ss[1])
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
}
|
case "reverse": // reverse option
|
||||||
|
if len(ss) > 1 {
|
||||||
// reverse option
|
|
||||||
if strings.HasPrefix(line, "reverse ") {
|
|
||||||
var ss []string
|
|
||||||
for _, s := range strings.Split(line, " ") {
|
|
||||||
if s = strings.TrimSpace(s); s != "" {
|
|
||||||
ss = append(ss, s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(ss) == 2 {
|
|
||||||
reversed, _ = strconv.ParseBool(ss[1])
|
reversed, _ = strconv.ParseBool(ss[1])
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
|
matchers = append(matchers, NewMatcher(ss[0]))
|
||||||
}
|
}
|
||||||
|
|
||||||
matchers = append(matchers, NewMatcher(line))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
if err := scanner.Err(); err != nil {
|
||||||
|
@ -220,7 +220,7 @@ var bypassReloadTests = []struct {
|
|||||||
stopped: true,
|
stopped: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
r: bytes.NewBufferString("#reverse true\n#reload 10s\n192.168.1.0/24"),
|
r: bytes.NewBufferString("#reverse true\n#reload 10s\n192.168.1.0/24 #comment"),
|
||||||
reversed: false,
|
reversed: false,
|
||||||
period: 0,
|
period: 0,
|
||||||
addr: "192.168.10.2",
|
addr: "192.168.10.2",
|
||||||
@ -244,7 +244,7 @@ var bypassReloadTests = []struct {
|
|||||||
stopped: true,
|
stopped: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
r: bytes.NewBufferString("#reverse true\n#reload 10s\nexample.com"),
|
r: bytes.NewBufferString("#reverse true\n#reload 10s\nexample.com #comment"),
|
||||||
reversed: false,
|
reversed: false,
|
||||||
period: 0,
|
period: 0,
|
||||||
addr: "example.com",
|
addr: "example.com",
|
||||||
|
@ -1 +0,0 @@
|
|||||||
Hello World!
|
|
@ -118,6 +118,24 @@ func parseUsers(authFile string) (users []*url.Userinfo, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseAuthenticator(s string) (gost.Authenticator, error) {
|
||||||
|
if s == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
f, err := os.Open(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
au := gost.NewLocalAuthenticator(nil)
|
||||||
|
au.Reload(f)
|
||||||
|
|
||||||
|
go gost.PeriodReload(au, s)
|
||||||
|
|
||||||
|
return au, nil
|
||||||
|
}
|
||||||
|
|
||||||
func parseIP(s string, port string) (ips []string) {
|
func parseIP(s string, port string) (ips []string) {
|
||||||
if s == "" {
|
if s == "" {
|
||||||
return
|
return
|
||||||
|
@ -257,12 +257,14 @@ func (r *route) GenRouters() ([]router, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
users, err := parseUsers(node.Get("secrets"))
|
authenticator, err := parseAuthenticator(node.Get("secrets"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if node.User != nil {
|
if authenticator == nil && node.User != nil {
|
||||||
users = append(users, node.User)
|
kvs := make(map[string]string)
|
||||||
|
kvs[node.User.Username()], _ = node.User.Password()
|
||||||
|
authenticator = gost.NewLocalAuthenticator(kvs)
|
||||||
}
|
}
|
||||||
certFile, keyFile := node.Get("cert"), node.Get("key")
|
certFile, keyFile := node.Get("cert"), node.Get("key")
|
||||||
tlsCfg, err := tlsConfig(certFile, keyFile)
|
tlsCfg, err := tlsConfig(certFile, keyFile)
|
||||||
@ -298,8 +300,8 @@ func (r *route) GenRouters() ([]router, error) {
|
|||||||
ln, err = gost.KCPListener(node.Addr, config)
|
ln, err = gost.KCPListener(node.Addr, config)
|
||||||
case "ssh":
|
case "ssh":
|
||||||
config := &gost.SSHConfig{
|
config := &gost.SSHConfig{
|
||||||
Users: users,
|
Authenticator: authenticator,
|
||||||
TLSConfig: tlsCfg,
|
TLSConfig: tlsCfg,
|
||||||
}
|
}
|
||||||
if node.Protocol == "forward" {
|
if node.Protocol == "forward" {
|
||||||
ln, err = gost.TCPListener(node.Addr)
|
ln, err = gost.TCPListener(node.Addr)
|
||||||
@ -416,7 +418,7 @@ func (r *route) GenRouters() ([]router, error) {
|
|||||||
// gost.AddrHandlerOption(node.Addr),
|
// gost.AddrHandlerOption(node.Addr),
|
||||||
gost.AddrHandlerOption(ln.Addr().String()),
|
gost.AddrHandlerOption(ln.Addr().String()),
|
||||||
gost.ChainHandlerOption(chain),
|
gost.ChainHandlerOption(chain),
|
||||||
gost.UsersHandlerOption(users...),
|
gost.AuthenticatorHandlerOption(authenticator),
|
||||||
gost.TLSConfigHandlerOption(tlsCfg),
|
gost.TLSConfigHandlerOption(tlsCfg),
|
||||||
gost.WhitelistHandlerOption(whitelist),
|
gost.WhitelistHandlerOption(whitelist),
|
||||||
gost.BlacklistHandlerOption(blacklist),
|
gost.BlacklistHandlerOption(blacklist),
|
||||||
|
24
gost.go
24
gost.go
@ -11,6 +11,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net"
|
"net"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -18,7 +19,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Version is the gost version.
|
// Version is the gost version.
|
||||||
const Version = "2.7"
|
const Version = "2.7.1"
|
||||||
|
|
||||||
// Debug is a flag that enables the debug log.
|
// Debug is a flag that enables the debug log.
|
||||||
var Debug bool
|
var Debug bool
|
||||||
@ -180,7 +181,22 @@ func (c *nopConn) SetWriteDeadline(t time.Time) error {
|
|||||||
return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accepter represents a network endpoint that can accept connection from peer.
|
// splitLine splits a line text by white space, mainly used by config parser.
|
||||||
type Accepter interface {
|
func splitLine(line string) []string {
|
||||||
Accept() (net.Conn, error)
|
if line == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if n := strings.IndexByte(line, '#'); n >= 0 {
|
||||||
|
line = line[:n]
|
||||||
|
}
|
||||||
|
line = strings.Replace(line, "\t", " ", -1)
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
|
||||||
|
var ss []string
|
||||||
|
for _, s := range strings.Split(line, " ") {
|
||||||
|
if s = strings.TrimSpace(s); s != "" {
|
||||||
|
ss = append(ss, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ss
|
||||||
}
|
}
|
||||||
|
48
handler.go
48
handler.go
@ -20,21 +20,22 @@ type Handler interface {
|
|||||||
|
|
||||||
// HandlerOptions describes the options for Handler.
|
// HandlerOptions describes the options for Handler.
|
||||||
type HandlerOptions struct {
|
type HandlerOptions struct {
|
||||||
Addr string
|
Addr string
|
||||||
Chain *Chain
|
Chain *Chain
|
||||||
Users []*url.Userinfo
|
Users []*url.Userinfo
|
||||||
TLSConfig *tls.Config
|
Authenticator Authenticator
|
||||||
Whitelist *Permissions
|
TLSConfig *tls.Config
|
||||||
Blacklist *Permissions
|
Whitelist *Permissions
|
||||||
Strategy Strategy
|
Blacklist *Permissions
|
||||||
Bypass *Bypass
|
Strategy Strategy
|
||||||
Retries int
|
Bypass *Bypass
|
||||||
Timeout time.Duration
|
Retries int
|
||||||
Resolver Resolver
|
Timeout time.Duration
|
||||||
Hosts *Hosts
|
Resolver Resolver
|
||||||
ProbeResist string
|
Hosts *Hosts
|
||||||
Node Node
|
ProbeResist string
|
||||||
Host string
|
Node Node
|
||||||
|
Host string
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandlerOption allows a common way to set handler options.
|
// HandlerOption allows a common way to set handler options.
|
||||||
@ -58,6 +59,23 @@ func ChainHandlerOption(chain *Chain) HandlerOption {
|
|||||||
func UsersHandlerOption(users ...*url.Userinfo) HandlerOption {
|
func UsersHandlerOption(users ...*url.Userinfo) HandlerOption {
|
||||||
return func(opts *HandlerOptions) {
|
return func(opts *HandlerOptions) {
|
||||||
opts.Users = users
|
opts.Users = users
|
||||||
|
|
||||||
|
kvs := make(map[string]string)
|
||||||
|
for _, u := range users {
|
||||||
|
if u != nil {
|
||||||
|
kvs[u.Username()], _ = u.Password()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(kvs) > 0 {
|
||||||
|
opts.Authenticator = NewLocalAuthenticator(kvs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthenticatorHandlerOption sets the Authenticator option of HandlerOptions.
|
||||||
|
func AuthenticatorHandlerOption(au Authenticator) HandlerOption {
|
||||||
|
return func(opts *HandlerOptions) {
|
||||||
|
opts.Authenticator = au
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
47
hosts.go
47
hosts.go
@ -4,7 +4,6 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -94,42 +93,28 @@ func (h *Hosts) Reload(r io.Reader) error {
|
|||||||
scanner := bufio.NewScanner(r)
|
scanner := bufio.NewScanner(r)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
if n := strings.IndexByte(line, '#'); n >= 0 {
|
ss := splitLine(line)
|
||||||
line = line[:n]
|
|
||||||
}
|
|
||||||
line = strings.Replace(line, "\t", " ", -1)
|
|
||||||
line = strings.TrimSpace(line)
|
|
||||||
if line == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
var ss []string
|
|
||||||
for _, s := range strings.Split(line, " ") {
|
|
||||||
if s = strings.TrimSpace(s); s != "" {
|
|
||||||
ss = append(ss, s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(ss) < 2 {
|
if len(ss) < 2 {
|
||||||
continue // invalid lines are ignored
|
continue // invalid lines are ignored
|
||||||
}
|
}
|
||||||
|
|
||||||
// reload option
|
switch ss[0] {
|
||||||
if strings.ToLower(ss[0]) == "reload" {
|
case "reload": // reload option
|
||||||
period, _ = time.ParseDuration(ss[1])
|
period, _ = time.ParseDuration(ss[1])
|
||||||
continue
|
default:
|
||||||
|
ip := net.ParseIP(ss[0])
|
||||||
|
if ip == nil {
|
||||||
|
break // invalid IP addresses are ignored
|
||||||
|
}
|
||||||
|
host := Host{
|
||||||
|
IP: ip,
|
||||||
|
Hostname: ss[1],
|
||||||
|
}
|
||||||
|
if len(ss) > 2 {
|
||||||
|
host.Aliases = ss[2:]
|
||||||
|
}
|
||||||
|
hosts = append(hosts, host)
|
||||||
}
|
}
|
||||||
|
|
||||||
ip := net.ParseIP(ss[0])
|
|
||||||
if ip == nil {
|
|
||||||
continue // invalid IP addresses are ignored
|
|
||||||
}
|
|
||||||
host := Host{
|
|
||||||
IP: ip,
|
|
||||||
Hostname: ss[1],
|
|
||||||
}
|
|
||||||
if len(ss) > 2 {
|
|
||||||
host.Aliases = ss[2:]
|
|
||||||
}
|
|
||||||
hosts = append(hosts, host)
|
|
||||||
}
|
}
|
||||||
if err := scanner.Err(); err != nil {
|
if err := scanner.Err(); err != nil {
|
||||||
return err
|
return err
|
||||||
|
19
http.go
19
http.go
@ -299,7 +299,7 @@ func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.
|
|||||||
log.Logf("[http] %s -> %s : Authorization '%s' '%s'",
|
log.Logf("[http] %s -> %s : Authorization '%s' '%s'",
|
||||||
conn.RemoteAddr(), conn.LocalAddr(), u, p)
|
conn.RemoteAddr(), conn.LocalAddr(), u, p)
|
||||||
}
|
}
|
||||||
if authenticate(u, p, h.options.Users...) {
|
if h.options.Authenticator == nil || h.options.Authenticator.Authenticate(u, p) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -423,20 +423,3 @@ func basicProxyAuth(proxyAuth string) (username, password string, ok bool) {
|
|||||||
|
|
||||||
return cs[:s], cs[s+1:], true
|
return cs[:s], cs[s+1:], true
|
||||||
}
|
}
|
||||||
|
|
||||||
func authenticate(username, password string, users ...*url.Userinfo) bool {
|
|
||||||
if len(users) == 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, user := range users {
|
|
||||||
u := user.Username()
|
|
||||||
p, _ := user.Password()
|
|
||||||
if (u == username && p == password) ||
|
|
||||||
(u == username && p == "") ||
|
|
||||||
(u == "" && p == password) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
3
http2.go
3
http2.go
@ -457,9 +457,10 @@ func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp
|
|||||||
if Debug && (u != "" || p != "") {
|
if Debug && (u != "" || p != "") {
|
||||||
log.Logf("[http2] %s - %s : Authorization '%s' '%s'", r.RemoteAddr, laddr, u, p)
|
log.Logf("[http2] %s - %s : Authorization '%s' '%s'", r.RemoteAddr, laddr, u, p)
|
||||||
}
|
}
|
||||||
if authenticate(u, p, h.options.Users...) {
|
if h.options.Authenticator == nil || h.options.Authenticator.Authenticate(u, p) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// probing resistance is enabled
|
// probing resistance is enabled
|
||||||
if ss := strings.SplitN(h.options.ProbeResist, ":", 2); len(ss) == 2 {
|
if ss := strings.SplitN(h.options.ProbeResist, ":", 2); len(ss) == 2 {
|
||||||
switch ss[0] {
|
switch ss[0] {
|
||||||
|
@ -1088,7 +1088,7 @@ func TestHTTP2ProxyWithFileProbeResist(t *testing.T) {
|
|||||||
Listener: ln,
|
Listener: ln,
|
||||||
Handler: HTTP2Handler(
|
Handler: HTTP2Handler(
|
||||||
UsersHandlerOption(url.UserPassword("admin", "123456")),
|
UsersHandlerOption(url.UserPassword("admin", "123456")),
|
||||||
ProbeResistHandlerOption("file:.testdata/probe_resist.txt"),
|
ProbeResistHandlerOption("file:.config/probe_resist.txt"),
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
go server.Run()
|
go server.Run()
|
||||||
|
@ -26,7 +26,7 @@ var httpProxyTests = []struct {
|
|||||||
{url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "")}, ""},
|
{url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "")}, ""},
|
||||||
{url.UserPassword("admin", "123456"), nil, ""},
|
{url.UserPassword("admin", "123456"), nil, ""},
|
||||||
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("admin")}, ""},
|
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("admin")}, ""},
|
||||||
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, ""},
|
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, "407 Proxy Authentication Required"},
|
||||||
{url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, ""},
|
{url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, ""},
|
||||||
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, ""},
|
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, ""},
|
||||||
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("user", "pass"), url.UserPassword("admin", "123456")}, ""},
|
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("user", "pass"), url.UserPassword("admin", "123456")}, ""},
|
||||||
@ -312,7 +312,7 @@ func TestHTTPProxyWithFileProbeResist(t *testing.T) {
|
|||||||
Listener: ln,
|
Listener: ln,
|
||||||
Handler: HTTPHandler(
|
Handler: HTTPHandler(
|
||||||
UsersHandlerOption(url.UserPassword("admin", "123456")),
|
UsersHandlerOption(url.UserPassword("admin", "123456")),
|
||||||
ProbeResistHandlerOption("file:.testdata/probe_resist.txt"),
|
ProbeResistHandlerOption("file:.config/probe_resist.txt"),
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
go server.Run()
|
go server.Run()
|
||||||
|
21
reload.go
21
reload.go
@ -17,26 +17,7 @@ type Reloader interface {
|
|||||||
// Stoppable is the interface that indicates a Reloader can be stopped.
|
// Stoppable is the interface that indicates a Reloader can be stopped.
|
||||||
type Stoppable interface {
|
type Stoppable interface {
|
||||||
Stop()
|
Stop()
|
||||||
}
|
Stopped() bool
|
||||||
|
|
||||||
//StopReloader is the interface that adds Stop method to the Reloader.
|
|
||||||
type StopReloader interface {
|
|
||||||
Reloader
|
|
||||||
Stoppable
|
|
||||||
}
|
|
||||||
|
|
||||||
type nopStoppable struct {
|
|
||||||
Reloader
|
|
||||||
}
|
|
||||||
|
|
||||||
func (nopStoppable) Stop() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// NopStoppable returns a StopReloader with a no-op Stop method,
|
|
||||||
// wrapping the provided Reloader r.
|
|
||||||
func NopStoppable(r Reloader) StopReloader {
|
|
||||||
return nopStoppable{r}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// PeriodReload reloads the config configFile periodically according to the period of the Reloader r.
|
// PeriodReload reloads the config configFile periodically according to the period of the Reloader r.
|
||||||
|
21
resolver.go
21
resolver.go
@ -278,29 +278,10 @@ func (r *resolver) Reload(rd io.Reader) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
split := func(line string) []string {
|
|
||||||
if line == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if n := strings.IndexByte(line, '#'); n >= 0 {
|
|
||||||
line = line[:n]
|
|
||||||
}
|
|
||||||
line = strings.Replace(line, "\t", " ", -1)
|
|
||||||
line = strings.TrimSpace(line)
|
|
||||||
|
|
||||||
var ss []string
|
|
||||||
for _, s := range strings.Split(line, " ") {
|
|
||||||
if s = strings.TrimSpace(s); s != "" {
|
|
||||||
ss = append(ss, s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ss
|
|
||||||
}
|
|
||||||
|
|
||||||
scanner := bufio.NewScanner(rd)
|
scanner := bufio.NewScanner(rd)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
ss := split(line)
|
ss := splitLine(line)
|
||||||
if len(ss) == 0 {
|
if len(ss) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,11 @@ import (
|
|||||||
"github.com/go-log/log"
|
"github.com/go-log/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Accepter represents a network endpoint that can accept connection from peer.
|
||||||
|
type Accepter interface {
|
||||||
|
Accept() (net.Conn, error)
|
||||||
|
}
|
||||||
|
|
||||||
// Server is a proxy server.
|
// Server is a proxy server.
|
||||||
type Server struct {
|
type Server struct {
|
||||||
Listener Listener
|
Listener Listener
|
||||||
|
@ -14,7 +14,7 @@ apps:
|
|||||||
|
|
||||||
parts:
|
parts:
|
||||||
go:
|
go:
|
||||||
source-tag: go1.11
|
source-tag: go1.10
|
||||||
gost:
|
gost:
|
||||||
after: [go]
|
after: [go]
|
||||||
source: .
|
source: .
|
||||||
|
30
socks.go
30
socks.go
@ -96,9 +96,10 @@ func (selector *clientSelector) OnSelected(method uint8, conn net.Conn) (net.Con
|
|||||||
}
|
}
|
||||||
|
|
||||||
type serverSelector struct {
|
type serverSelector struct {
|
||||||
methods []uint8
|
methods []uint8
|
||||||
Users []*url.Userinfo
|
// Users []*url.Userinfo
|
||||||
TLSConfig *tls.Config
|
Authenticator Authenticator
|
||||||
|
TLSConfig *tls.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func (selector *serverSelector) Methods() []uint8 {
|
func (selector *serverSelector) Methods() []uint8 {
|
||||||
@ -121,8 +122,8 @@ func (selector *serverSelector) Select(methods ...uint8) (method uint8) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// when user/pass is set, auth is mandatory
|
// when Authenticator is set, auth is mandatory
|
||||||
if len(selector.Users) > 0 {
|
if selector.Authenticator != nil {
|
||||||
if method == gosocks5.MethodNoAuth {
|
if method == gosocks5.MethodNoAuth {
|
||||||
method = gosocks5.MethodUserPass
|
method = gosocks5.MethodUserPass
|
||||||
}
|
}
|
||||||
@ -155,18 +156,8 @@ func (selector *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Con
|
|||||||
if Debug {
|
if Debug {
|
||||||
log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), req.String())
|
log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), req.String())
|
||||||
}
|
}
|
||||||
valid := false
|
|
||||||
for _, user := range selector.Users {
|
if selector.Authenticator != nil && !selector.Authenticator.Authenticate(req.Username, req.Password) {
|
||||||
username := user.Username()
|
|
||||||
password, _ := user.Password()
|
|
||||||
if (req.Username == username && req.Password == password) ||
|
|
||||||
(req.Username == username && password == "") ||
|
|
||||||
(username == "" && req.Password == password) {
|
|
||||||
valid = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(selector.Users) > 0 && !valid {
|
|
||||||
resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Failure)
|
resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Failure)
|
||||||
if err := resp.Write(conn); err != nil {
|
if err := resp.Write(conn); err != nil {
|
||||||
log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err)
|
log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err)
|
||||||
@ -788,8 +779,9 @@ func (h *socks5Handler) Init(options ...HandlerOption) {
|
|||||||
tlsConfig = DefaultTLSConfig
|
tlsConfig = DefaultTLSConfig
|
||||||
}
|
}
|
||||||
h.selector = &serverSelector{ // socks5 server selector
|
h.selector = &serverSelector{ // socks5 server selector
|
||||||
Users: h.options.Users,
|
// Users: h.options.Users,
|
||||||
TLSConfig: tlsConfig,
|
Authenticator: h.options.Authenticator,
|
||||||
|
TLSConfig: tlsConfig,
|
||||||
}
|
}
|
||||||
// methods that socks5 server supported
|
// methods that socks5 server supported
|
||||||
h.selector.AddMethod(
|
h.selector.AddMethod(
|
||||||
|
@ -25,7 +25,7 @@ var socks5ProxyTests = []struct {
|
|||||||
{url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "")}, true},
|
{url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "")}, true},
|
||||||
{url.UserPassword("admin", "123456"), nil, true},
|
{url.UserPassword("admin", "123456"), nil, true},
|
||||||
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("admin")}, true},
|
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("admin")}, true},
|
||||||
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, true},
|
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, false},
|
||||||
{url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, true},
|
{url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, true},
|
||||||
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, true},
|
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, true},
|
||||||
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("user", "pass"), url.UserPassword("admin", "123456")}, true},
|
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("user", "pass"), url.UserPassword("admin", "123456")}, true},
|
||||||
|
23
ssh.go
23
ssh.go
@ -7,7 +7,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -466,8 +465,8 @@ func (h *sshForwardHandler) Init(options ...HandlerOption) {
|
|||||||
}
|
}
|
||||||
h.config = &ssh.ServerConfig{}
|
h.config = &ssh.ServerConfig{}
|
||||||
|
|
||||||
h.config.PasswordCallback = defaultSSHPasswordCallback(h.options.Users...)
|
h.config.PasswordCallback = defaultSSHPasswordCallback(h.options.Authenticator)
|
||||||
if len(h.options.Users) == 0 {
|
if h.options.Authenticator == nil {
|
||||||
h.config.NoClientAuth = true
|
h.config.NoClientAuth = true
|
||||||
}
|
}
|
||||||
tlsConfig := h.options.TLSConfig
|
tlsConfig := h.options.TLSConfig
|
||||||
@ -665,8 +664,8 @@ func (h *sshForwardHandler) tcpipForwardRequest(sshConn ssh.Conn, req *ssh.Reque
|
|||||||
|
|
||||||
// SSHConfig holds the SSH tunnel server config
|
// SSHConfig holds the SSH tunnel server config
|
||||||
type SSHConfig struct {
|
type SSHConfig struct {
|
||||||
Users []*url.Userinfo
|
Authenticator Authenticator
|
||||||
TLSConfig *tls.Config
|
TLSConfig *tls.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
type sshTunnelListener struct {
|
type sshTunnelListener struct {
|
||||||
@ -688,8 +687,8 @@ func SSHTunnelListener(addr string, config *SSHConfig) (Listener, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
sshConfig := &ssh.ServerConfig{}
|
sshConfig := &ssh.ServerConfig{}
|
||||||
sshConfig.PasswordCallback = defaultSSHPasswordCallback(config.Users...)
|
sshConfig.PasswordCallback = defaultSSHPasswordCallback(config.Authenticator)
|
||||||
if len(config.Users) == 0 {
|
if config.Authenticator == nil {
|
||||||
sshConfig.NoClientAuth = true
|
sshConfig.NoClientAuth = true
|
||||||
}
|
}
|
||||||
tlsConfig := config.TLSConfig
|
tlsConfig := config.TLSConfig
|
||||||
@ -808,14 +807,10 @@ func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) {
|
|||||||
// PasswordCallbackFunc is a callback function used by SSH server.
|
// PasswordCallbackFunc is a callback function used by SSH server.
|
||||||
type PasswordCallbackFunc func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error)
|
type PasswordCallbackFunc func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error)
|
||||||
|
|
||||||
func defaultSSHPasswordCallback(users ...*url.Userinfo) PasswordCallbackFunc {
|
func defaultSSHPasswordCallback(au Authenticator) PasswordCallbackFunc {
|
||||||
return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
||||||
for _, user := range users {
|
if au.Authenticate(conn.User(), string(password)) {
|
||||||
u := user.Username()
|
return nil, nil
|
||||||
p, _ := user.Password()
|
|
||||||
if u == conn.User() && p == string(password) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
log.Logf("[ssh] %s -> %s : password rejected for %s", conn.RemoteAddr(), conn.LocalAddr(), conn.User())
|
log.Logf("[ssh] %s -> %s : password rejected for %s", conn.RemoteAddr(), conn.LocalAddr(), conn.User())
|
||||||
return nil, fmt.Errorf("password rejected for %s", conn.User())
|
return nil, fmt.Errorf("password rejected for %s", conn.User())
|
||||||
|
Loading…
Reference in New Issue
Block a user