add bypass support for server nodes

This commit is contained in:
ginuerzh 2018-04-22 12:46:28 +08:00
parent 08003d296d
commit f6e09eaae6
3 changed files with 137 additions and 8 deletions

View File

@ -1,6 +1,8 @@
package gost package gost
import ( import (
"bytes"
"fmt"
"net" "net"
glob "github.com/gobwas/glob" glob "github.com/gobwas/glob"
@ -10,6 +12,7 @@ import (
// it gives the match result of the given pattern for specific v. // it gives the match result of the given pattern for specific v.
type Matcher interface { type Matcher interface {
Match(v string) bool Match(v string) bool
String() string
} }
// NewMatcher creates a Matcher for the given pattern. // NewMatcher creates a Matcher for the given pattern.
@ -48,6 +51,10 @@ func (m *ipMatcher) Match(ip string) bool {
return m.ip.Equal(net.ParseIP(ip)) return m.ip.Equal(net.ParseIP(ip))
} }
func (m *ipMatcher) String() string {
return "ip " + m.ip.String()
}
type cidrMatcher struct { type cidrMatcher struct {
ipNet *net.IPNet ipNet *net.IPNet
} }
@ -66,8 +73,13 @@ func (m *cidrMatcher) Match(ip string) bool {
return m.ipNet.Contains(net.ParseIP(ip)) return m.ipNet.Contains(net.ParseIP(ip))
} }
func (m *cidrMatcher) String() string {
return "cidr " + m.ipNet.String()
}
type domainMatcher struct { type domainMatcher struct {
glob glob.Glob pattern string
glob glob.Glob
} }
// DomainMatcher creates a Matcher for a specific domain pattern, // DomainMatcher creates a Matcher for a specific domain pattern,
@ -75,7 +87,8 @@ type domainMatcher struct {
// or a wildcard such as '*.exmaple.com'. // or a wildcard such as '*.exmaple.com'.
func DomainMatcher(pattern string) Matcher { func DomainMatcher(pattern string) Matcher {
return &domainMatcher{ return &domainMatcher{
glob: glob.MustCompile(pattern), pattern: pattern,
glob: glob.MustCompile(pattern),
} }
} }
@ -86,6 +99,10 @@ func (m *domainMatcher) Match(domain string) bool {
return m.glob.Match(domain) return m.glob.Match(domain)
} }
func (m *domainMatcher) String() string {
return "domain " + m.pattern
}
// Bypass is a filter for address (IP or domain). // Bypass is a filter for address (IP or domain).
// It contains a list of matchers. // It contains a list of matchers.
type Bypass struct { type Bypass struct {
@ -116,15 +133,45 @@ func NewBypassPatterns(patterns []string, reverse bool) *Bypass {
// Contains reports whether the bypass includes addr. // Contains reports whether the bypass includes addr.
func (bp *Bypass) Contains(addr string) bool { func (bp *Bypass) Contains(addr string) bool {
if bp == nil {
return false
}
var matched bool
for _, matcher := range bp.matchers { for _, matcher := range bp.matchers {
if matcher == nil { if matcher == nil {
continue continue
} }
matched := matcher.Match(addr) if matcher.Match(addr) {
if (matched && !bp.reverse) || matched = true
(!matched && bp.reverse) { break
return true
} }
} }
return false return !bp.reverse && matched ||
bp.reverse && !matched
}
// AddMatchers appends matchers to the bypass matcher list.
func (bp *Bypass) AddMatchers(matchers ...Matcher) {
bp.matchers = append(bp.matchers, matchers...)
}
// Matchers return the bypass matcher list.
func (bp *Bypass) Matchers() []Matcher {
return bp.matchers
}
// Reversed reports whether the rules of the bypass are reversed.
func (bp *Bypass) Reversed() bool {
return bp.reverse
}
func (bp *Bypass) String() string {
b := &bytes.Buffer{}
fmt.Fprintf(b, "reversed: %v\n", bp.Reversed())
for _, m := range bp.Matchers() {
b.WriteString(m.String())
b.WriteByte('\n')
}
return b.String()
} }

View File

@ -484,7 +484,15 @@ func (r *route) serve() error {
} }
} }
fBypass := node.Get("bypass")
if fBypass == "" {
fBypass = "bypass" // default bypass file
}
srv := &gost.Server{Listener: ln} srv := &gost.Server{Listener: ln}
srv.Init(
gost.BypassServerOption(parseBypass(fBypass)),
)
go srv.Serve(handler) go srv.Serve(handler)
} }
@ -685,3 +693,29 @@ func parseStrategy(s string) gost.Strategy {
} }
} }
func parseBypass(fpath string) (bypass *gost.Bypass) {
if fpath == "" {
return
}
f, err := os.Open(fpath)
if err != nil {
return
}
var matchers []gost.Matcher
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
if n := strings.IndexByte(line, '#'); n >= 0 {
line = line[:n]
}
line = strings.TrimSpace(line)
if line == "" {
continue
}
matchers = append(matchers, gost.NewMatcher(line))
}
bypass = gost.NewBypass(matchers, strings.HasPrefix(fpath, "~"))
return
}

View File

@ -11,6 +11,17 @@ import (
// Server is a proxy server. // Server is a proxy server.
type Server struct { type Server struct {
Listener Listener Listener Listener
options *ServerOptions
}
// Init intializes server with given options.
func (s *Server) Init(opts ...ServerOption) {
if s.options == nil {
s.options = &ServerOptions{}
}
for _, opt := range opts {
opt(s.options)
}
} }
// Addr returns the address of the server // Addr returns the address of the server
@ -24,7 +35,9 @@ func (s *Server) Close() error {
} }
// Serve serves as a proxy server. // Serve serves as a proxy server.
func (s *Server) Serve(h Handler) error { func (s *Server) Serve(h Handler, opts ...ServerOption) error {
s.Init(opts...)
if s.Listener == nil { if s.Listener == nil {
ln, err := TCPListener("") ln, err := TCPListener("")
if err != nil { if err != nil {
@ -57,9 +70,31 @@ func (s *Server) Serve(h Handler) error {
return e return e
} }
tempDelay = 0 tempDelay = 0
ip := extractIP(conn.RemoteAddr())
if s.options.Bypass.Contains(ip.String()) {
log.Log("bypass", ip.String())
conn.Close()
continue
}
go h.Handle(conn) go h.Handle(conn)
} }
}
// ServerOptions holds the options for Server.
type ServerOptions struct {
Bypass *Bypass
}
// ServerOption allows a common way to set server options.
type ServerOption func(opts *ServerOptions)
// BypassServerOption sets the bypass option of ServerOptions.
func BypassServerOption(bypass *Bypass) ServerOption {
return func(opts *ServerOptions) {
opts.Bypass = bypass
}
} }
// Listener is a proxy server listener, just like a net.Listener. // Listener is a proxy server listener, just like a net.Listener.
@ -116,3 +151,16 @@ func transport(rw1, rw2 io.ReadWriter) error {
} }
return err return err
} }
func extractIP(addr net.Addr) net.IP {
switch v := addr.(type) {
case *net.IPAddr:
return v.IP
case *net.TCPAddr:
return v.IP
case *net.UDPAddr:
return v.IP
default:
}
return net.IPv4zero
}