support forward chain

This commit is contained in:
rui.zheng 2015-10-12 17:24:57 +08:00
parent d69a093004
commit f88eecbe37
5 changed files with 133 additions and 105 deletions

186
conn.go
View File

@ -211,50 +211,75 @@ func (r *reqReader) Read(p []byte) (n int, err error) {
return return
} }
func connect(connType, addr string) (conn net.Conn, err error) { func connect(addr string) (conn net.Conn, err error) {
if !strings.Contains(addr, ":") { if !strings.Contains(addr, ":") {
addr += ":80" addr += ":80"
} }
if len(forwardArgs) == 0 {
if len(proxyArgs) > 0 && len(forwardArgs) > 0 { return net.Dial("tcp", addr)
return connectProxyForward(connType, addr, proxyArgs[0], forwardArgs[0])
} }
return forwardChain(addr, forwardArgs[0], forwardArgs[1:]...)
if len(forwardArgs) > 0 {
// TODO: multi-foward
forward := forwardArgs[0]
return connectForward(connType, addr, forward)
}
if len(proxyArgs) > 0 {
proxy := proxyArgs[0]
return connectForward(connType, addr, proxy)
}
return net.Dial("tcp", addr)
} }
func connectProxyForward(connType, addr string, proxy, forward Args) (conn net.Conn, err error) { func forwardChain(addr string, level1 Args, chain ...Args) (conn net.Conn, err error) {
return nil, errors.New("Not implemented")
}
func connectForward(connType, addr string, forward Args) (conn net.Conn, err error) {
if glog.V(LINFO) { if glog.V(LINFO) {
glog.Infoln(forward.Protocol, "forward:", forward.Addr) glog.Infof("forward: %s/%s %s", level1.Protocol, level1.Transport, level1.Addr)
} }
conn, err = net.Dial("tcp", forward.Addr) if conn, err = net.Dial("tcp", level1.Addr); err != nil {
if err != nil {
return return
} }
c, err := forward(conn, level1)
if err != nil {
conn.Close()
return nil, err
}
conn = c
switch forward.Transport { if len(chain) == 0 {
case "ws": // websocket connection if err := establish(conn, addr, level1); err != nil {
c, err := wsClient(conn, forward.Addr)
if err != nil {
conn.Close() conn.Close()
return nil, err return nil, err
} }
return
}
cur := level1
for _, arg := range chain {
if err = establish(conn, arg.Addr, cur); err != nil {
goto exit
}
c, err = forward(conn, arg)
if err != nil {
goto exit
}
conn = c conn = c
cur = arg
}
exit:
if err != nil {
conn.Close()
return nil, err
}
if err := establish(conn, addr, cur); err != nil {
conn.Close()
return nil, err
}
return
}
func forward(conn net.Conn, arg Args) (net.Conn, error) {
var err error
switch arg.Transport {
case "ws": // websocket connection
conn, err = wsClient(conn, arg.Addr)
if err != nil {
return nil, err
}
case "tls": // tls connection case "tls": // tls connection
conn = tls.Client(conn, &tls.Config{InsecureSkipVerify: true}) conn = tls.Client(conn, &tls.Config{InsecureSkipVerify: true})
case "tcp": case "tcp":
@ -262,83 +287,86 @@ func connectForward(connType, addr string, forward Args) (conn net.Conn, err err
default: default:
} }
switch forward.Protocol { switch arg.Protocol {
case "ss": // shadowsocks case "ss": // shadowsocks
conn.Close()
return nil, errors.New("Not implemented") return nil, errors.New("Not implemented")
case "socks", "socks5": case "socks", "socks5":
selector := &clientSelector{ selector := &clientSelector{
methods: []uint8{gosocks5.MethodNoAuth, gosocks5.MethodUserPass}, methods: []uint8{gosocks5.MethodNoAuth, gosocks5.MethodUserPass},
arg: forward, arg: arg,
} }
if forward.EncMeth == "tls" { if arg.EncMeth == "tls" {
selector.methods = []uint8{MethodTLS, MethodTLSAuth} selector.methods = []uint8{MethodTLS, MethodTLSAuth}
} }
c := gosocks5.ClientConn(conn, selector) c := gosocks5.ClientConn(conn, selector)
if err := c.Handleshake(); err != nil { if err := c.Handleshake(); err != nil {
c.Close()
return nil, err return nil, err
} }
conn = c conn = c
case "http":
fallthrough
default:
}
host, port, _ := net.SplitHostPort(addr) return conn, nil
p, _ := strconv.ParseUint(port, 10, 16) }
r := gosocks5.NewRequest(gosocks5.CmdConnect, &gosocks5.Addr{
func establish(conn net.Conn, addr string, arg Args) error {
switch arg.Protocol {
case "ss": // shadowsocks
return nil
case "socks", "socks5":
host, port, err := net.SplitHostPort(addr)
p, _ := strconv.Atoi(port)
// TODO: support bind and udp
req := gosocks5.NewRequest(gosocks5.CmdConnect, &gosocks5.Addr{
Type: gosocks5.AddrDomain, Type: gosocks5.AddrDomain,
Host: host, Host: host,
Port: uint16(p), Port: uint16(p),
}) })
rep, err := requestSocks5(conn, r) rep, err := requestSocks5(conn, req)
if err != nil { if err != nil {
conn.Close() return err
return nil, err
} }
if rep.Rep != gosocks5.Succeeded { if rep.Rep != gosocks5.Succeeded {
conn.Close() return errors.New("Service unavailable")
return nil, errors.New("Service unavailable")
} }
case "http": case "http":
fallthrough fallthrough
default: default:
if connType == ConnHttpConnect || connType == ConnSocks5 { req := &http.Request{
req := &http.Request{ Method: "CONNECT",
Method: "CONNECT", URL: &url.URL{Host: addr},
URL: &url.URL{Host: addr}, Host: addr,
Host: addr, ProtoMajor: 1,
ProtoMajor: 1, ProtoMinor: 1,
ProtoMinor: 1, Header: make(http.Header),
Header: make(http.Header), }
} req.Header.Set("Proxy-Connection", "keep-alive")
req.Header.Set("Proxy-Connection", "keep-alive") if arg.User != nil {
if forward.User != nil { req.Header.Set("Proxy-Authorization",
req.Header.Set("Proxy-Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(arg.User.String())))
"Basic "+base64.StdEncoding.EncodeToString([]byte(forward.User.String()))) }
} if err := req.Write(conn); err != nil {
if err = req.Write(conn); err != nil { return err
conn.Close() }
return nil, err if glog.V(LDEBUG) {
} dump, _ := httputil.DumpRequest(req, false)
if glog.V(LDEBUG) { glog.Infoln(string(dump))
dump, _ := httputil.DumpRequest(req, false) }
glog.Infoln(string(dump))
}
resp, err := http.ReadResponse(bufio.NewReader(conn), req) resp, err := http.ReadResponse(bufio.NewReader(conn), req)
if err != nil { if err != nil {
conn.Close() return err
return nil, err }
} if glog.V(LDEBUG) {
if glog.V(LDEBUG) { dump, _ := httputil.DumpResponse(resp, false)
dump, _ := httputil.DumpResponse(resp, false) glog.Infoln(string(dump))
glog.Infoln(string(dump)) }
} if resp.StatusCode != http.StatusOK {
if resp.StatusCode != http.StatusOK { return errors.New(resp.Status)
conn.Close()
//log.Println(resp.Status)
return nil, errors.New(resp.Status)
}
} }
} }
return return nil
} }

11
http.go
View File

@ -19,11 +19,6 @@ func handleHttpRequest(req *http.Request, conn net.Conn, arg Args) {
} }
} }
connType := ConnHttp
if req.Method == "CONNECT" {
connType = ConnHttpConnect
}
var username, password string var username, password string
if arg.User != nil { if arg.User != nil {
username = arg.User.Username() username = arg.User.Username()
@ -52,7 +47,7 @@ func handleHttpRequest(req *http.Request, conn net.Conn, arg Args) {
return return
} }
c, err := connect(connType, req.Host) c, err := connect(req.Host)
if err != nil { if err != nil {
if glog.V(LWARNING) { if glog.V(LWARNING) {
glog.Warningln(err) glog.Warningln(err)
@ -67,7 +62,7 @@ func handleHttpRequest(req *http.Request, conn net.Conn, arg Args) {
} }
defer c.Close() defer c.Close()
if connType == ConnHttpConnect { if req.Method == "CONNECT" {
b := []byte("HTTP/1.1 200 Connection established\r\n" + b := []byte("HTTP/1.1 200 Connection established\r\n" +
"Proxy-Agent: gost/" + Version + "\r\n\r\n") "Proxy-Agent: gost/" + Version + "\r\n\r\n")
if glog.V(LDEBUG) { if glog.V(LDEBUG) {
@ -80,7 +75,7 @@ func handleHttpRequest(req *http.Request, conn net.Conn, arg Args) {
return return
} }
} else { } else {
if len(proxyArgs) > 0 || len(forwardArgs) > 0 { if len(forwardArgs) > 0 {
err = req.WriteProxy(c) err = req.WriteProxy(c)
} else { } else {
err = req.Write(c) err = req.Write(c)

20
main.go
View File

@ -16,30 +16,30 @@ const (
) )
var ( var (
listenUrl, proxyUrl, forwardUrl string listenAddr, forwardAddr strSlice
pv bool // print version pv bool // print version
listenArgs []Args listenArgs []Args
proxyArgs []Args
forwardArgs []Args forwardArgs []Args
) )
func init() { func init() {
flag.StringVar(&listenUrl, "L", ":http", "local address") flag.Var(&listenAddr, "L", "listen address")
flag.StringVar(&forwardUrl, "S", "", "remote address") flag.Var(&forwardAddr, "F", "forward address, can make a forward chain")
flag.StringVar(&proxyUrl, "P", "", "proxy address")
flag.BoolVar(&pv, "V", false, "print version") flag.BoolVar(&pv, "V", false, "print version")
flag.Parse() flag.Parse()
listenArgs = parseArgs(listenUrl) listenArgs = parseArgs(listenAddr)
proxyArgs = parseArgs(proxyUrl) forwardArgs = parseArgs(forwardAddr)
forwardArgs = parseArgs(forwardUrl)
} }
func main() { func main() {
defer glog.Flush() defer glog.Flush()
if flag.NFlag() == 0 {
flag.PrintDefaults()
return
}
if pv { if pv {
printVersion() printVersion()
return return

View File

@ -138,7 +138,7 @@ func (selector *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Con
return nil, err return nil, err
} }
if glog.V(LDEBUG) { if glog.V(LDEBUG) {
glog.Infoln(req) glog.Infoln(req.String())
} }
var username, password string var username, password string
@ -206,7 +206,7 @@ func handleSocks5Request(req *gosocks5.Request, conn net.Conn, arg Args) {
if glog.V(LINFO) { if glog.V(LINFO) {
glog.Infoln("socks5 connect:", req.Addr.String()) glog.Infoln("socks5 connect:", req.Addr.String())
} }
tconn, err := connect(ConnSocks5, req.Addr.String()) tconn, err := connect(req.Addr.String())
if err != nil { if err != nil {
if glog.V(LWARNING) { if glog.V(LWARNING) {
glog.Warningln("socks5 connect:", err) glog.Warningln("socks5 connect:", err)

17
util.go
View File

@ -10,6 +10,16 @@ import (
"strings" "strings"
) )
type strSlice []string
func (ss *strSlice) String() string {
return fmt.Sprintf("%s", *ss)
}
func (ss *strSlice) Set(value string) error {
*ss = append(*ss, value)
return nil
}
// socks://admin:123456@localhost:8080/tls // socks://admin:123456@localhost:8080/tls
type Args struct { type Args struct {
Addr string // host:port Addr string // host:port
@ -32,12 +42,7 @@ func (args Args) String() string {
args.EncMeth, args.EncPass) args.EncMeth, args.EncPass)
} }
func parseArgs(rawurl string) (args []Args) { func parseArgs(ss []string) (args []Args) {
ss := strings.Split(rawurl, ",")
if rawurl == "" || len(ss) == 0 {
return nil
}
for _, s := range ss { for _, s := range ss {
if !strings.Contains(s, "://") { if !strings.Contains(s, "://") {
s = "tcp://" + s s = "tcp://" + s