diff --git a/conn.go b/conn.go index b807f6d..3f3d64f 100644 --- a/conn.go +++ b/conn.go @@ -111,7 +111,7 @@ func handleConn(conn net.Conn, arg Args) { } return } - handleSocks5Request(req, conn, arg) + handleSocks5Request(req, conn) return } @@ -162,7 +162,7 @@ func handleConn(conn net.Conn, arg Args) { } return } - handleSocks5Request(req, conn, arg) + handleSocks5Request(req, conn) return } @@ -198,41 +198,43 @@ func (r *reqReader) Read(p []byte) (n int, err error) { return } -func connect(addr string) (conn net.Conn, err error) { +func Connect(addr string) (conn net.Conn, err error) { if !strings.Contains(addr, ":") { addr += ":80" } if len(forwardArgs) == 0 { return net.Dial("tcp", addr) } - return forwardChain(addr, forwardArgs[0], forwardArgs[1:]...) -} -func forwardChain(addr string, level1 Args, chain ...Args) (conn net.Conn, err error) { - if glog.V(LINFO) { - glog.Infof("forward: %s/%s %s", level1.Protocol, level1.Transport, level1.Addr) - } - if conn, err = net.Dial("tcp", level1.Addr); err != nil { - return - } - c, err := forward(conn, level1) + var end Args + conn, end, err = forwardChain(forwardArgs...) if err != nil { + if conn != nil { + conn.Close() + } + return nil, err + } + if err := establish(conn, addr, end); err != nil { conn.Close() return nil, err } - conn = c + return conn, nil +} - if len(chain) == 0 { - if err := establish(conn, addr, level1); err != nil { - conn.Close() - return nil, err - } +func forwardChain(chain ...Args) (conn net.Conn, end Args, err error) { + end = chain[0] + if conn, err = net.Dial("tcp", end.Addr); err != nil { return } + c, err := forward(conn, end) + if err != nil { + return + } + conn = c - cur := level1 + chain = chain[1:] for _, arg := range chain { - if err = establish(conn, arg.Addr, cur); err != nil { + if err = establish(conn, arg.Addr, end); err != nil { goto exit } @@ -241,26 +243,18 @@ func forwardChain(addr string, level1 Args, chain ...Args) (conn net.Conn, err e goto exit } conn = c - cur = arg + end = arg } exit: - if err != nil { - conn.Close() - return nil, err - } - - if err := establish(conn, addr, cur); err != nil { - conn.Close() - return nil, err - } - return } func forward(conn net.Conn, arg Args) (net.Conn, error) { var err error - + if glog.V(LINFO) { + glog.Infof("forward: %s/%s %s", arg.Protocol, arg.Transport, arg.Addr) + } switch arg.Transport { case "ws": // websocket connection conn, err = wsClient(conn, arg.Addr) @@ -312,10 +306,19 @@ func establish(conn net.Conn, addr string, arg Args) error { Host: host, Port: uint16(p), }) - rep, err := requestSocks5(conn, req) + if err := req.Write(conn); err != nil { + return err + } + if glog.V(LDEBUG) { + glog.Infoln(req) + } + rep, err := gosocks5.ReadReply(conn) if err != nil { return err } + if glog.V(LDEBUG) { + glog.Infoln(rep) + } if rep.Rep != gosocks5.Succeeded { return errors.New("Service unavailable") } diff --git a/http.go b/http.go index e67d433..47a9ee9 100644 --- a/http.go +++ b/http.go @@ -47,7 +47,7 @@ func handleHttpRequest(req *http.Request, conn net.Conn, arg Args) { return } - c, err := connect(req.Host) + c, err := Connect(req.Host) if err != nil { if glog.V(LWARNING) { glog.Warningln(err) diff --git a/socks.go b/socks.go index 1ed9f72..4e573e3 100644 --- a/socks.go +++ b/socks.go @@ -2,6 +2,7 @@ package main import ( "crypto/tls" + "errors" "github.com/ginuerzh/gosocks5" "github.com/golang/glog" "net" @@ -175,24 +176,7 @@ func (selector *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Con return conn, nil } -func requestSocks5(conn net.Conn, req *gosocks5.Request) (*gosocks5.Reply, error) { - if err := req.Write(conn); err != nil { - return nil, err - } - if glog.V(LDEBUG) { - glog.Infoln(req.String()) - } - rep, err := gosocks5.ReadReply(conn) - if err != nil { - return nil, err - } - if glog.V(LDEBUG) { - glog.Infoln(rep.String()) - } - return rep, nil -} - -func handleSocks5Request(req *gosocks5.Request, conn net.Conn, arg Args) { +func handleSocks5Request(req *gosocks5.Request, conn net.Conn) { if glog.V(LDEBUG) { glog.Infoln(req) } @@ -202,7 +186,7 @@ func handleSocks5Request(req *gosocks5.Request, conn net.Conn, arg Args) { if glog.V(LINFO) { glog.Infoln("socks5 connect:", req.Addr.String()) } - tconn, err := connect(req.Addr.String()) + tconn, err := Connect(req.Addr.String()) if err != nil { if glog.V(LWARNING) { glog.Warningln("socks5 connect:", err) @@ -234,78 +218,10 @@ func handleSocks5Request(req *gosocks5.Request, conn net.Conn, arg Args) { Transport(conn, tconn) case gosocks5.CmdBind: - l, err := net.ListenTCP("tcp", nil) - if err != nil { - if glog.V(LWARNING) { - glog.Warningln("socks5 bind listen:", err) - } - rep := gosocks5.NewReply(gosocks5.Failure, nil) - if err := rep.Write(conn); err != nil { - if glog.V(LWARNING) { - glog.Warningln("socks5 bind listen:", err) - } - } else { - if glog.V(LDEBUG) { - glog.Infoln(rep) - } - } - return - } - - addr := ToSocksAddr(l.Addr()) - addr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) - if glog.V(LINFO) { - glog.Infoln("socks5 bind:", addr) - } - rep := gosocks5.NewReply(gosocks5.Succeeded, addr) - if err := rep.Write(conn); err != nil { - if glog.V(LWARNING) { - glog.Warningln("socks5 bind:", err) - } - l.Close() - return - } - if glog.V(LDEBUG) { - glog.Infoln(rep) - } - - tconn, err := l.AcceptTCP() - l.Close() // only accept one peer - if err != nil { - if glog.V(LWARNING) { - glog.Warningln("socks5 bind accept:", err) - } - rep = gosocks5.NewReply(gosocks5.Failure, nil) - if err := rep.Write(conn); err != nil { - if glog.V(LWARNING) { - glog.Warningln("socks5 bind accept:", err) - } - } else { - if glog.V(LDEBUG) { - glog.Infoln(rep) - } - } - return - } - defer tconn.Close() - - addr = ToSocksAddr(tconn.RemoteAddr()) - if glog.V(LINFO) { - glog.Infoln("socks5 bind accept:", addr.String()) - } - rep = gosocks5.NewReply(gosocks5.Succeeded, addr) - if err := rep.Write(conn); err != nil { - if glog.V(LWARNING) { - glog.Warningln("socks5 bind accept:", err) - } - return - } - if glog.V(LDEBUG) { - glog.Infoln(rep) - } - - if err := Transport(conn, tconn); err != nil { - //log.Println(err) + if len(forwardArgs) > 0 { + forwardBind(req, conn) + } else { + serveBind(conn) } case gosocks5.CmdUdp: uconn, err := net.ListenUDP("udp", nil) @@ -347,6 +263,171 @@ func handleSocks5Request(req *gosocks5.Request, conn net.Conn, arg Args) { } } +func serveBind(conn net.Conn) error { + l, err := net.ListenTCP("tcp", nil) + if err != nil { + if glog.V(LWARNING) { + glog.Warningln("socks5 bind listen:", err) + } + rep := gosocks5.NewReply(gosocks5.Failure, nil) + if err := rep.Write(conn); err != nil { + if glog.V(LWARNING) { + glog.Warningln("socks5 bind listen:", err) + } + } else { + if glog.V(LDEBUG) { + glog.Infoln(rep) + } + } + return err + } + + addr := ToSocksAddr(l.Addr()) + // Issue: may not reachable when host has two interfaces + addr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) + if glog.V(LINFO) { + glog.Infoln("socks5 bind:", addr) + } + rep := gosocks5.NewReply(gosocks5.Succeeded, addr) + if err := rep.Write(conn); err != nil { + if glog.V(LWARNING) { + glog.Warningln("socks5 bind:", err) + } + l.Close() + return err + } + if glog.V(LDEBUG) { + glog.Infoln(rep) + } + + tconn, err := l.AcceptTCP() + l.Close() // only accept one peer + if err != nil { + if glog.V(LWARNING) { + glog.Warningln("socks5 bind accept:", err) + } + rep = gosocks5.NewReply(gosocks5.Failure, nil) + if err := rep.Write(conn); err != nil { + if glog.V(LWARNING) { + glog.Warningln("socks5 bind accept:", err) + } + } else { + if glog.V(LDEBUG) { + glog.Infoln(rep) + } + } + return err + } + defer tconn.Close() + + addr = ToSocksAddr(tconn.RemoteAddr()) + if glog.V(LINFO) { + glog.Infoln("socks5 bind accept:", addr.String()) + } + rep = gosocks5.NewReply(gosocks5.Succeeded, addr) + if err := rep.Write(conn); err != nil { + if glog.V(LWARNING) { + glog.Warningln("socks5 bind accept:", err) + } + return err + } + if glog.V(LDEBUG) { + glog.Infoln(rep) + } + + return Transport(conn, tconn) +} + +func forwardBind(req *gosocks5.Request, conn net.Conn) error { + fc, _, err := forwardChain(forwardArgs...) + if err != nil { + if fc != nil { + fc.Close() + } + rep := gosocks5.NewReply(gosocks5.Failure, nil) + if err := rep.Write(conn); err != nil { + if glog.V(LWARNING) { + glog.Warningln("socks5 bind:", err) + } + } else { + if glog.V(LDEBUG) { + glog.Infoln(rep) + } + } + return err + } + defer fc.Close() + + if err := req.Write(fc); err != nil { + if glog.V(LWARNING) { + glog.Warningln("socks5 bind:", err) + } + gosocks5.NewReply(gosocks5.Failure, nil).Write(conn) + return err + } + if glog.V(LDEBUG) { + glog.Infoln(req) + } + + // first reply + if err := peekBindReply(conn, fc); err != nil { + return err + } + // second reply + if err := peekBindReply(conn, fc); err != nil { + return err + } + + return Transport(conn, fc) +} + +func peekBindReply(conn, fc net.Conn) error { + rep, err := gosocks5.ReadReply(fc) + if err != nil { + if glog.V(LWARNING) { + glog.Warningln("socks5 bind:", err) + } + rep = gosocks5.NewReply(gosocks5.Failure, nil) + } + if err := rep.Write(conn); err != nil { + if glog.V(LWARNING) { + glog.Warningln("socks5 bind:", err) + } + return err + } + if glog.V(LDEBUG) { + glog.Infoln(rep) + } + if rep.Rep != gosocks5.Succeeded { + return errors.New("Bind failure") + } + + return nil +} + +/* +func forwardUDP() error { + fc, _, err := forwardChain(forwardArgs...) + if err != nil { + if fc != nil { + fc.Close() + } + rep := gosocks5.NewReply(gosocks5.Failure, nil) + if err := rep.Write(conn); err != nil { + if glog.V(LWARNING) { + glog.Warningln("socks5 bind:", err) + } + } else { + if glog.V(LDEBUG) { + glog.Infoln(rep) + } + } + return err + } + defer fc.Close() + +} +*/ func srvTunnelUDP(conn net.Conn, uconn *net.UDPConn) { go func() { b := make([]byte, 16*1024)