diff --git a/conn.go b/conn.go index e203f92..7eb4fd2 100644 --- a/conn.go +++ b/conn.go @@ -211,50 +211,75 @@ func (r *reqReader) Read(p []byte) (n int, err error) { return } -func connect(connType, addr string) (conn net.Conn, err error) { +func connect(addr string) (conn net.Conn, err error) { if !strings.Contains(addr, ":") { addr += ":80" } - - if len(proxyArgs) > 0 && len(forwardArgs) > 0 { - return connectProxyForward(connType, addr, proxyArgs[0], forwardArgs[0]) + if len(forwardArgs) == 0 { + return net.Dial("tcp", addr) } - - if len(forwardArgs) > 0 { - // TODO: multi-foward - forward := forwardArgs[0] - return connectForward(connType, addr, forward) - } - - if len(proxyArgs) > 0 { - proxy := proxyArgs[0] - return connectForward(connType, addr, proxy) - } - - return net.Dial("tcp", addr) + return forwardChain(addr, forwardArgs[0], forwardArgs[1:]...) } -func connectProxyForward(connType, addr string, proxy, forward Args) (conn net.Conn, err error) { - return nil, errors.New("Not implemented") -} - -func connectForward(connType, addr string, forward Args) (conn net.Conn, err error) { +func forwardChain(addr string, level1 Args, chain ...Args) (conn net.Conn, err error) { if glog.V(LINFO) { - glog.Infoln(forward.Protocol, "forward:", forward.Addr) + glog.Infof("forward: %s/%s %s", level1.Protocol, level1.Transport, level1.Addr) } - conn, err = net.Dial("tcp", forward.Addr) - if err != nil { + if conn, err = net.Dial("tcp", level1.Addr); err != nil { return } + c, err := forward(conn, level1) + if err != nil { + conn.Close() + return nil, err + } + conn = c - switch forward.Transport { - case "ws": // websocket connection - c, err := wsClient(conn, forward.Addr) - if err != nil { + if len(chain) == 0 { + if err := establish(conn, addr, level1); err != nil { conn.Close() return nil, err } + return + } + + cur := level1 + for _, arg := range chain { + if err = establish(conn, arg.Addr, cur); err != nil { + goto exit + } + + c, err = forward(conn, arg) + if err != nil { + goto exit + } conn = c + cur = arg + } + +exit: + if err != nil { + conn.Close() + return nil, err + } + + if err := establish(conn, addr, cur); err != nil { + conn.Close() + return nil, err + } + + return +} + +func forward(conn net.Conn, arg Args) (net.Conn, error) { + var err error + + switch arg.Transport { + case "ws": // websocket connection + conn, err = wsClient(conn, arg.Addr) + if err != nil { + return nil, err + } case "tls": // tls connection conn = tls.Client(conn, &tls.Config{InsecureSkipVerify: true}) case "tcp": @@ -262,83 +287,86 @@ func connectForward(connType, addr string, forward Args) (conn net.Conn, err err default: } - switch forward.Protocol { + switch arg.Protocol { case "ss": // shadowsocks - conn.Close() return nil, errors.New("Not implemented") case "socks", "socks5": selector := &clientSelector{ methods: []uint8{gosocks5.MethodNoAuth, gosocks5.MethodUserPass}, - arg: forward, + arg: arg, } - if forward.EncMeth == "tls" { + if arg.EncMeth == "tls" { selector.methods = []uint8{MethodTLS, MethodTLSAuth} } c := gosocks5.ClientConn(conn, selector) if err := c.Handleshake(); err != nil { - c.Close() return nil, err } conn = c + case "http": + fallthrough + default: + } - host, port, _ := net.SplitHostPort(addr) - p, _ := strconv.ParseUint(port, 10, 16) - r := gosocks5.NewRequest(gosocks5.CmdConnect, &gosocks5.Addr{ + return conn, nil +} + +func establish(conn net.Conn, addr string, arg Args) error { + switch arg.Protocol { + case "ss": // shadowsocks + return nil + case "socks", "socks5": + host, port, err := net.SplitHostPort(addr) + p, _ := strconv.Atoi(port) + // TODO: support bind and udp + req := gosocks5.NewRequest(gosocks5.CmdConnect, &gosocks5.Addr{ Type: gosocks5.AddrDomain, Host: host, Port: uint16(p), }) - rep, err := requestSocks5(conn, r) + rep, err := requestSocks5(conn, req) if err != nil { - conn.Close() - return nil, err + return err } if rep.Rep != gosocks5.Succeeded { - conn.Close() - return nil, errors.New("Service unavailable") + return errors.New("Service unavailable") } case "http": fallthrough default: - if connType == ConnHttpConnect || connType == ConnSocks5 { - req := &http.Request{ - Method: "CONNECT", - URL: &url.URL{Host: addr}, - Host: addr, - ProtoMajor: 1, - ProtoMinor: 1, - Header: make(http.Header), - } - req.Header.Set("Proxy-Connection", "keep-alive") - if forward.User != nil { - req.Header.Set("Proxy-Authorization", - "Basic "+base64.StdEncoding.EncodeToString([]byte(forward.User.String()))) - } - if err = req.Write(conn); err != nil { - conn.Close() - return nil, err - } - if glog.V(LDEBUG) { - dump, _ := httputil.DumpRequest(req, false) - glog.Infoln(string(dump)) - } + req := &http.Request{ + Method: "CONNECT", + URL: &url.URL{Host: addr}, + Host: addr, + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + } + req.Header.Set("Proxy-Connection", "keep-alive") + if arg.User != nil { + req.Header.Set("Proxy-Authorization", + "Basic "+base64.StdEncoding.EncodeToString([]byte(arg.User.String()))) + } + if err := req.Write(conn); err != nil { + return err + } + if glog.V(LDEBUG) { + dump, _ := httputil.DumpRequest(req, false) + glog.Infoln(string(dump)) + } - resp, err := http.ReadResponse(bufio.NewReader(conn), req) - if err != nil { - conn.Close() - return nil, err - } - if glog.V(LDEBUG) { - dump, _ := httputil.DumpResponse(resp, false) - glog.Infoln(string(dump)) - } - if resp.StatusCode != http.StatusOK { - conn.Close() - //log.Println(resp.Status) - return nil, errors.New(resp.Status) - } + resp, err := http.ReadResponse(bufio.NewReader(conn), req) + if err != nil { + return err + } + if glog.V(LDEBUG) { + dump, _ := httputil.DumpResponse(resp, false) + glog.Infoln(string(dump)) + } + if resp.StatusCode != http.StatusOK { + return errors.New(resp.Status) } } - return + return nil } diff --git a/http.go b/http.go index c4be6a2..e67d433 100644 --- a/http.go +++ b/http.go @@ -19,11 +19,6 @@ func handleHttpRequest(req *http.Request, conn net.Conn, arg Args) { } } - connType := ConnHttp - if req.Method == "CONNECT" { - connType = ConnHttpConnect - } - var username, password string if arg.User != nil { username = arg.User.Username() @@ -52,7 +47,7 @@ func handleHttpRequest(req *http.Request, conn net.Conn, arg Args) { return } - c, err := connect(connType, req.Host) + c, err := connect(req.Host) if err != nil { if glog.V(LWARNING) { glog.Warningln(err) @@ -67,7 +62,7 @@ func handleHttpRequest(req *http.Request, conn net.Conn, arg Args) { } defer c.Close() - if connType == ConnHttpConnect { + if req.Method == "CONNECT" { b := []byte("HTTP/1.1 200 Connection established\r\n" + "Proxy-Agent: gost/" + Version + "\r\n\r\n") if glog.V(LDEBUG) { @@ -80,7 +75,7 @@ func handleHttpRequest(req *http.Request, conn net.Conn, arg Args) { return } } else { - if len(proxyArgs) > 0 || len(forwardArgs) > 0 { + if len(forwardArgs) > 0 { err = req.WriteProxy(c) } else { err = req.Write(c) diff --git a/main.go b/main.go index 16b8755..d8210da 100644 --- a/main.go +++ b/main.go @@ -16,30 +16,30 @@ const ( ) var ( - listenUrl, proxyUrl, forwardUrl string - pv bool // print version + listenAddr, forwardAddr strSlice + pv bool // print version listenArgs []Args - proxyArgs []Args forwardArgs []Args ) func init() { - flag.StringVar(&listenUrl, "L", ":http", "local address") - flag.StringVar(&forwardUrl, "S", "", "remote address") - flag.StringVar(&proxyUrl, "P", "", "proxy address") + flag.Var(&listenAddr, "L", "listen address") + flag.Var(&forwardAddr, "F", "forward address, can make a forward chain") flag.BoolVar(&pv, "V", false, "print version") - flag.Parse() - listenArgs = parseArgs(listenUrl) - proxyArgs = parseArgs(proxyUrl) - forwardArgs = parseArgs(forwardUrl) + listenArgs = parseArgs(listenAddr) + forwardArgs = parseArgs(forwardAddr) } func main() { defer glog.Flush() + if flag.NFlag() == 0 { + flag.PrintDefaults() + return + } if pv { printVersion() return diff --git a/socks.go b/socks.go index 73ab490..3a23d44 100644 --- a/socks.go +++ b/socks.go @@ -138,7 +138,7 @@ func (selector *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Con return nil, err } if glog.V(LDEBUG) { - glog.Infoln(req) + glog.Infoln(req.String()) } var username, password string @@ -206,7 +206,7 @@ func handleSocks5Request(req *gosocks5.Request, conn net.Conn, arg Args) { if glog.V(LINFO) { glog.Infoln("socks5 connect:", req.Addr.String()) } - tconn, err := connect(ConnSocks5, req.Addr.String()) + tconn, err := connect(req.Addr.String()) if err != nil { if glog.V(LWARNING) { glog.Warningln("socks5 connect:", err) diff --git a/util.go b/util.go index 909eacb..11566c4 100644 --- a/util.go +++ b/util.go @@ -10,6 +10,16 @@ import ( "strings" ) +type strSlice []string + +func (ss *strSlice) String() string { + return fmt.Sprintf("%s", *ss) +} +func (ss *strSlice) Set(value string) error { + *ss = append(*ss, value) + return nil +} + // socks://admin:123456@localhost:8080/tls type Args struct { Addr string // host:port @@ -32,12 +42,7 @@ func (args Args) String() string { args.EncMeth, args.EncPass) } -func parseArgs(rawurl string) (args []Args) { - ss := strings.Split(rawurl, ",") - if rawurl == "" || len(ss) == 0 { - return nil - } - +func parseArgs(ss []string) (args []Args) { for _, s := range ss { if !strings.Contains(s, "://") { s = "tcp://" + s